This commit is contained in:
Slipstream 2025-06-15 20:53:38 -06:00
commit 380feddeeb
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
26 changed files with 1202 additions and 194 deletions

View File

@ -14,6 +14,7 @@ A Python library for interacting with the Discord API, with a focus on bot devel
- `Message.jump_url` property for quick links to messages
- Built-in caching layer
- `Guild.me` property to access the bot's member object
- Easy CDN asset handling via the `Asset` model
- Experimental voice support
- Helpful error handling utilities
@ -111,6 +112,10 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
session parameter supported by ``aiohttp``.
### Logging Out
Call ``Client.logout`` to disconnect from the Gateway and clear the current bot token while keeping the HTTP session alive. Assign a new token and call ``connect`` or ``run`` to log back in.
### Default Allowed Mentions
Specify default mention behaviour for all outgoing messages when constructing the client:
@ -126,6 +131,17 @@ client = disagreement.Client(
This dictionary is used whenever ``send_message`` or helpers like ``Message.reply``
are called without an explicit ``allowed_mentions`` argument.
### Working With Assets
Properties like ``User.avatar`` and ``Guild.icon`` return :class:`disagreement.Asset` objects.
Use ``read`` to get the bytes or ``save`` to write them to disk.
```python
user = await client.fetch_user(123)
data = await user.avatar.read()
await user.avatar.save("avatar.png")
```
### Defining Subcommands with `AppCommandGroup`
```python

View File

@ -15,6 +15,7 @@ __copyright__ = "Copyright 2025 Slipstream"
__version__ = "0.8.1"
from .client import Client, AutoShardedClient
from .asset import Asset
from .models import (
Message,
User,
@ -39,6 +40,7 @@ from .models import (
Container,
Guild,
)
from .object import Object
from .voice_client import VoiceClient
from .audio import AudioSource, FFmpegAudioSource
from .typing import Typing
@ -125,6 +127,7 @@ import logging
__all__ = [
"Client",
"AutoShardedClient",
"Asset",
"Message",
"User",
"Reaction",
@ -146,6 +149,7 @@ __all__ = [
"MediaGallery",
"MediaGalleryItem",
"Container",
"Object",
"VoiceClient",
"AudioSource",
"FFmpegAudioSource",

51
disagreement/asset.py Normal file
View File

@ -0,0 +1,51 @@
"""Utility class for Discord CDN assets."""
from __future__ import annotations
import os
from typing import IO, Optional, Union, TYPE_CHECKING
import aiohttp # pylint: disable=import-error
if TYPE_CHECKING:
from .client import Client
class Asset:
"""Represents a CDN asset such as an avatar or icon."""
def __init__(self, url: str, client_instance: Optional["Client"] = None) -> None:
self.url = url
self._client = client_instance
async def read(self) -> bytes:
"""Read the asset's bytes."""
session: Optional[aiohttp.ClientSession] = None
if self._client is not None:
await self._client._http._ensure_session() # type: ignore[attr-defined]
session = self._client._http._session # type: ignore[attr-defined]
if session is None:
session = aiohttp.ClientSession()
close = True
else:
close = False
async with session.get(self.url) as resp:
data = await resp.read()
if close:
await session.close()
return data
async def save(self, fp: Union[str, os.PathLike[str], IO[bytes]]) -> None:
"""Save the asset to the given file path or file-like object."""
data = await self.read()
if isinstance(fp, (str, os.PathLike)):
path = os.fspath(fp)
with open(path, "wb") as file:
file.write(data)
else:
fp.write(data)
def __repr__(self) -> str:
return f"<Asset url='{self.url}'>"

View File

@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]):
def _should_cache(self, member: Member) -> bool:
"""Determines if a member should be cached based on the flags."""
if self.flags.all:
if self.flags.all_enabled:
return True
if self.flags.none:
if self.flags.no_flags:
return False
if self.flags.online and member.status != "offline":

View File

@ -74,6 +74,14 @@ class MemberCacheFlags:
for name in self.VALID_FLAGS:
yield name, getattr(self, name)
@property
def all_enabled(self) -> bool:
return self.value == self.ALL_FLAGS
@property
def no_flags(self) -> bool:
return self.value == 0
def __int__(self) -> int:
return self.value

View File

@ -2,8 +2,11 @@
The main Client class for interacting with the Discord API.
"""
import asyncio
import signal
import asyncio
import signal
import json
import os
import importlib
from typing import (
Optional,
Callable,
@ -16,7 +19,9 @@ from typing import (
Dict,
cast,
)
from types import ModuleType
from types import ModuleType
PERSISTENT_VIEWS_FILE = "persistent_views.json"
from datetime import datetime, timedelta
@ -77,7 +82,7 @@ def _update_list(lst: List[Any], item: Any) -> None:
lst.append(item)
class Client:
class Client:
"""
Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API.
@ -193,7 +198,10 @@ class Client:
self._views: Dict[Snowflake, "View"] = {}
self._persistent_views: Dict[str, "View"] = {}
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
self._webhooks: Dict[Snowflake, "Webhook"] = {}
self._webhooks: Dict[Snowflake, "Webhook"] = {}
# Load persistent views stored on disk
self._load_persistent_views()
# Default whether replies mention the user
self.mention_replies: bool = mention_replies
@ -210,13 +218,46 @@ class Client:
self.loop.add_signal_handler(
signal.SIGTERM, lambda: self.loop.create_task(self.close())
)
except NotImplementedError:
# add_signal_handler is not available on all platforms (e.g., Windows default event loop policy)
# Users on these platforms would need to handle shutdown differently.
print(
"Warning: Signal handlers for SIGINT/SIGTERM could not be added. "
"Graceful shutdown via signals might not work as expected on this platform."
)
except NotImplementedError:
# add_signal_handler is not available on all platforms (e.g., Windows default event loop policy)
# Users on these platforms would need to handle shutdown differently.
print(
"Warning: Signal handlers for SIGINT/SIGTERM could not be added. "
"Graceful shutdown via signals might not work as expected on this platform."
)
def _load_persistent_views(self) -> None:
"""Load registered persistent views from disk."""
if not os.path.isfile(PERSISTENT_VIEWS_FILE):
return
try:
with open(PERSISTENT_VIEWS_FILE, "r") as fp:
mapping = json.load(fp)
except Exception as e: # pragma: no cover - best effort load
print(f"Failed to load persistent views: {e}")
return
for custom_id, path in mapping.items():
try:
module_name, class_name = path.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
view = cls()
self._persistent_views[custom_id] = view
except Exception as e: # pragma: no cover - best effort load
print(f"Failed to initialize persistent view {path}: {e}")
def _save_persistent_views(self) -> None:
"""Persist registered views to disk."""
data = {}
for custom_id, view in self._persistent_views.items():
cls = view.__class__
data[custom_id] = f"{cls.__module__}.{cls.__name__}"
try:
with open(PERSISTENT_VIEWS_FILE, "w") as fp:
json.dump(data, fp)
except Exception as e: # pragma: no cover - best effort save
print(f"Failed to save persistent views: {e}")
async def _initialize_gateway(self):
"""Initializes the GatewayClient if it doesn't exist."""
@ -413,15 +454,23 @@ class Client:
await self.close()
return False
async def close_gateway(self, code: int = 1000) -> None:
"""Closes only the gateway connection, allowing for potential reconnect."""
if self._shard_manager:
await self._shard_manager.close()
self._shard_manager = None
if self._gateway:
await self._gateway.close(code=code)
self._gateway = None
self._ready_event.clear() # No longer ready if gateway is closed
async def close_gateway(self, code: int = 1000) -> None:
"""Closes only the gateway connection, allowing for potential reconnect."""
if self._shard_manager:
await self._shard_manager.close()
self._shard_manager = None
if self._gateway:
await self._gateway.close(code=code)
self._gateway = None
self._ready_event.clear() # No longer ready if gateway is closed
async def logout(self) -> None:
"""Invalidate the bot token and disconnect from the Gateway."""
await self.close_gateway()
self.token = ""
self._http.token = ""
self.user = None
self.start_time = None
def is_closed(self) -> bool:
"""Indicates if the client has been closed."""
@ -1399,10 +1448,30 @@ class Client:
return self._channels.get(channel_id)
def get_message(self, message_id: Snowflake) -> Optional["Message"]:
"""Returns a message from the internal cache."""
return self._messages.get(message_id)
def get_message(self, message_id: Snowflake) -> Optional["Message"]:
"""Returns a message from the internal cache."""
return self._messages.get(message_id)
def get_all_channels(self) -> List["Channel"]:
"""Return all channels cached in every guild."""
channels: List["Channel"] = []
for guild in self._guilds.values():
channels.extend(guild._channels.values())
return channels
def get_all_members(self) -> List["Member"]:
"""Return all cached members across all guilds.
When member caching is disabled via :class:`MemberCacheFlags.none`, this
list will always be empty.
"""
members: List["Member"] = []
for guild in self._guilds.values():
members.extend(guild._members.values())
return members
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
"""Fetches a guild by ID from Discord and caches it."""
@ -1707,11 +1776,13 @@ class Client:
for item in view.children:
if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views:
raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered."
)
self._persistent_views[item.custom_id] = view
if item.custom_id in self._persistent_views:
raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered."
)
self._persistent_views[item.custom_id] = view
self._save_persistent_views()
# --- Application Command Methods ---
async def process_interaction(self, interaction: Interaction) -> None:

View File

@ -6,7 +6,16 @@ import re
import inspect
from .errors import BadArgument
from disagreement.models import Member, Guild, Role, User
from disagreement.models import (
Member,
Guild,
Role,
User,
TextChannel,
VoiceChannel,
Emoji,
PartialEmoji,
)
if TYPE_CHECKING:
from .core import CommandContext
@ -158,6 +167,82 @@ class UserConverter(Converter["User"]):
raise BadArgument(f"User '{argument}' not found.")
class TextChannelConverter(Converter["TextChannel"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "TextChannel":
if not ctx.message.guild_id:
raise BadArgument("TextChannel converter requires guild context.")
match = re.match(r"<#(?P<id>\d+)>$", argument)
channel_id = match.group("id") if match else argument
guild = ctx.bot.get_guild(ctx.message.guild_id)
if guild:
channel = guild.get_channel(channel_id)
if isinstance(channel, TextChannel):
return channel
channel = (
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
)
if isinstance(channel, TextChannel):
return channel
if hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(channel_id)
if isinstance(channel, TextChannel):
return channel
raise BadArgument(f"Text channel '{argument}' not found.")
class VoiceChannelConverter(Converter["VoiceChannel"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "VoiceChannel":
if not ctx.message.guild_id:
raise BadArgument("VoiceChannel converter requires guild context.")
match = re.match(r"<#(?P<id>\d+)>$", argument)
channel_id = match.group("id") if match else argument
guild = ctx.bot.get_guild(ctx.message.guild_id)
if guild:
channel = guild.get_channel(channel_id)
if isinstance(channel, VoiceChannel):
return channel
channel = (
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
)
if isinstance(channel, VoiceChannel):
return channel
if hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(channel_id)
if isinstance(channel, VoiceChannel):
return channel
raise BadArgument(f"Voice channel '{argument}' not found.")
class EmojiConverter(Converter["PartialEmoji"]):
_CUSTOM_RE = re.compile(r"<(?P<animated>a)?:(?P<name>[^:]+):(?P<id>\d+)>$")
async def convert(self, ctx: "CommandContext", argument: str) -> "PartialEmoji":
match = self._CUSTOM_RE.match(argument)
if match:
return PartialEmoji(
{
"id": match.group("id"),
"name": match.group("name"),
"animated": bool(match.group("animated")),
}
)
if argument:
return PartialEmoji({"id": None, "name": argument})
raise BadArgument(f"Emoji '{argument}' not found.")
# Default converters mapping
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
int: IntConverter(),
@ -168,6 +253,10 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
Guild: GuildConverter(),
Role: RoleConverter(),
User: UserConverter(),
TextChannel: TextChannelConverter(),
VoiceChannel: VoiceChannelConverter(),
PartialEmoji: EmojiConverter(),
Emoji: EmojiConverter(),
}

View File

@ -79,8 +79,15 @@ class GroupMixin:
)
self.commands[alias.lower()] = command
def get_command(self, name: str) -> Optional["Command"]:
return self.commands.get(name.lower())
def get_command(self, name: str) -> Optional["Command"]:
return self.commands.get(name.lower())
def walk_commands(self):
"""Yield all commands in this group recursively."""
for cmd in dict.fromkeys(self.commands.values()):
yield cmd
if isinstance(cmd, Group):
yield from cmd.walk_commands()
class Command(GroupMixin):
@ -366,6 +373,13 @@ class CommandHandler:
def get_command(self, name: str) -> Optional[Command]:
return self.commands.get(name.lower())
def walk_commands(self):
"""Yield every registered command, including subcommands."""
for cmd in dict.fromkeys(self.commands.values()):
yield cmd
if isinstance(cmd, Group):
yield from cmd.walk_commands()
def get_cog(self, name: str) -> Optional["Cog"]:
"""Return a loaded cog by name if present."""

View File

@ -1,6 +1,8 @@
from collections import defaultdict
from typing import List, Optional
from .core import Command, CommandContext, CommandHandler
from ...utils import Paginator
from .core import Command, CommandContext, CommandHandler, Group
class HelpCommand(Command):
@ -15,17 +17,22 @@ class HelpCommand(Command):
if not cmd or cmd.name.lower() != command.lower():
await ctx.send(f"Command '{command}' not found.")
return
description = cmd.description or cmd.brief or "No description provided."
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
else:
lines: List[str] = []
for registered in dict.fromkeys(handler.commands.values()):
brief = registered.brief or registered.description or ""
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
if lines:
await ctx.send("\n".join(lines))
if isinstance(cmd, Group):
await self.send_group_help(ctx, cmd)
elif cmd:
description = cmd.description or cmd.brief or "No description provided."
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
else:
await ctx.send("No commands available.")
lines: List[str] = []
for registered in handler.walk_commands():
brief = registered.brief or registered.description or ""
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
if lines:
await ctx.send("\n".join(lines))
else:
await self.send_command_help(ctx, cmd)
else:
await self.send_bot_help(ctx)
super().__init__(
callback,
@ -33,3 +40,42 @@ class HelpCommand(Command):
brief="Show command help.",
description="Displays help for commands.",
)
async def send_bot_help(self, ctx: CommandContext) -> None:
groups = defaultdict(list)
for cmd in dict.fromkeys(self.handler.commands.values()):
key = cmd.cog.cog_name if cmd.cog else "No Category"
groups[key].append(cmd)
paginator = Paginator()
for cog_name, cmds in groups.items():
paginator.add_line(f"**{cog_name}**")
for cmd in cmds:
brief = cmd.brief or cmd.description or ""
paginator.add_line(f"{ctx.prefix}{cmd.name} - {brief}".strip())
paginator.add_line("")
pages = paginator.pages
if not pages:
await ctx.send("No commands available.")
return
for page in pages:
await ctx.send(page)
async def send_command_help(self, ctx: CommandContext, command: Command) -> None:
description = command.description or command.brief or "No description provided."
await ctx.send(f"**{ctx.prefix}{command.name}**\n{description}")
async def send_group_help(self, ctx: CommandContext, group: Group) -> None:
paginator = Paginator()
description = group.description or group.brief or "No description provided."
paginator.add_line(f"**{ctx.prefix}{group.name}**\n{description}")
if group.commands:
for sub in dict.fromkeys(group.commands.values()):
brief = sub.brief or sub.description or ""
paginator.add_line(
f"{ctx.prefix}{group.name} {sub.name} - {brief}".strip()
)
for page in paginator.pages:
await ctx.send(page)

View File

@ -656,21 +656,21 @@ class HTTPClient:
await self.request("PUT", f"/channels/{channel_id}/pins/{message_id}")
async def unpin_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> None:
"""Unpins a message from a channel."""
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
async def crosspost_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> Dict[str, Any]:
"""Crossposts a message to any following channels."""
return await self.request(
"POST", f"/channels/{channel_id}/messages/{message_id}/crosspost"
)
async def unpin_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> None:
"""Unpins a message from a channel."""
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
async def crosspost_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> Dict[str, Any]:
"""Crossposts a message to any following channels."""
return await self.request(
"POST", f"/channels/{channel_id}/messages/{message_id}/crosspost"
)
async def delete_channel(
self, channel_id: str, reason: Optional[str] = None
@ -734,68 +734,68 @@ class HTTPClient:
return await self.request("GET", f"/channels/{channel_id}/invites")
async def create_invite(
self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> "Invite":
"""Creates an invite for a channel."""
async def create_invite(
self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> "Invite":
"""Creates an invite for a channel."""
data = await self.request(
"POST", f"/channels/{channel_id}/invites", payload=payload
)
from .models import Invite
return Invite.from_dict(data)
async def create_channel_invite(
self,
channel_id: "Snowflake",
payload: Dict[str, Any],
*,
reason: Optional[str] = None,
) -> "Invite":
"""Creates an invite for a channel with an optional audit log reason."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
data = await self.request(
"POST",
f"/channels/{channel_id}/invites",
payload=payload,
custom_headers=headers,
)
from .models import Invite
return Invite.from_dict(data)
from .models import Invite
async def delete_invite(self, code: str) -> None:
"""Deletes an invite by code."""
await self.request("DELETE", f"/invites/{code}")
async def get_invite(self, code: "Snowflake") -> Dict[str, Any]:
"""Fetches a single invite by its code."""
return await self.request("GET", f"/invites/{code}")
return Invite.from_dict(data)
async def create_webhook(
self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook":
"""Creates a webhook in the specified channel."""
async def create_channel_invite(
self,
channel_id: "Snowflake",
payload: Dict[str, Any],
*,
reason: Optional[str] = None,
) -> "Invite":
"""Creates an invite for a channel with an optional audit log reason."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
data = await self.request(
"POST",
f"/channels/{channel_id}/invites",
payload=payload,
custom_headers=headers,
)
from .models import Invite
return Invite.from_dict(data)
async def delete_invite(self, code: str) -> None:
"""Deletes an invite by code."""
await self.request("DELETE", f"/invites/{code}")
async def get_invite(self, code: "Snowflake") -> Dict[str, Any]:
"""Fetches a single invite by its code."""
return await self.request("GET", f"/invites/{code}")
async def create_webhook(
self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook":
"""Creates a webhook in the specified channel."""
data = await self.request(
"POST", f"/channels/{channel_id}/webhooks", payload=payload
)
from .models import Webhook
return Webhook(data)
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
"""Fetches a webhook by ID."""
return await self.request("GET", f"/webhooks/{webhook_id}")
async def edit_webhook(
self, webhook_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook":
from .models import Webhook
return Webhook(data)
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
"""Fetches a webhook by ID and returns the raw payload."""
return await self.request("GET", f"/webhooks/{webhook_id}")
async def edit_webhook(
self, webhook_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook":
"""Edits an existing webhook."""
data = await self.request("PATCH", f"/webhooks/{webhook_id}", payload=payload)
@ -803,25 +803,28 @@ class HTTPClient:
return Webhook(data)
async def delete_webhook(self, webhook_id: "Snowflake") -> None:
"""Deletes a webhook."""
await self.request("DELETE", f"/webhooks/{webhook_id}")
async def get_webhook(
self, webhook_id: "Snowflake", token: Optional[str] = None
) -> "Webhook":
"""Fetches a webhook by ID, optionally using its token."""
endpoint = f"/webhooks/{webhook_id}"
use_auth = True
if token is not None:
endpoint += f"/{token}"
use_auth = False
data = await self.request("GET", endpoint, use_auth_header=use_auth)
from .models import Webhook
return Webhook(data)
async def delete_webhook(self, webhook_id: "Snowflake") -> None:
"""Deletes a webhook."""
await self.request("DELETE", f"/webhooks/{webhook_id}")
async def get_webhook_with_token(
self, webhook_id: "Snowflake", token: Optional[str] = None
) -> "Webhook":
"""Fetches a webhook by ID, optionally using its token."""
endpoint = f"/webhooks/{webhook_id}"
use_auth = True
if token is not None:
endpoint += f"/{token}"
use_auth = False
if use_auth:
data = await self.request("GET", endpoint)
else:
data = await self.request("GET", endpoint, use_auth_header=False)
from .models import Webhook
return Webhook(data)
async def execute_webhook(
self,

View File

@ -2,6 +2,8 @@
Data models for Discord objects.
"""
from __future__ import annotations
import asyncio
import datetime
import io
@ -54,13 +56,26 @@ if TYPE_CHECKING:
from .interactions import Snowflake
from .typing import Typing
from .shard_manager import Shard
from .asset import Asset
# Forward reference Message if it were used in type hints before its definition
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
from .components import component_factory
class User:
class HashableById:
"""Mixin providing equality and hashing based on the ``id`` attribute."""
id: str
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.id == other.id # type: ignore[attr-defined]
def __hash__(self) -> int: # pragma: no cover - trivial
return hash(self.id)
class User(HashableById):
"""Represents a Discord User."""
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
@ -71,7 +86,12 @@ class User:
self.username: Optional[str] = data.get("username")
self.discriminator: Optional[str] = data.get("discriminator")
self.bot: bool = data.get("bot", False)
self.avatar: Optional[str] = data.get("avatar")
avatar_hash = data.get("avatar")
self._avatar: Optional[str] = (
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
if avatar_hash
else None
)
@property
def mention(self) -> str:
@ -83,6 +103,25 @@ class User:
disc = self.discriminator or "????"
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
@property
def avatar(self) -> Optional["Asset"]:
"""Return the user's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
async def send(
self,
content: Optional[str] = None,
@ -98,7 +137,7 @@ class User:
return await target_client.send_dm(self.id, content=content, **kwargs)
class Message:
class Message(HashableById):
"""Represents a message sent in a channel on Discord.
Attributes:
@ -780,7 +819,12 @@ class Role:
self.name: str = data["name"]
self.color: int = data["color"]
self.hoist: bool = data["hoist"]
self.icon: Optional[str] = data.get("icon")
icon_hash = data.get("icon")
self._icon: Optional[str] = (
f"https://cdn.discordapp.com/role-icons/{self.id}/{icon_hash}.png"
if icon_hash
else None
)
self.unicode_emoji: Optional[str] = data.get("unicode_emoji")
self.position: int = data["position"]
self.permissions: str = data["permissions"] # String of bitwise permissions
@ -798,6 +842,23 @@ class Role:
def __repr__(self) -> str:
return f"<Role id='{self.id}' name='{self.name}'>"
@property
def icon(self) -> Optional["Asset"]:
if self._icon:
from .asset import Asset
return Asset(self._icon, None)
return None
@icon.setter
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._icon = value
elif value is None:
self._icon = None
else:
self._icon = value.url
class Member(User): # Member inherits from User
"""Represents a Guild Member.
@ -826,7 +887,15 @@ class Member(User): # Member inherits from User
) # Pass user_data or data if user_data is empty
self.nick: Optional[str] = data.get("nick")
self.avatar: Optional[str] = data.get("avatar")
avatar_hash = data.get("avatar")
if avatar_hash:
guild_id = data.get("guild_id")
if guild_id:
self._avatar = f"https://cdn.discordapp.com/guilds/{guild_id}/users/{self.id}/avatars/{avatar_hash}.png"
else:
self._avatar = (
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
)
self.roles: List[str] = data.get("roles", [])
self.joined_at: str = data["joined_at"]
self.premium_since: Optional[str] = data.get("premium_since")
@ -854,6 +923,25 @@ class Member(User): # Member inherits from User
def __repr__(self) -> str:
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
@property
def avatar(self) -> Optional["Asset"]:
"""Return the member's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
@property
def display_name(self) -> str:
"""Return the nickname if set, otherwise the username."""
@ -921,7 +1009,6 @@ class Member(User): # Member inherits from User
return max(role_objects, key=lambda r: r.position)
@property
def guild_permissions(self) -> "Permissions":
"""Return the member's guild-level permissions."""
@ -947,7 +1034,8 @@ class Member(User): # Member inherits from User
return Permissions(~0)
return base
@property
def voice(self) -> Optional["VoiceState"]:
"""Return the member's cached voice state as a :class:`VoiceState`."""
@ -1152,7 +1240,7 @@ class PermissionOverwrite:
return f"<PermissionOverwrite id='{self.id}' type='{self.type.name if hasattr(self.type, 'name') else self._type_val}' allow='{self.allow}' deny='{self.deny}'>"
class Guild:
class Guild(HashableById):
"""Represents a Discord Guild (Server).
Attributes:
@ -1210,9 +1298,24 @@ class Guild:
)
self.id: str = data["id"]
self.name: str = data["name"]
self.icon: Optional[str] = data.get("icon")
self.splash: Optional[str] = data.get("splash")
self.discovery_splash: Optional[str] = data.get("discovery_splash")
icon_hash = data.get("icon")
self._icon: Optional[str] = (
f"https://cdn.discordapp.com/icons/{self.id}/{icon_hash}.png"
if icon_hash
else None
)
splash_hash = data.get("splash")
self._splash: Optional[str] = (
f"https://cdn.discordapp.com/splashes/{self.id}/{splash_hash}.png"
if splash_hash
else None
)
discovery_hash = data.get("discovery_splash")
self._discovery_splash: Optional[str] = (
f"https://cdn.discordapp.com/discovery-splashes/{self.id}/{discovery_hash}.png"
if discovery_hash
else None
)
self.owner: Optional[bool] = data.get("owner")
self.owner_id: str = data["owner_id"]
self.permissions: Optional[str] = data.get("permissions")
@ -1249,7 +1352,12 @@ class Guild:
self.max_members: Optional[int] = data.get("max_members")
self.vanity_url_code: Optional[str] = data.get("vanity_url_code")
self.description: Optional[str] = data.get("description")
self.banner: Optional[str] = data.get("banner")
banner_hash = data.get("banner")
self._banner: Optional[str] = (
f"https://cdn.discordapp.com/banners/{self.id}/{banner_hash}.png"
if banner_hash
else None
)
self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"])
self.premium_subscription_count: Optional[int] = data.get(
"premium_subscription_count"
@ -1357,6 +1465,74 @@ class Guild:
def __repr__(self) -> str:
return f"<Guild id='{self.id}' name='{self.name}'>"
@property
def icon(self) -> Optional["Asset"]:
if self._icon:
from .asset import Asset
return Asset(self._icon, self._client)
return None
@icon.setter
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._icon = value
elif value is None:
self._icon = None
else:
self._icon = value.url
@property
def splash(self) -> Optional["Asset"]:
if self._splash:
from .asset import Asset
return Asset(self._splash, self._client)
return None
@splash.setter
def splash(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._splash = value
elif value is None:
self._splash = None
else:
self._splash = value.url
@property
def discovery_splash(self) -> Optional["Asset"]:
if self._discovery_splash:
from .asset import Asset
return Asset(self._discovery_splash, self._client)
return None
@discovery_splash.setter
def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._discovery_splash = value
elif value is None:
self._discovery_splash = None
else:
self._discovery_splash = value.url
@property
def banner(self) -> Optional["Asset"]:
if self._banner:
from .asset import Asset
return Asset(self._banner, self._client)
return None
@banner.setter
def banner(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._banner = value
elif value is None:
self._banner = None
else:
self._banner = value.url
async def fetch_widget(self) -> Dict[str, Any]:
"""|coro| Fetch this guild's widget settings."""
@ -1485,7 +1661,7 @@ class Guild:
return cast("CategoryChannel", self._client.parse_channel(data))
class Channel:
class Channel(HashableById):
"""Base class for Discord channels."""
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
@ -2089,7 +2265,12 @@ class Webhook:
self.guild_id: Optional[str] = data.get("guild_id")
self.channel_id: Optional[str] = data.get("channel_id")
self.name: Optional[str] = data.get("name")
self.avatar: Optional[str] = data.get("avatar")
avatar_hash = data.get("avatar")
self._avatar: Optional[str] = (
f"https://cdn.discordapp.com/webhooks/{self.id}/{avatar_hash}.png"
if avatar_hash
else None
)
self.token: Optional[str] = data.get("token")
self.application_id: Optional[str] = data.get("application_id")
self.url: Optional[str] = data.get("url")
@ -2098,6 +2279,25 @@ class Webhook:
def __repr__(self) -> str:
return f"<Webhook id='{self.id}' name='{self.name}'>"
@property
def avatar(self) -> Optional["Asset"]:
"""Return the webhook's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
@classmethod
def from_url(
cls, url: str, session: Optional[aiohttp.ClientSession] = None

19
disagreement/object.py Normal file
View File

@ -0,0 +1,19 @@
class Object:
"""A minimal wrapper around a Discord snowflake ID."""
__slots__ = ("id",)
def __init__(self, object_id: int) -> None:
self.id = int(object_id)
def __int__(self) -> int:
return self.id
def __hash__(self) -> int:
return hash(self.id)
def __eq__(self, other: object) -> bool:
return isinstance(other, Object) and self.id == other.id
def __repr__(self) -> str:
return f"<Object id={self.id}>"

View File

@ -77,11 +77,14 @@ class VoiceClient:
self.secret_key: Optional[Sequence[int]] = None
self._server_ip: Optional[str] = None
self._server_port: Optional[int] = None
self._current_source: Optional[AudioSource] = None
self._play_task: Optional[asyncio.Task] = None
self._sink: Optional[AudioSink] = None
self._ssrc_map: dict[int, int] = {}
self._ssrc_lock = threading.Lock()
self._current_source: Optional[AudioSource] = None
self._play_task: Optional[asyncio.Task] = None
self._pause_event = asyncio.Event()
self._pause_event.set()
self._is_playing = False
self._sink: Optional[AudioSink] = None
self._ssrc_map: dict[int, int] = {}
self._ssrc_lock = threading.Lock()
async def connect(self) -> None:
if self._ws is None:
@ -189,31 +192,37 @@ class VoiceClient:
raise RuntimeError("UDP socket not initialised")
self._udp.send(frame)
async def _play_loop(self) -> None:
assert self._current_source is not None
try:
while True:
data = await self._current_source.read()
if not data:
break
volume = getattr(self._current_source, "volume", 1.0)
if volume != 1.0:
data = _apply_volume(data, volume)
await self.send_audio_frame(data)
finally:
await self._current_source.close()
self._current_source = None
self._play_task = None
async def _play_loop(self) -> None:
assert self._current_source is not None
self._is_playing = True
try:
while True:
await self._pause_event.wait()
data = await self._current_source.read()
if not data:
break
volume = getattr(self._current_source, "volume", 1.0)
if volume != 1.0:
data = _apply_volume(data, volume)
await self.send_audio_frame(data)
finally:
await self._current_source.close()
self._current_source = None
self._play_task = None
self._is_playing = False
self._pause_event.set()
async def stop(self) -> None:
if self._play_task:
self._play_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._play_task
self._play_task = None
if self._current_source:
await self._current_source.close()
self._current_source = None
async def stop(self) -> None:
if self._play_task:
self._play_task.cancel()
self._pause_event.set()
with contextlib.suppress(asyncio.CancelledError):
await self._play_task
self._play_task = None
self._is_playing = False
if self._current_source:
await self._current_source.close()
self._current_source = None
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
"""|coro| Play an :class:`AudioSource` on the voice connection."""
@ -224,10 +233,31 @@ class VoiceClient:
if wait:
await self._play_task
async def play_file(self, filename: str, *, wait: bool = True) -> None:
"""|coro| Stream an audio file or URL using FFmpeg."""
await self.play(FFmpegAudioSource(filename), wait=wait)
async def play_file(self, filename: str, *, wait: bool = True) -> None:
"""|coro| Stream an audio file or URL using FFmpeg."""
await self.play(FFmpegAudioSource(filename), wait=wait)
def pause(self) -> None:
"""Pause the current audio source."""
if self._play_task and not self._play_task.done():
self._pause_event.clear()
def resume(self) -> None:
"""Resume playback of a paused source."""
if self._play_task and not self._play_task.done():
self._pause_event.set()
def is_paused(self) -> bool:
"""Return ``True`` if playback is currently paused."""
return bool(self._play_task and not self._pause_event.is_set())
def is_playing(self) -> bool:
"""Return ``True`` if audio is actively being played."""
return self._is_playing and self._pause_event.is_set()
def listen(self, sink: AudioSink) -> None:
"""Start listening to voice and routing to a sink."""

View File

@ -28,6 +28,13 @@ The cache can be cleared manually if needed:
client.cache.clear()
```
## Partial Objects
Some events only include minimal data for related resources. When only an ``id``
is available, Disagreement represents the resource using :class:`~disagreement.Object`.
These objects can be compared and used in sets or dictionaries and can be passed
to API methods to fetch the full data when needed.
## Next Steps
- [Components](using_components.md)

View File

@ -11,7 +11,11 @@ The command handler registers a `help` command automatically. Use it to list all
!help ping # shows help for the "ping" command
```
The help command will show each command's brief description if provided.
Commands are grouped by their Cog name and paginated so that long help
lists are split into multiple messages using the `Paginator` utility.
If you need custom formatting you can subclass
`HelpCommand` and override `send_command_help` or `send_group_help`.
## Checks

View File

@ -157,6 +157,22 @@ container = Container(
A container can itself contain layout and content components, letting you build complex messages.
## Persistent Views
Views with ``timeout=None`` are persistent. Their ``custom_id`` components are saved to ``persistent_views.json`` so they survive bot restarts.
```python
class MyView(View):
@button(label="Press", custom_id="press")
async def handle(self, view, inter):
await inter.respond("Pressed!")
client.add_persistent_view(MyView())
```
When the client starts, it loads this file and registers each view again. Remove
the file to clear stored views.
## Next Steps
- [Slash Commands](slash_commands.md)

View File

@ -6,6 +6,10 @@ Disagreement includes experimental support for connecting to voice channels. You
voice = await client.join_voice(guild_id, channel_id)
await voice.play_file("welcome.mp3")
await voice.play_file("another.mp3") # switch sources while connected
voice.pause()
voice.resume()
if voice.is_playing():
print("audio is playing")
await voice.close()
```

View File

@ -3,7 +3,16 @@ import pytest
from disagreement.ext.commands.converters import run_converters
from disagreement.ext.commands.core import CommandContext, Command
from disagreement.ext.commands.errors import BadArgument
from disagreement.models import Message, Member, Role, Guild, User
from disagreement.models import (
Message,
Member,
Role,
Guild,
User,
TextChannel,
VoiceChannel,
PartialEmoji,
)
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
@ -11,11 +20,12 @@ from disagreement.enums import (
MFALevel,
GuildNSFWLevel,
PremiumTier,
ChannelType,
)
from disagreement.client import Client
from disagreement.cache import GuildCache, Cache
from disagreement.cache import GuildCache, Cache, ChannelCache
class DummyBot(Client):
@ -23,10 +33,14 @@ class DummyBot(Client):
super().__init__(token="test")
self._guilds = GuildCache()
self._users = Cache()
self._channels = ChannelCache()
def get_guild(self, guild_id):
return self._guilds.get(guild_id)
def get_channel(self, channel_id):
return self._channels.get(channel_id)
async def fetch_member(self, guild_id, member_id):
guild = self._guilds.get(guild_id)
return guild.get_member(member_id) if guild else None
@ -41,6 +55,9 @@ class DummyBot(Client):
async def fetch_user(self, user_id):
return self._users.get(user_id)
async def fetch_channel(self, channel_id):
return self._channels.get(channel_id)
@pytest.fixture()
def guild_objects():
@ -93,12 +110,38 @@ def guild_objects():
guild._members.set(member.id, member)
guild.roles.append(role)
return guild, member, role, user
text_channel = TextChannel(
{
"id": "20",
"type": ChannelType.GUILD_TEXT.value,
"guild_id": guild.id,
"permission_overwrites": [],
},
client_instance=bot,
)
voice_channel = VoiceChannel(
{
"id": "21",
"type": ChannelType.GUILD_VOICE.value,
"guild_id": guild.id,
"permission_overwrites": [],
},
client_instance=bot,
)
guild._channels.set(text_channel.id, text_channel)
guild.text_channels.append(text_channel)
guild._channels.set(voice_channel.id, voice_channel)
guild.voice_channels.append(voice_channel)
bot._channels.set(text_channel.id, text_channel)
bot._channels.set(voice_channel.id, voice_channel)
return guild, member, role, user, text_channel, voice_channel
@pytest.fixture()
def command_context(guild_objects):
guild, member, role, _ = guild_objects
guild, member, role, _, _, _ = guild_objects
bot = guild._client
message_data = {
"id": "10",
@ -121,7 +164,7 @@ def command_context(guild_objects):
@pytest.mark.asyncio
async def test_member_converter(command_context, guild_objects):
_, member, _, _ = guild_objects
_, member, _, _, _, _ = guild_objects
mention = f"<@!{member.id}>"
result = await run_converters(command_context, Member, mention)
assert result is member
@ -131,7 +174,7 @@ async def test_member_converter(command_context, guild_objects):
@pytest.mark.asyncio
async def test_role_converter(command_context, guild_objects):
_, _, role, _ = guild_objects
_, _, role, _, _, _ = guild_objects
mention = f"<@&{role.id}>"
result = await run_converters(command_context, Role, mention)
assert result is role
@ -141,7 +184,7 @@ async def test_role_converter(command_context, guild_objects):
@pytest.mark.asyncio
async def test_user_converter(command_context, guild_objects):
_, _, _, user = guild_objects
_, _, _, user, _, _ = guild_objects
mention = f"<@{user.id}>"
result = await run_converters(command_context, User, mention)
assert result is user
@ -151,11 +194,43 @@ async def test_user_converter(command_context, guild_objects):
@pytest.mark.asyncio
async def test_guild_converter(command_context, guild_objects):
guild, _, _, _ = guild_objects
guild, _, _, _, _, _ = guild_objects
result = await run_converters(command_context, Guild, guild.id)
assert result is guild
@pytest.mark.asyncio
async def test_text_channel_converter(command_context, guild_objects):
_, _, _, _, text_channel, _ = guild_objects
mention = f"<#{text_channel.id}>"
result = await run_converters(command_context, TextChannel, mention)
assert result is text_channel
result = await run_converters(command_context, TextChannel, text_channel.id)
assert result is text_channel
@pytest.mark.asyncio
async def test_voice_channel_converter(command_context, guild_objects):
_, _, _, _, _, voice_channel = guild_objects
mention = f"<#{voice_channel.id}>"
result = await run_converters(command_context, VoiceChannel, mention)
assert result is voice_channel
result = await run_converters(command_context, VoiceChannel, voice_channel.id)
assert result is voice_channel
@pytest.mark.asyncio
async def test_emoji_converter(command_context):
result = await run_converters(command_context, PartialEmoji, "<:smile:1>")
assert isinstance(result, PartialEmoji)
assert result.id == "1"
assert result.name == "smile"
result = await run_converters(command_context, PartialEmoji, "😄")
assert result.id is None
assert result.name == "😄"
@pytest.mark.asyncio
async def test_member_converter_no_guild():
guild_data = {

14
tests/test_asset.py Normal file
View File

@ -0,0 +1,14 @@
from disagreement.models import User
from disagreement.asset import Asset
def test_user_avatar_returns_asset():
user = User({"id": "1", "username": "u", "discriminator": "0001", "avatar": "abc"})
avatar = user.avatar
assert isinstance(avatar, Asset)
assert avatar.url == "https://cdn.discordapp.com/avatars/1/abc.png"
def test_user_avatar_none():
user = User({"id": "1", "username": "u", "discriminator": "0001"})
assert user.avatar is None

View File

@ -1,6 +1,60 @@
import time
from disagreement.cache import Cache
from disagreement.client import Client
from disagreement.caching import MemberCacheFlags
from disagreement.enums import (
ChannelType,
ExplicitContentFilterLevel,
GuildNSFWLevel,
MFALevel,
MessageNotificationLevel,
PremiumTier,
VerificationLevel,
)
def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict:
base = {
"id": gid,
"name": f"g{gid}",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
"channels": [],
"members": [],
}
for i in range(channel_count):
base["channels"].append(
{
"id": f"{gid}-c{i}",
"type": ChannelType.GUILD_TEXT.value,
"guild_id": gid,
"permission_overwrites": [],
}
)
for i in range(member_count):
base["members"].append(
{
"user": {
"id": f"{gid}-m{i}",
"username": f"u{i}",
"discriminator": "0001",
},
"joined_at": "t",
"roles": [],
}
)
return base
def test_cache_store_and_get():
@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item():
assert cache.get_or_fetch("c", fetch) == 3
assert called
def test_client_get_all_channels_and_members():
client = Client(token="t")
client.parse_guild(_guild_payload("1", 2, 2))
client.parse_guild(_guild_payload("2", 1, 1))
channels = {c.id for c in client.get_all_channels()}
members = {m.id for m in client.get_all_members()}
assert channels == {"1-c0", "1-c1", "2-c0"}
assert members == {"1-m0", "1-m1", "2-m0"}
def test_client_get_all_members_disabled_cache():
client = Client(token="t", member_cache_flags=MemberCacheFlags.none())
client.parse_guild(_guild_payload("1", 1, 2))
assert client.get_all_members() == []

View File

@ -0,0 +1,86 @@
import types
from disagreement.models import User, Guild, Channel, Message
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
ChannelType,
)
def _guild_data(gid="1"):
return {
"id": gid,
"name": "g",
"owner_id": gid,
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
def _user(uid="1"):
return User({"id": uid, "username": "u", "discriminator": "0001"})
def _message(mid="1"):
data = {
"id": mid,
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
return Message(data, client_instance=types.SimpleNamespace())
def _channel(cid="1"):
data = {"id": cid, "type": ChannelType.GUILD_TEXT.value}
return Channel(data, client_instance=types.SimpleNamespace())
def test_user_hash_and_eq():
a = _user()
b = _user()
c = _user("2")
assert a == b
assert hash(a) == hash(b)
assert a != c
def test_guild_hash_and_eq():
a = Guild(_guild_data(), client_instance=types.SimpleNamespace())
b = Guild(_guild_data(), client_instance=types.SimpleNamespace())
c = Guild(_guild_data("2"), client_instance=types.SimpleNamespace())
assert a == b
assert hash(a) == hash(b)
assert a != c
def test_channel_hash_and_eq():
a = _channel()
b = _channel()
c = _channel("2")
assert a == b
assert hash(a) == hash(b)
assert a != c
def test_message_hash_and_eq():
a = _message()
b = _message()
c = _message("2")
assert a == b
assert hash(a) == hash(b)
assert a != c

View File

@ -1,7 +1,9 @@
import pytest
from disagreement.ext.commands.core import CommandHandler, Command
from disagreement.ext import commands
from disagreement.ext.commands.core import CommandHandler, Command, Group
from disagreement.models import Message
from disagreement.ext.commands.help import HelpCommand
class DummyBot:
@ -13,15 +15,21 @@ class DummyBot:
return {"id": "1", "channel_id": channel_id, "content": content}
class MyCog(commands.Cog):
def __init__(self, client) -> None:
super().__init__(client)
@commands.command()
async def foo(self, ctx: commands.CommandContext) -> None:
pass
@pytest.mark.asyncio
async def test_help_lists_commands():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
async def foo(ctx):
pass
handler.add_command(Command(foo, name="foo", brief="Foo cmd"))
handler.add_cog(MyCog(bot))
msg_data = {
"id": "1",
@ -33,6 +41,7 @@ async def test_help_lists_commands():
msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg)
assert any("foo" in m for m in bot.sent)
assert any("MyCog" in m for m in bot.sent)
@pytest.mark.asyncio
@ -55,3 +64,65 @@ async def test_help_specific_command():
msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg)
assert any("Bar desc" in m for m in bot.sent)
class CustomHelp(HelpCommand):
async def send_command_help(self, ctx, command):
await ctx.send(f"custom {command.name}")
async def send_group_help(self, ctx, group):
await ctx.send(f"group {group.name}")
@pytest.mark.asyncio
async def test_custom_help_methods():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
handler.remove_command("help")
handler.add_command(CustomHelp(handler))
async def sub(ctx):
pass
group = Group(sub, name="grp")
handler.add_command(group)
msg_data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "!help grp",
"timestamp": "t",
}
msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg)
assert any("group grp" in m for m in bot.sent)
@pytest.mark.asyncio
async def test_help_lists_subcommands():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
async def root(ctx):
pass
group = Group(root, name="root")
@group.command(name="child")
async def child(ctx):
pass
handler.add_command(group)
msg_data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "!help",
"timestamp": "t",
}
msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg)
assert any("root" in m for m in bot.sent)
assert any("child" in m for m in bot.sent)

15
tests/test_object.py Normal file
View File

@ -0,0 +1,15 @@
from disagreement.object import Object
def test_object_int():
obj = Object(123)
assert int(obj) == 123
def test_object_equality_and_hash():
a = Object(1)
b = Object(1)
c = Object(2)
assert a == b
assert a != c
assert hash(a) == hash(b)

View File

@ -59,6 +59,17 @@ class DummySource(AudioSource):
return b""
class SlowSource(AudioSource):
def __init__(self, chunks):
self.chunks = list(chunks)
async def read(self) -> bytes:
await asyncio.sleep(0)
if self.chunks:
return self.chunks.pop(0)
return b""
@pytest.mark.asyncio
async def test_voice_client_handshake():
hello = {"d": {"heartbeat_interval": 50}}
@ -205,3 +216,49 @@ async def test_voice_client_volume_scaling(monkeypatch):
samples[1] = int(samples[1] * 0.5)
expected = samples.tobytes()
assert udp.sent == [expected]
@pytest.mark.asyncio
async def test_pause_resume_and_status():
ws = DummyWebSocket(
[
{"d": {"heartbeat_interval": 50}},
{"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}},
{"d": {"secret_key": []}},
]
)
udp = DummyUDP()
vc = VoiceClient(
client=DummyVoiceClient(),
endpoint="ws://localhost",
session_id="sess",
token="tok",
guild_id=1,
user_id=2,
ws=ws,
udp=udp,
)
await vc.connect()
vc._heartbeat_task.cancel()
src = SlowSource([b"a", b"b", b"c"])
await vc.play(src, wait=False)
while not udp.sent:
await asyncio.sleep(0)
assert vc.is_playing()
vc.pause()
assert vc.is_paused()
await asyncio.sleep(0)
sent = len(udp.sent)
await asyncio.sleep(0.01)
assert len(udp.sent) == sent
assert not vc.is_playing()
vc.resume()
assert not vc.is_paused()
await vc._play_task
assert udp.sent == [b"a", b"b", b"c"]
assert not vc.is_playing()

View File

@ -0,0 +1,35 @@
import pytest
from disagreement.ext.commands.core import CommandHandler, Command, Group
class DummyBot:
pass
@pytest.mark.asyncio
async def test_walk_commands_recurses_groups():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
async def root(ctx):
pass
root_group = Group(root, name="root")
@root_group.command(name="child")
async def child(ctx):
pass
@root_group.group(name="sub")
async def sub(ctx):
pass
@sub.command(name="leaf")
async def leaf(ctx):
pass
handler.add_command(root_group)
names = [cmd.name for cmd in handler.walk_commands()]
assert set(names) == {"help", "root", "child", "sub", "leaf"}

View File

@ -204,7 +204,7 @@ async def test_get_webhook_calls_request():
await http.get_webhook("1")
http.request.assert_called_once_with("GET", "/webhooks/1")
http.request.assert_called_once_with("GET", "/webhooks/1", use_auth_header=True)
@pytest.mark.asyncio