From aa55aa1d4c8eb85b76fe60afccdc3dcad63fdb86 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:39:12 -0600 Subject: [PATCH 01/10] feat: persist views (#120) --- disagreement/client.py | 77 +++++++++++++++++++++++++++++++--------- docs/using_components.md | 16 +++++++++ 2 files changed, 76 insertions(+), 17 deletions(-) diff --git a/disagreement/client.py b/disagreement/client.py index b335d8a..b474477 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -2,8 +2,11 @@ The main Client class for interacting with the Discord API. """ -import asyncio -import signal +import asyncio +import signal +import json +import os +import importlib from typing import ( Optional, Callable, @@ -16,7 +19,9 @@ from typing import ( Dict, cast, ) -from types import ModuleType +from types import ModuleType + +PERSISTENT_VIEWS_FILE = "persistent_views.json" from datetime import datetime, timedelta @@ -77,7 +82,7 @@ def _update_list(lst: List[Any], item: Any) -> None: lst.append(item) -class Client: +class Client: """ Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -193,7 +198,10 @@ class Client: self._views: Dict[Snowflake, "View"] = {} self._persistent_views: Dict[str, "View"] = {} self._voice_clients: Dict[Snowflake, VoiceClient] = {} - self._webhooks: Dict[Snowflake, "Webhook"] = {} + self._webhooks: Dict[Snowflake, "Webhook"] = {} + + # Load persistent views stored on disk + self._load_persistent_views() # Default whether replies mention the user self.mention_replies: bool = mention_replies @@ -210,13 +218,46 @@ class Client: self.loop.add_signal_handler( signal.SIGTERM, lambda: self.loop.create_task(self.close()) ) - except NotImplementedError: - # add_signal_handler is not available on all platforms (e.g., Windows default event loop policy) - # Users on these platforms would need to handle shutdown differently. - print( - "Warning: Signal handlers for SIGINT/SIGTERM could not be added. " - "Graceful shutdown via signals might not work as expected on this platform." - ) + except NotImplementedError: + # add_signal_handler is not available on all platforms (e.g., Windows default event loop policy) + # Users on these platforms would need to handle shutdown differently. + print( + "Warning: Signal handlers for SIGINT/SIGTERM could not be added. " + "Graceful shutdown via signals might not work as expected on this platform." + ) + + def _load_persistent_views(self) -> None: + """Load registered persistent views from disk.""" + if not os.path.isfile(PERSISTENT_VIEWS_FILE): + return + try: + with open(PERSISTENT_VIEWS_FILE, "r") as fp: + mapping = json.load(fp) + except Exception as e: # pragma: no cover - best effort load + print(f"Failed to load persistent views: {e}") + return + + for custom_id, path in mapping.items(): + try: + module_name, class_name = path.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + view = cls() + self._persistent_views[custom_id] = view + except Exception as e: # pragma: no cover - best effort load + print(f"Failed to initialize persistent view {path}: {e}") + + def _save_persistent_views(self) -> None: + """Persist registered views to disk.""" + data = {} + for custom_id, view in self._persistent_views.items(): + cls = view.__class__ + data[custom_id] = f"{cls.__module__}.{cls.__name__}" + try: + with open(PERSISTENT_VIEWS_FILE, "w") as fp: + json.dump(data, fp) + except Exception as e: # pragma: no cover - best effort save + print(f"Failed to save persistent views: {e}") async def _initialize_gateway(self): """Initializes the GatewayClient if it doesn't exist.""" @@ -1707,11 +1748,13 @@ class Client: for item in view.children: if item.custom_id: # Ensure custom_id is not None - if item.custom_id in self._persistent_views: - raise ValueError( - f"A component with custom_id '{item.custom_id}' is already registered." - ) - self._persistent_views[item.custom_id] = view + if item.custom_id in self._persistent_views: + raise ValueError( + f"A component with custom_id '{item.custom_id}' is already registered." + ) + self._persistent_views[item.custom_id] = view + + self._save_persistent_views() # --- Application Command Methods --- async def process_interaction(self, interaction: Interaction) -> None: diff --git a/docs/using_components.md b/docs/using_components.md index e987217..b7736ca 100644 --- a/docs/using_components.md +++ b/docs/using_components.md @@ -157,6 +157,22 @@ container = Container( A container can itself contain layout and content components, letting you build complex messages. +## Persistent Views + +Views with ``timeout=None`` are persistent. Their ``custom_id`` components are saved to ``persistent_views.json`` so they survive bot restarts. + +```python +class MyView(View): + @button(label="Press", custom_id="press") + async def handle(self, view, inter): + await inter.respond("Pressed!") + +client.add_persistent_view(MyView()) +``` + +When the client starts, it loads this file and registers each view again. Remove +the file to clear stored views. + ## Next Steps - [Slash Commands](slash_commands.md) From 4b3b6aeb45b4b6cfd5c39bc6c66b7e890a6d8358 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:39:14 -0600 Subject: [PATCH 02/10] Add Asset model and avatar helpers (#119) --- README.md | 12 ++ disagreement/__init__.py | 2 + disagreement/asset.py | 51 +++++++++ disagreement/http.py | 10 +- disagreement/models.py | 230 +++++++++++++++++++++++++++++++++++---- tests/test_asset.py | 14 +++ 6 files changed, 292 insertions(+), 27 deletions(-) create mode 100644 disagreement/asset.py create mode 100644 tests/test_asset.py diff --git a/README.md b/README.md index 26e5c55..b2c56ab 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ A Python library for interacting with the Discord API, with a focus on bot devel - `Message.jump_url` property for quick links to messages - Built-in caching layer - `Guild.me` property to access the bot's member object +- Easy CDN asset handling via the `Asset` model - Experimental voice support - Helpful error handling utilities @@ -126,6 +127,17 @@ client = disagreement.Client( This dictionary is used whenever ``send_message`` or helpers like ``Message.reply`` are called without an explicit ``allowed_mentions`` argument. +### Working With Assets + +Properties like ``User.avatar`` and ``Guild.icon`` return :class:`disagreement.Asset` objects. +Use ``read`` to get the bytes or ``save`` to write them to disk. + +```python +user = await client.fetch_user(123) +data = await user.avatar.read() +await user.avatar.save("avatar.png") +``` + ### Defining Subcommands with `AppCommandGroup` ```python diff --git a/disagreement/__init__.py b/disagreement/__init__.py index ca28c17..fc44bfd 100644 --- a/disagreement/__init__.py +++ b/disagreement/__init__.py @@ -15,6 +15,7 @@ __copyright__ = "Copyright 2025 Slipstream" __version__ = "0.8.1" from .client import Client, AutoShardedClient +from .asset import Asset from .models import ( Message, User, @@ -125,6 +126,7 @@ import logging __all__ = [ "Client", "AutoShardedClient", + "Asset", "Message", "User", "Reaction", diff --git a/disagreement/asset.py b/disagreement/asset.py new file mode 100644 index 0000000..1254f59 --- /dev/null +++ b/disagreement/asset.py @@ -0,0 +1,51 @@ +"""Utility class for Discord CDN assets.""" + +from __future__ import annotations + +import os +from typing import IO, Optional, Union, TYPE_CHECKING + +import aiohttp # pylint: disable=import-error + +if TYPE_CHECKING: + from .client import Client + + +class Asset: + """Represents a CDN asset such as an avatar or icon.""" + + def __init__(self, url: str, client_instance: Optional["Client"] = None) -> None: + self.url = url + self._client = client_instance + + async def read(self) -> bytes: + """Read the asset's bytes.""" + + session: Optional[aiohttp.ClientSession] = None + if self._client is not None: + await self._client._http._ensure_session() # type: ignore[attr-defined] + session = self._client._http._session # type: ignore[attr-defined] + if session is None: + session = aiohttp.ClientSession() + close = True + else: + close = False + async with session.get(self.url) as resp: + data = await resp.read() + if close: + await session.close() + return data + + async def save(self, fp: Union[str, os.PathLike[str], IO[bytes]]) -> None: + """Save the asset to the given file path or file-like object.""" + + data = await self.read() + if isinstance(fp, (str, os.PathLike)): + path = os.fspath(fp) + with open(path, "wb") as file: + file.write(data) + else: + fp.write(data) + + def __repr__(self) -> str: + return f"" diff --git a/disagreement/http.py b/disagreement/http.py index 3751585..2cb9bca 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -788,11 +788,6 @@ class HTTPClient: return Webhook(data) - async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]: - """Fetches a webhook by ID.""" - - return await self.request("GET", f"/webhooks/{webhook_id}") - async def edit_webhook( self, webhook_id: "Snowflake", payload: Dict[str, Any] ) -> "Webhook": @@ -818,7 +813,10 @@ class HTTPClient: if token is not None: endpoint += f"/{token}" use_auth = False - data = await self.request("GET", endpoint, use_auth_header=use_auth) + if use_auth: + data = await self.request("GET", endpoint) + else: + data = await self.request("GET", endpoint, use_auth_header=False) from .models import Webhook return Webhook(data) diff --git a/disagreement/models.py b/disagreement/models.py index 98cbba2..de5d0af 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1,6 +1,8 @@ -""" -Data models for Discord objects. -""" +""" +Data models for Discord objects. +""" + +from __future__ import annotations import asyncio import datetime @@ -47,13 +49,14 @@ from .enums import ( # These enums will need to be defined in disagreement/enum from .permissions import Permissions -if TYPE_CHECKING: - from .client import Client # For type hinting to avoid circular imports - from .enums import OverwriteType # For PermissionOverwrite model - from .ui.view import View - from .interactions import Snowflake - from .typing import Typing - from .shard_manager import Shard +if TYPE_CHECKING: + from .client import Client # For type hinting to avoid circular imports + from .enums import OverwriteType # For PermissionOverwrite model + from .ui.view import View + from .interactions import Snowflake + from .typing import Typing + from .shard_manager import Shard + from .asset import Asset # Forward reference Message if it were used in type hints before its definition # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. @@ -71,7 +74,12 @@ class User: self.username: Optional[str] = data.get("username") self.discriminator: Optional[str] = data.get("discriminator") self.bot: bool = data.get("bot", False) - self.avatar: Optional[str] = data.get("avatar") + avatar_hash = data.get("avatar") + self._avatar: Optional[str] = ( + f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png" + if avatar_hash + else None + ) @property def mention(self) -> str: @@ -83,6 +91,25 @@ class User: disc = self.discriminator or "????" return f"" + @property + def avatar(self) -> Optional["Asset"]: + """Return the user's avatar as an :class:`Asset`.""" + + if self._avatar: + from .asset import Asset + + return Asset(self._avatar, self._client) + return None + + @avatar.setter + def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._avatar = value + elif value is None: + self._avatar = None + else: + self._avatar = value.url + async def send( self, content: Optional[str] = None, @@ -780,7 +807,12 @@ class Role: self.name: str = data["name"] self.color: int = data["color"] self.hoist: bool = data["hoist"] - self.icon: Optional[str] = data.get("icon") + icon_hash = data.get("icon") + self._icon: Optional[str] = ( + f"https://cdn.discordapp.com/role-icons/{self.id}/{icon_hash}.png" + if icon_hash + else None + ) self.unicode_emoji: Optional[str] = data.get("unicode_emoji") self.position: int = data["position"] self.permissions: str = data["permissions"] # String of bitwise permissions @@ -798,6 +830,23 @@ class Role: def __repr__(self) -> str: return f"" + @property + def icon(self) -> Optional["Asset"]: + if self._icon: + from .asset import Asset + + return Asset(self._icon, None) + return None + + @icon.setter + def icon(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._icon = value + elif value is None: + self._icon = None + else: + self._icon = value.url + class Member(User): # Member inherits from User """Represents a Guild Member. @@ -826,7 +875,15 @@ class Member(User): # Member inherits from User ) # Pass user_data or data if user_data is empty self.nick: Optional[str] = data.get("nick") - self.avatar: Optional[str] = data.get("avatar") + avatar_hash = data.get("avatar") + if avatar_hash: + guild_id = data.get("guild_id") + if guild_id: + self._avatar = f"https://cdn.discordapp.com/guilds/{guild_id}/users/{self.id}/avatars/{avatar_hash}.png" + else: + self._avatar = ( + f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png" + ) self.roles: List[str] = data.get("roles", []) self.joined_at: str = data["joined_at"] self.premium_since: Optional[str] = data.get("premium_since") @@ -854,6 +911,25 @@ class Member(User): # Member inherits from User def __repr__(self) -> str: return f"" + @property + def avatar(self) -> Optional["Asset"]: + """Return the member's avatar as an :class:`Asset`.""" + + if self._avatar: + from .asset import Asset + + return Asset(self._avatar, self._client) + return None + + @avatar.setter + def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._avatar = value + elif value is None: + self._avatar = None + else: + self._avatar = value.url + @property def display_name(self) -> str: """Return the nickname if set, otherwise the username.""" @@ -921,7 +997,6 @@ class Member(User): # Member inherits from User return max(role_objects, key=lambda r: r.position) @property - def guild_permissions(self) -> "Permissions": """Return the member's guild-level permissions.""" @@ -948,8 +1023,9 @@ class Member(User): # Member inherits from User return base - def voice(self) -> Optional["VoiceState"]: - """Return the member's cached voice state as a :class:`VoiceState`.""" + @property + def voice(self) -> Optional["VoiceState"]: + """Return the member's cached voice state as a :class:`VoiceState`.""" if self.voice_state is None: return None @@ -1210,9 +1286,24 @@ class Guild: ) self.id: str = data["id"] self.name: str = data["name"] - self.icon: Optional[str] = data.get("icon") - self.splash: Optional[str] = data.get("splash") - self.discovery_splash: Optional[str] = data.get("discovery_splash") + icon_hash = data.get("icon") + self._icon: Optional[str] = ( + f"https://cdn.discordapp.com/icons/{self.id}/{icon_hash}.png" + if icon_hash + else None + ) + splash_hash = data.get("splash") + self._splash: Optional[str] = ( + f"https://cdn.discordapp.com/splashes/{self.id}/{splash_hash}.png" + if splash_hash + else None + ) + discovery_hash = data.get("discovery_splash") + self._discovery_splash: Optional[str] = ( + f"https://cdn.discordapp.com/discovery-splashes/{self.id}/{discovery_hash}.png" + if discovery_hash + else None + ) self.owner: Optional[bool] = data.get("owner") self.owner_id: str = data["owner_id"] self.permissions: Optional[str] = data.get("permissions") @@ -1249,7 +1340,12 @@ class Guild: self.max_members: Optional[int] = data.get("max_members") self.vanity_url_code: Optional[str] = data.get("vanity_url_code") self.description: Optional[str] = data.get("description") - self.banner: Optional[str] = data.get("banner") + banner_hash = data.get("banner") + self._banner: Optional[str] = ( + f"https://cdn.discordapp.com/banners/{self.id}/{banner_hash}.png" + if banner_hash + else None + ) self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"]) self.premium_subscription_count: Optional[int] = data.get( "premium_subscription_count" @@ -1357,6 +1453,74 @@ class Guild: def __repr__(self) -> str: return f"" + @property + def icon(self) -> Optional["Asset"]: + if self._icon: + from .asset import Asset + + return Asset(self._icon, self._client) + return None + + @icon.setter + def icon(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._icon = value + elif value is None: + self._icon = None + else: + self._icon = value.url + + @property + def splash(self) -> Optional["Asset"]: + if self._splash: + from .asset import Asset + + return Asset(self._splash, self._client) + return None + + @splash.setter + def splash(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._splash = value + elif value is None: + self._splash = None + else: + self._splash = value.url + + @property + def discovery_splash(self) -> Optional["Asset"]: + if self._discovery_splash: + from .asset import Asset + + return Asset(self._discovery_splash, self._client) + return None + + @discovery_splash.setter + def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._discovery_splash = value + elif value is None: + self._discovery_splash = None + else: + self._discovery_splash = value.url + + @property + def banner(self) -> Optional["Asset"]: + if self._banner: + from .asset import Asset + + return Asset(self._banner, self._client) + return None + + @banner.setter + def banner(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._banner = value + elif value is None: + self._banner = None + else: + self._banner = value.url + async def fetch_widget(self) -> Dict[str, Any]: """|coro| Fetch this guild's widget settings.""" @@ -2089,7 +2253,12 @@ class Webhook: self.guild_id: Optional[str] = data.get("guild_id") self.channel_id: Optional[str] = data.get("channel_id") self.name: Optional[str] = data.get("name") - self.avatar: Optional[str] = data.get("avatar") + avatar_hash = data.get("avatar") + self._avatar: Optional[str] = ( + f"https://cdn.discordapp.com/webhooks/{self.id}/{avatar_hash}.png" + if avatar_hash + else None + ) self.token: Optional[str] = data.get("token") self.application_id: Optional[str] = data.get("application_id") self.url: Optional[str] = data.get("url") @@ -2098,6 +2267,25 @@ class Webhook: def __repr__(self) -> str: return f"" + @property + def avatar(self) -> Optional["Asset"]: + """Return the webhook's avatar as an :class:`Asset`.""" + + if self._avatar: + from .asset import Asset + + return Asset(self._avatar, self._client) + return None + + @avatar.setter + def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._avatar = value + elif value is None: + self._avatar = None + else: + self._avatar = value.url + @classmethod def from_url( cls, url: str, session: Optional[aiohttp.ClientSession] = None diff --git a/tests/test_asset.py b/tests/test_asset.py new file mode 100644 index 0000000..5661863 --- /dev/null +++ b/tests/test_asset.py @@ -0,0 +1,14 @@ +from disagreement.models import User +from disagreement.asset import Asset + + +def test_user_avatar_returns_asset(): + user = User({"id": "1", "username": "u", "discriminator": "0001", "avatar": "abc"}) + avatar = user.avatar + assert isinstance(avatar, Asset) + assert avatar.url == "https://cdn.discordapp.com/avatars/1/abc.png" + + +def test_user_avatar_none(): + user = User({"id": "1", "username": "u", "discriminator": "0001"}) + assert user.avatar is None From 17751d3b09e6ce9e8e8cd05f3f9de9a9272960f3 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:39:16 -0600 Subject: [PATCH 03/10] Add HashableById mixin and tests (#118) --- disagreement/models.py | 20 +++++++-- tests/test_hashable_mixin.py | 86 ++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 4 deletions(-) create mode 100644 tests/test_hashable_mixin.py diff --git a/disagreement/models.py b/disagreement/models.py index de5d0af..1f09d7c 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -63,7 +63,19 @@ if TYPE_CHECKING: from .components import component_factory -class User: +class HashableById: + """Mixin providing equality and hashing based on the ``id`` attribute.""" + + id: str + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) and self.id == other.id # type: ignore[attr-defined] + + def __hash__(self) -> int: # pragma: no cover - trivial + return hash(self.id) + + +class User(HashableById): """Represents a Discord User.""" def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None: @@ -125,7 +137,7 @@ class User: return await target_client.send_dm(self.id, content=content, **kwargs) -class Message: +class Message(HashableById): """Represents a message sent in a channel on Discord. Attributes: @@ -1228,7 +1240,7 @@ class PermissionOverwrite: return f"" -class Guild: +class Guild(HashableById): """Represents a Discord Guild (Server). Attributes: @@ -1649,7 +1661,7 @@ class Guild: return cast("CategoryChannel", self._client.parse_channel(data)) -class Channel: +class Channel(HashableById): """Base class for Discord channels.""" def __init__(self, data: Dict[str, Any], client_instance: "Client"): diff --git a/tests/test_hashable_mixin.py b/tests/test_hashable_mixin.py new file mode 100644 index 0000000..5a6c2ad --- /dev/null +++ b/tests/test_hashable_mixin.py @@ -0,0 +1,86 @@ +import types +from disagreement.models import User, Guild, Channel, Message +from disagreement.enums import ( + VerificationLevel, + MessageNotificationLevel, + ExplicitContentFilterLevel, + MFALevel, + GuildNSFWLevel, + PremiumTier, + ChannelType, +) + + +def _guild_data(gid="1"): + return { + "id": gid, + "name": "g", + "owner_id": gid, + "afk_timeout": 60, + "verification_level": VerificationLevel.NONE.value, + "default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value, + "explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value, + "roles": [], + "emojis": [], + "features": [], + "mfa_level": MFALevel.NONE.value, + "system_channel_flags": 0, + "premium_tier": PremiumTier.NONE.value, + "nsfw_level": GuildNSFWLevel.DEFAULT.value, + } + + +def _user(uid="1"): + return User({"id": uid, "username": "u", "discriminator": "0001"}) + + +def _message(mid="1"): + data = { + "id": mid, + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + return Message(data, client_instance=types.SimpleNamespace()) + + +def _channel(cid="1"): + data = {"id": cid, "type": ChannelType.GUILD_TEXT.value} + return Channel(data, client_instance=types.SimpleNamespace()) + + +def test_user_hash_and_eq(): + a = _user() + b = _user() + c = _user("2") + assert a == b + assert hash(a) == hash(b) + assert a != c + + +def test_guild_hash_and_eq(): + a = Guild(_guild_data(), client_instance=types.SimpleNamespace()) + b = Guild(_guild_data(), client_instance=types.SimpleNamespace()) + c = Guild(_guild_data("2"), client_instance=types.SimpleNamespace()) + assert a == b + assert hash(a) == hash(b) + assert a != c + + +def test_channel_hash_and_eq(): + a = _channel() + b = _channel() + c = _channel("2") + assert a == b + assert hash(a) == hash(b) + assert a != c + + +def test_message_hash_and_eq(): + a = _message() + b = _message() + c = _message("2") + assert a == b + assert hash(a) == hash(b) + assert a != c From cec747a575f0dd7646aa8850229316cb1725b78a Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:39:20 -0600 Subject: [PATCH 04/10] Improve help command (#116) --- disagreement/ext/commands/help.py | 58 +++++++++++++++++++++++++------ docs/commands.md | 6 +++- tests/test_help_command.py | 54 +++++++++++++++++++++++++--- 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/disagreement/ext/commands/help.py b/disagreement/ext/commands/help.py index 61c69d1..516f5ce 100644 --- a/disagreement/ext/commands/help.py +++ b/disagreement/ext/commands/help.py @@ -1,6 +1,8 @@ +from collections import defaultdict from typing import List, Optional -from .core import Command, CommandContext, CommandHandler +from ...utils import Paginator +from .core import Command, CommandContext, CommandHandler, Group class HelpCommand(Command): @@ -15,17 +17,12 @@ class HelpCommand(Command): if not cmd or cmd.name.lower() != command.lower(): await ctx.send(f"Command '{command}' not found.") return - description = cmd.description or cmd.brief or "No description provided." - await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}") - else: - lines: List[str] = [] - for registered in dict.fromkeys(handler.commands.values()): - brief = registered.brief or registered.description or "" - lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip()) - if lines: - await ctx.send("\n".join(lines)) + if isinstance(cmd, Group): + await self.send_group_help(ctx, cmd) else: - await ctx.send("No commands available.") + await self.send_command_help(ctx, cmd) + else: + await self.send_bot_help(ctx) super().__init__( callback, @@ -33,3 +30,42 @@ class HelpCommand(Command): brief="Show command help.", description="Displays help for commands.", ) + + async def send_bot_help(self, ctx: CommandContext) -> None: + groups = defaultdict(list) + for cmd in dict.fromkeys(self.handler.commands.values()): + key = cmd.cog.cog_name if cmd.cog else "No Category" + groups[key].append(cmd) + + paginator = Paginator() + for cog_name, cmds in groups.items(): + paginator.add_line(f"**{cog_name}**") + for cmd in cmds: + brief = cmd.brief or cmd.description or "" + paginator.add_line(f"{ctx.prefix}{cmd.name} - {brief}".strip()) + paginator.add_line("") + + pages = paginator.pages + if not pages: + await ctx.send("No commands available.") + return + for page in pages: + await ctx.send(page) + + async def send_command_help(self, ctx: CommandContext, command: Command) -> None: + description = command.description or command.brief or "No description provided." + await ctx.send(f"**{ctx.prefix}{command.name}**\n{description}") + + async def send_group_help(self, ctx: CommandContext, group: Group) -> None: + paginator = Paginator() + description = group.description or group.brief or "No description provided." + paginator.add_line(f"**{ctx.prefix}{group.name}**\n{description}") + if group.commands: + for sub in dict.fromkeys(group.commands.values()): + brief = sub.brief or sub.description or "" + paginator.add_line( + f"{ctx.prefix}{group.name} {sub.name} - {brief}".strip() + ) + + for page in paginator.pages: + await ctx.send(page) diff --git a/docs/commands.md b/docs/commands.md index ef312ec..68d1bd7 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -11,7 +11,11 @@ The command handler registers a `help` command automatically. Use it to list all !help ping # shows help for the "ping" command ``` -The help command will show each command's brief description if provided. +Commands are grouped by their Cog name and paginated so that long help +lists are split into multiple messages using the `Paginator` utility. + +If you need custom formatting you can subclass +`HelpCommand` and override `send_command_help` or `send_group_help`. ## Checks diff --git a/tests/test_help_command.py b/tests/test_help_command.py index 23a2c7a..6e1e30b 100644 --- a/tests/test_help_command.py +++ b/tests/test_help_command.py @@ -1,6 +1,7 @@ import pytest -from disagreement.ext.commands.core import CommandHandler, Command +from disagreement.ext import commands +from disagreement.ext.commands.core import CommandHandler, Command, Group from disagreement.models import Message @@ -13,15 +14,21 @@ class DummyBot: return {"id": "1", "channel_id": channel_id, "content": content} +class MyCog(commands.Cog): + def __init__(self, client) -> None: + super().__init__(client) + + @commands.command() + async def foo(self, ctx: commands.CommandContext) -> None: + pass + + @pytest.mark.asyncio async def test_help_lists_commands(): bot = DummyBot() handler = CommandHandler(client=bot, prefix="!") - async def foo(ctx): - pass - - handler.add_command(Command(foo, name="foo", brief="Foo cmd")) + handler.add_cog(MyCog(bot)) msg_data = { "id": "1", @@ -33,6 +40,7 @@ async def test_help_lists_commands(): msg = Message(msg_data, client_instance=bot) await handler.process_commands(msg) assert any("foo" in m for m in bot.sent) + assert any("MyCog" in m for m in bot.sent) @pytest.mark.asyncio @@ -55,3 +63,39 @@ async def test_help_specific_command(): msg = Message(msg_data, client_instance=bot) await handler.process_commands(msg) assert any("Bar desc" in m for m in bot.sent) + + +from disagreement.ext.commands.help import HelpCommand + + +class CustomHelp(HelpCommand): + async def send_command_help(self, ctx, command): + await ctx.send(f"custom {command.name}") + + async def send_group_help(self, ctx, group): + await ctx.send(f"group {group.name}") + + +@pytest.mark.asyncio +async def test_custom_help_methods(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + handler.remove_command("help") + handler.add_command(CustomHelp(handler)) + + async def sub(ctx): + pass + + group = Group(sub, name="grp") + handler.add_command(group) + + msg_data = { + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "!help grp", + "timestamp": "t", + } + msg = Message(msg_data, client_instance=bot) + await handler.process_commands(msg) + assert any("group grp" in m for m in bot.sent) From 132521fa396d278e2e1d0f9459e352e425df0f32 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:39:23 -0600 Subject: [PATCH 05/10] Add Object class and partial docs (#113) --- disagreement/__init__.py | 2 ++ disagreement/object.py | 19 +++++++++++++++++++ docs/caching.md | 7 +++++++ tests/test_object.py | 15 +++++++++++++++ 4 files changed, 43 insertions(+) create mode 100644 disagreement/object.py create mode 100644 tests/test_object.py diff --git a/disagreement/__init__.py b/disagreement/__init__.py index fc44bfd..3c88eab 100644 --- a/disagreement/__init__.py +++ b/disagreement/__init__.py @@ -40,6 +40,7 @@ from .models import ( Container, Guild, ) +from .object import Object from .voice_client import VoiceClient from .audio import AudioSource, FFmpegAudioSource from .typing import Typing @@ -148,6 +149,7 @@ __all__ = [ "MediaGallery", "MediaGalleryItem", "Container", + "Object", "VoiceClient", "AudioSource", "FFmpegAudioSource", diff --git a/disagreement/object.py b/disagreement/object.py new file mode 100644 index 0000000..f4d6dfc --- /dev/null +++ b/disagreement/object.py @@ -0,0 +1,19 @@ +class Object: + """A minimal wrapper around a Discord snowflake ID.""" + + __slots__ = ("id",) + + def __init__(self, object_id: int) -> None: + self.id = int(object_id) + + def __int__(self) -> int: + return self.id + + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Object) and self.id == other.id + + def __repr__(self) -> str: + return f"" diff --git a/docs/caching.md b/docs/caching.md index 98bcc7d..f1add20 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -28,6 +28,13 @@ The cache can be cleared manually if needed: client.cache.clear() ``` +## Partial Objects + +Some events only include minimal data for related resources. When only an ``id`` +is available, Disagreement represents the resource using :class:`~disagreement.Object`. +These objects can be compared and used in sets or dictionaries and can be passed +to API methods to fetch the full data when needed. + ## Next Steps - [Components](using_components.md) diff --git a/tests/test_object.py b/tests/test_object.py new file mode 100644 index 0000000..03e60f0 --- /dev/null +++ b/tests/test_object.py @@ -0,0 +1,15 @@ +from disagreement.object import Object + + +def test_object_int(): + obj = Object(123) + assert int(obj) == 123 + + +def test_object_equality_and_hash(): + a = Object(1) + b = Object(1) + c = Object(2) + assert a == b + assert a != c + assert hash(a) == hash(b) From e2061adc5519d06a2dd766243c3d3cc94891602a Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:39:26 -0600 Subject: [PATCH 06/10] Add logout method (#114) --- README.md | 4 ++++ disagreement/client.py | 26 +++++++++++++++++--------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index b2c56ab..6709933 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,10 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying ``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other session parameter supported by ``aiohttp``. +### Logging Out + +Call ``Client.logout`` to disconnect from the Gateway and clear the current bot token while keeping the HTTP session alive. Assign a new token and call ``connect`` or ``run`` to log back in. + ### Default Allowed Mentions Specify default mention behaviour for all outgoing messages when constructing the client: diff --git a/disagreement/client.py b/disagreement/client.py index b474477..fecc20e 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -454,15 +454,23 @@ class Client: await self.close() return False - async def close_gateway(self, code: int = 1000) -> None: - """Closes only the gateway connection, allowing for potential reconnect.""" - if self._shard_manager: - await self._shard_manager.close() - self._shard_manager = None - if self._gateway: - await self._gateway.close(code=code) - self._gateway = None - self._ready_event.clear() # No longer ready if gateway is closed + async def close_gateway(self, code: int = 1000) -> None: + """Closes only the gateway connection, allowing for potential reconnect.""" + if self._shard_manager: + await self._shard_manager.close() + self._shard_manager = None + if self._gateway: + await self._gateway.close(code=code) + self._gateway = None + self._ready_event.clear() # No longer ready if gateway is closed + + async def logout(self) -> None: + """Invalidate the bot token and disconnect from the Gateway.""" + await self.close_gateway() + self.token = "" + self._http.token = "" + self.user = None + self.start_time = None def is_closed(self) -> bool: """Indicates if the client has been closed.""" From 506adeca201209549c717e47c8310fbe7fc144e2 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:39:28 -0600 Subject: [PATCH 07/10] Add channel and emoji converters (#112) --- disagreement/ext/commands/converters.py | 91 ++++++++++++++++++++++++- tests/test_additional_converters.py | 91 ++++++++++++++++++++++--- 2 files changed, 173 insertions(+), 9 deletions(-) diff --git a/disagreement/ext/commands/converters.py b/disagreement/ext/commands/converters.py index ea09879..4235234 100644 --- a/disagreement/ext/commands/converters.py +++ b/disagreement/ext/commands/converters.py @@ -6,7 +6,16 @@ import re import inspect from .errors import BadArgument -from disagreement.models import Member, Guild, Role, User +from disagreement.models import ( + Member, + Guild, + Role, + User, + TextChannel, + VoiceChannel, + Emoji, + PartialEmoji, +) if TYPE_CHECKING: from .core import CommandContext @@ -158,6 +167,82 @@ class UserConverter(Converter["User"]): raise BadArgument(f"User '{argument}' not found.") +class TextChannelConverter(Converter["TextChannel"]): + async def convert(self, ctx: "CommandContext", argument: str) -> "TextChannel": + if not ctx.message.guild_id: + raise BadArgument("TextChannel converter requires guild context.") + + match = re.match(r"<#(?P\d+)>$", argument) + channel_id = match.group("id") if match else argument + + guild = ctx.bot.get_guild(ctx.message.guild_id) + if guild: + channel = guild.get_channel(channel_id) + if isinstance(channel, TextChannel): + return channel + + channel = ( + ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None + ) + if isinstance(channel, TextChannel): + return channel + + if hasattr(ctx.bot, "fetch_channel"): + channel = await ctx.bot.fetch_channel(channel_id) + if isinstance(channel, TextChannel): + return channel + + raise BadArgument(f"Text channel '{argument}' not found.") + + +class VoiceChannelConverter(Converter["VoiceChannel"]): + async def convert(self, ctx: "CommandContext", argument: str) -> "VoiceChannel": + if not ctx.message.guild_id: + raise BadArgument("VoiceChannel converter requires guild context.") + + match = re.match(r"<#(?P\d+)>$", argument) + channel_id = match.group("id") if match else argument + + guild = ctx.bot.get_guild(ctx.message.guild_id) + if guild: + channel = guild.get_channel(channel_id) + if isinstance(channel, VoiceChannel): + return channel + + channel = ( + ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None + ) + if isinstance(channel, VoiceChannel): + return channel + + if hasattr(ctx.bot, "fetch_channel"): + channel = await ctx.bot.fetch_channel(channel_id) + if isinstance(channel, VoiceChannel): + return channel + + raise BadArgument(f"Voice channel '{argument}' not found.") + + +class EmojiConverter(Converter["PartialEmoji"]): + _CUSTOM_RE = re.compile(r"<(?Pa)?:(?P[^:]+):(?P\d+)>$") + + async def convert(self, ctx: "CommandContext", argument: str) -> "PartialEmoji": + match = self._CUSTOM_RE.match(argument) + if match: + return PartialEmoji( + { + "id": match.group("id"), + "name": match.group("name"), + "animated": bool(match.group("animated")), + } + ) + + if argument: + return PartialEmoji({"id": None, "name": argument}) + + raise BadArgument(f"Emoji '{argument}' not found.") + + # Default converters mapping DEFAULT_CONVERTERS: dict[type, Converter[Any]] = { int: IntConverter(), @@ -168,6 +253,10 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = { Guild: GuildConverter(), Role: RoleConverter(), User: UserConverter(), + TextChannel: TextChannelConverter(), + VoiceChannel: VoiceChannelConverter(), + PartialEmoji: EmojiConverter(), + Emoji: EmojiConverter(), } diff --git a/tests/test_additional_converters.py b/tests/test_additional_converters.py index 8ee751f..952edc7 100644 --- a/tests/test_additional_converters.py +++ b/tests/test_additional_converters.py @@ -3,7 +3,16 @@ import pytest from disagreement.ext.commands.converters import run_converters from disagreement.ext.commands.core import CommandContext, Command from disagreement.ext.commands.errors import BadArgument -from disagreement.models import Message, Member, Role, Guild, User +from disagreement.models import ( + Message, + Member, + Role, + Guild, + User, + TextChannel, + VoiceChannel, + PartialEmoji, +) from disagreement.enums import ( VerificationLevel, MessageNotificationLevel, @@ -11,11 +20,12 @@ from disagreement.enums import ( MFALevel, GuildNSFWLevel, PremiumTier, + ChannelType, ) from disagreement.client import Client -from disagreement.cache import GuildCache, Cache +from disagreement.cache import GuildCache, Cache, ChannelCache class DummyBot(Client): @@ -23,10 +33,14 @@ class DummyBot(Client): super().__init__(token="test") self._guilds = GuildCache() self._users = Cache() + self._channels = ChannelCache() def get_guild(self, guild_id): return self._guilds.get(guild_id) + def get_channel(self, channel_id): + return self._channels.get(channel_id) + async def fetch_member(self, guild_id, member_id): guild = self._guilds.get(guild_id) return guild.get_member(member_id) if guild else None @@ -41,6 +55,9 @@ class DummyBot(Client): async def fetch_user(self, user_id): return self._users.get(user_id) + async def fetch_channel(self, channel_id): + return self._channels.get(channel_id) + @pytest.fixture() def guild_objects(): @@ -93,12 +110,38 @@ def guild_objects(): guild._members.set(member.id, member) guild.roles.append(role) - return guild, member, role, user + text_channel = TextChannel( + { + "id": "20", + "type": ChannelType.GUILD_TEXT.value, + "guild_id": guild.id, + "permission_overwrites": [], + }, + client_instance=bot, + ) + voice_channel = VoiceChannel( + { + "id": "21", + "type": ChannelType.GUILD_VOICE.value, + "guild_id": guild.id, + "permission_overwrites": [], + }, + client_instance=bot, + ) + + guild._channels.set(text_channel.id, text_channel) + guild.text_channels.append(text_channel) + guild._channels.set(voice_channel.id, voice_channel) + guild.voice_channels.append(voice_channel) + bot._channels.set(text_channel.id, text_channel) + bot._channels.set(voice_channel.id, voice_channel) + + return guild, member, role, user, text_channel, voice_channel @pytest.fixture() def command_context(guild_objects): - guild, member, role, _ = guild_objects + guild, member, role, _, _, _ = guild_objects bot = guild._client message_data = { "id": "10", @@ -121,7 +164,7 @@ def command_context(guild_objects): @pytest.mark.asyncio async def test_member_converter(command_context, guild_objects): - _, member, _, _ = guild_objects + _, member, _, _, _, _ = guild_objects mention = f"<@!{member.id}>" result = await run_converters(command_context, Member, mention) assert result is member @@ -131,7 +174,7 @@ async def test_member_converter(command_context, guild_objects): @pytest.mark.asyncio async def test_role_converter(command_context, guild_objects): - _, _, role, _ = guild_objects + _, _, role, _, _, _ = guild_objects mention = f"<@&{role.id}>" result = await run_converters(command_context, Role, mention) assert result is role @@ -141,7 +184,7 @@ async def test_role_converter(command_context, guild_objects): @pytest.mark.asyncio async def test_user_converter(command_context, guild_objects): - _, _, _, user = guild_objects + _, _, _, user, _, _ = guild_objects mention = f"<@{user.id}>" result = await run_converters(command_context, User, mention) assert result is user @@ -151,11 +194,43 @@ async def test_user_converter(command_context, guild_objects): @pytest.mark.asyncio async def test_guild_converter(command_context, guild_objects): - guild, _, _, _ = guild_objects + guild, _, _, _, _, _ = guild_objects result = await run_converters(command_context, Guild, guild.id) assert result is guild +@pytest.mark.asyncio +async def test_text_channel_converter(command_context, guild_objects): + _, _, _, _, text_channel, _ = guild_objects + mention = f"<#{text_channel.id}>" + result = await run_converters(command_context, TextChannel, mention) + assert result is text_channel + result = await run_converters(command_context, TextChannel, text_channel.id) + assert result is text_channel + + +@pytest.mark.asyncio +async def test_voice_channel_converter(command_context, guild_objects): + _, _, _, _, _, voice_channel = guild_objects + mention = f"<#{voice_channel.id}>" + result = await run_converters(command_context, VoiceChannel, mention) + assert result is voice_channel + result = await run_converters(command_context, VoiceChannel, voice_channel.id) + assert result is voice_channel + + +@pytest.mark.asyncio +async def test_emoji_converter(command_context): + result = await run_converters(command_context, PartialEmoji, "<:smile:1>") + assert isinstance(result, PartialEmoji) + assert result.id == "1" + assert result.name == "smile" + + result = await run_converters(command_context, PartialEmoji, "😄") + assert result.id is None + assert result.name == "😄" + + @pytest.mark.asyncio async def test_member_converter_no_guild(): guild_data = { From d710487fc2bb3dba570af16e57b2261107eede04 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:39:30 -0600 Subject: [PATCH 08/10] Add voice playback control (#111) --- disagreement/voice_client.py | 96 +++++++++++++++++++++++------------- docs/voice_features.md | 4 ++ tests/test_voice_client.py | 57 +++++++++++++++++++++ 3 files changed, 124 insertions(+), 33 deletions(-) diff --git a/disagreement/voice_client.py b/disagreement/voice_client.py index 41df380..4a6f9eb 100644 --- a/disagreement/voice_client.py +++ b/disagreement/voice_client.py @@ -77,11 +77,14 @@ class VoiceClient: self.secret_key: Optional[Sequence[int]] = None self._server_ip: Optional[str] = None self._server_port: Optional[int] = None - self._current_source: Optional[AudioSource] = None - self._play_task: Optional[asyncio.Task] = None - self._sink: Optional[AudioSink] = None - self._ssrc_map: dict[int, int] = {} - self._ssrc_lock = threading.Lock() + self._current_source: Optional[AudioSource] = None + self._play_task: Optional[asyncio.Task] = None + self._pause_event = asyncio.Event() + self._pause_event.set() + self._is_playing = False + self._sink: Optional[AudioSink] = None + self._ssrc_map: dict[int, int] = {} + self._ssrc_lock = threading.Lock() async def connect(self) -> None: if self._ws is None: @@ -189,31 +192,37 @@ class VoiceClient: raise RuntimeError("UDP socket not initialised") self._udp.send(frame) - async def _play_loop(self) -> None: - assert self._current_source is not None - try: - while True: - data = await self._current_source.read() - if not data: - break - volume = getattr(self._current_source, "volume", 1.0) - if volume != 1.0: - data = _apply_volume(data, volume) - await self.send_audio_frame(data) - finally: - await self._current_source.close() - self._current_source = None - self._play_task = None + async def _play_loop(self) -> None: + assert self._current_source is not None + self._is_playing = True + try: + while True: + await self._pause_event.wait() + data = await self._current_source.read() + if not data: + break + volume = getattr(self._current_source, "volume", 1.0) + if volume != 1.0: + data = _apply_volume(data, volume) + await self.send_audio_frame(data) + finally: + await self._current_source.close() + self._current_source = None + self._play_task = None + self._is_playing = False + self._pause_event.set() - async def stop(self) -> None: - if self._play_task: - self._play_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._play_task - self._play_task = None - if self._current_source: - await self._current_source.close() - self._current_source = None + async def stop(self) -> None: + if self._play_task: + self._play_task.cancel() + self._pause_event.set() + with contextlib.suppress(asyncio.CancelledError): + await self._play_task + self._play_task = None + self._is_playing = False + if self._current_source: + await self._current_source.close() + self._current_source = None async def play(self, source: AudioSource, *, wait: bool = True) -> None: """|coro| Play an :class:`AudioSource` on the voice connection.""" @@ -224,10 +233,31 @@ class VoiceClient: if wait: await self._play_task - async def play_file(self, filename: str, *, wait: bool = True) -> None: - """|coro| Stream an audio file or URL using FFmpeg.""" - - await self.play(FFmpegAudioSource(filename), wait=wait) + async def play_file(self, filename: str, *, wait: bool = True) -> None: + """|coro| Stream an audio file or URL using FFmpeg.""" + + await self.play(FFmpegAudioSource(filename), wait=wait) + + def pause(self) -> None: + """Pause the current audio source.""" + + if self._play_task and not self._play_task.done(): + self._pause_event.clear() + + def resume(self) -> None: + """Resume playback of a paused source.""" + + if self._play_task and not self._play_task.done(): + self._pause_event.set() + + def is_paused(self) -> bool: + """Return ``True`` if playback is currently paused.""" + + return bool(self._play_task and not self._pause_event.is_set()) + + def is_playing(self) -> bool: + """Return ``True`` if audio is actively being played.""" + return self._is_playing and self._pause_event.is_set() def listen(self, sink: AudioSink) -> None: """Start listening to voice and routing to a sink.""" diff --git a/docs/voice_features.md b/docs/voice_features.md index 4391862..28014c0 100644 --- a/docs/voice_features.md +++ b/docs/voice_features.md @@ -6,6 +6,10 @@ Disagreement includes experimental support for connecting to voice channels. You voice = await client.join_voice(guild_id, channel_id) await voice.play_file("welcome.mp3") await voice.play_file("another.mp3") # switch sources while connected +voice.pause() +voice.resume() +if voice.is_playing(): + print("audio is playing") await voice.close() ``` diff --git a/tests/test_voice_client.py b/tests/test_voice_client.py index 1c6cb98..8881979 100644 --- a/tests/test_voice_client.py +++ b/tests/test_voice_client.py @@ -59,6 +59,17 @@ class DummySource(AudioSource): return b"" +class SlowSource(AudioSource): + def __init__(self, chunks): + self.chunks = list(chunks) + + async def read(self) -> bytes: + await asyncio.sleep(0) + if self.chunks: + return self.chunks.pop(0) + return b"" + + @pytest.mark.asyncio async def test_voice_client_handshake(): hello = {"d": {"heartbeat_interval": 50}} @@ -205,3 +216,49 @@ async def test_voice_client_volume_scaling(monkeypatch): samples[1] = int(samples[1] * 0.5) expected = samples.tobytes() assert udp.sent == [expected] + + +@pytest.mark.asyncio +async def test_pause_resume_and_status(): + ws = DummyWebSocket( + [ + {"d": {"heartbeat_interval": 50}}, + {"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}}, + {"d": {"secret_key": []}}, + ] + ) + udp = DummyUDP() + vc = VoiceClient( + client=DummyVoiceClient(), + endpoint="ws://localhost", + session_id="sess", + token="tok", + guild_id=1, + user_id=2, + ws=ws, + udp=udp, + ) + await vc.connect() + vc._heartbeat_task.cancel() + + src = SlowSource([b"a", b"b", b"c"]) + await vc.play(src, wait=False) + + while not udp.sent: + await asyncio.sleep(0) + + assert vc.is_playing() + vc.pause() + assert vc.is_paused() + await asyncio.sleep(0) + sent = len(udp.sent) + await asyncio.sleep(0.01) + assert len(udp.sent) == sent + assert not vc.is_playing() + + vc.resume() + assert not vc.is_paused() + + await vc._play_task + assert udp.sent == [b"a", b"b", b"c"] + assert not vc.is_playing() From 8e88aaec2f360ce0091776636c9d4cc96186f6d5 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:42:21 -0600 Subject: [PATCH 09/10] Implement channel and member aggregation (#117) --- disagreement/cache.py | 4 +- disagreement/caching.py | 8 ++ disagreement/client.py | 28 +++- disagreement/http.py | 173 ++++++++++++------------ disagreement/models.py | 286 ++++++++++++++++++++-------------------- tests/test_cache.py | 73 ++++++++++ 6 files changed, 339 insertions(+), 233 deletions(-) diff --git a/disagreement/cache.py b/disagreement/cache.py index 32c6639..456797c 100644 --- a/disagreement/cache.py +++ b/disagreement/cache.py @@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]): def _should_cache(self, member: Member) -> bool: """Determines if a member should be cached based on the flags.""" - if self.flags.all: + if self.flags.all_enabled: return True - if self.flags.none: + if self.flags.no_flags: return False if self.flags.online and member.status != "offline": diff --git a/disagreement/caching.py b/disagreement/caching.py index e9a481f..9decdd2 100644 --- a/disagreement/caching.py +++ b/disagreement/caching.py @@ -74,6 +74,14 @@ class MemberCacheFlags: for name in self.VALID_FLAGS: yield name, getattr(self, name) + @property + def all_enabled(self) -> bool: + return self.value == self.ALL_FLAGS + + @property + def no_flags(self) -> bool: + return self.value == 0 + def __int__(self) -> int: return self.value diff --git a/disagreement/client.py b/disagreement/client.py index fecc20e..2c4d4a9 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -1448,10 +1448,30 @@ class Client: return self._channels.get(channel_id) - def get_message(self, message_id: Snowflake) -> Optional["Message"]: - """Returns a message from the internal cache.""" - - return self._messages.get(message_id) + def get_message(self, message_id: Snowflake) -> Optional["Message"]: + """Returns a message from the internal cache.""" + + return self._messages.get(message_id) + + def get_all_channels(self) -> List["Channel"]: + """Return all channels cached in every guild.""" + + channels: List["Channel"] = [] + for guild in self._guilds.values(): + channels.extend(guild._channels.values()) + return channels + + def get_all_members(self) -> List["Member"]: + """Return all cached members across all guilds. + + When member caching is disabled via :class:`MemberCacheFlags.none`, this + list will always be empty. + """ + + members: List["Member"] = [] + for guild in self._guilds.values(): + members.extend(guild._members.values()) + return members async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]: """Fetches a guild by ID from Discord and caches it.""" diff --git a/disagreement/http.py b/disagreement/http.py index 2cb9bca..1831a6f 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -656,21 +656,21 @@ class HTTPClient: await self.request("PUT", f"/channels/{channel_id}/pins/{message_id}") - async def unpin_message( - self, channel_id: "Snowflake", message_id: "Snowflake" - ) -> None: - """Unpins a message from a channel.""" - - await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}") - - async def crosspost_message( - self, channel_id: "Snowflake", message_id: "Snowflake" - ) -> Dict[str, Any]: - """Crossposts a message to any following channels.""" - - return await self.request( - "POST", f"/channels/{channel_id}/messages/{message_id}/crosspost" - ) + async def unpin_message( + self, channel_id: "Snowflake", message_id: "Snowflake" + ) -> None: + """Unpins a message from a channel.""" + + await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}") + + async def crosspost_message( + self, channel_id: "Snowflake", message_id: "Snowflake" + ) -> Dict[str, Any]: + """Crossposts a message to any following channels.""" + + return await self.request( + "POST", f"/channels/{channel_id}/messages/{message_id}/crosspost" + ) async def delete_channel( self, channel_id: str, reason: Optional[str] = None @@ -734,63 +734,68 @@ class HTTPClient: return await self.request("GET", f"/channels/{channel_id}/invites") - async def create_invite( - self, channel_id: "Snowflake", payload: Dict[str, Any] - ) -> "Invite": - """Creates an invite for a channel.""" + async def create_invite( + self, channel_id: "Snowflake", payload: Dict[str, Any] + ) -> "Invite": + """Creates an invite for a channel.""" data = await self.request( "POST", f"/channels/{channel_id}/invites", payload=payload ) - from .models import Invite - - return Invite.from_dict(data) - - async def create_channel_invite( - self, - channel_id: "Snowflake", - payload: Dict[str, Any], - *, - reason: Optional[str] = None, - ) -> "Invite": - """Creates an invite for a channel with an optional audit log reason.""" - - headers = {"X-Audit-Log-Reason": reason} if reason else None - data = await self.request( - "POST", - f"/channels/{channel_id}/invites", - payload=payload, - custom_headers=headers, - ) - from .models import Invite - - return Invite.from_dict(data) + from .models import Invite - async def delete_invite(self, code: str) -> None: - """Deletes an invite by code.""" - - await self.request("DELETE", f"/invites/{code}") - - async def get_invite(self, code: "Snowflake") -> Dict[str, Any]: - """Fetches a single invite by its code.""" - - return await self.request("GET", f"/invites/{code}") + return Invite.from_dict(data) - async def create_webhook( - self, channel_id: "Snowflake", payload: Dict[str, Any] - ) -> "Webhook": - """Creates a webhook in the specified channel.""" + async def create_channel_invite( + self, + channel_id: "Snowflake", + payload: Dict[str, Any], + *, + reason: Optional[str] = None, + ) -> "Invite": + """Creates an invite for a channel with an optional audit log reason.""" + + headers = {"X-Audit-Log-Reason": reason} if reason else None + data = await self.request( + "POST", + f"/channels/{channel_id}/invites", + payload=payload, + custom_headers=headers, + ) + from .models import Invite + + return Invite.from_dict(data) + + async def delete_invite(self, code: str) -> None: + """Deletes an invite by code.""" + + await self.request("DELETE", f"/invites/{code}") + + async def get_invite(self, code: "Snowflake") -> Dict[str, Any]: + """Fetches a single invite by its code.""" + + return await self.request("GET", f"/invites/{code}") + + async def create_webhook( + self, channel_id: "Snowflake", payload: Dict[str, Any] + ) -> "Webhook": + """Creates a webhook in the specified channel.""" data = await self.request( "POST", f"/channels/{channel_id}/webhooks", payload=payload ) - from .models import Webhook - - return Webhook(data) - - async def edit_webhook( - self, webhook_id: "Snowflake", payload: Dict[str, Any] - ) -> "Webhook": + from .models import Webhook + + return Webhook(data) + + async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]: + """Fetches a webhook by ID and returns the raw payload.""" + + return await self.request("GET", f"/webhooks/{webhook_id}") + + async def edit_webhook( + self, webhook_id: "Snowflake", payload: Dict[str, Any] + ) -> "Webhook": """Edits an existing webhook.""" data = await self.request("PATCH", f"/webhooks/{webhook_id}", payload=payload) @@ -798,28 +803,28 @@ class HTTPClient: return Webhook(data) - async def delete_webhook(self, webhook_id: "Snowflake") -> None: - """Deletes a webhook.""" - - await self.request("DELETE", f"/webhooks/{webhook_id}") - - async def get_webhook( - self, webhook_id: "Snowflake", token: Optional[str] = None - ) -> "Webhook": - """Fetches a webhook by ID, optionally using its token.""" - - endpoint = f"/webhooks/{webhook_id}" - use_auth = True - if token is not None: - endpoint += f"/{token}" - use_auth = False - if use_auth: - data = await self.request("GET", endpoint) - else: - data = await self.request("GET", endpoint, use_auth_header=False) - from .models import Webhook - - return Webhook(data) + async def delete_webhook(self, webhook_id: "Snowflake") -> None: + """Deletes a webhook.""" + + await self.request("DELETE", f"/webhooks/{webhook_id}") + + async def get_webhook_with_token( + self, webhook_id: "Snowflake", token: Optional[str] = None + ) -> "Webhook": + """Fetches a webhook by ID, optionally using its token.""" + + endpoint = f"/webhooks/{webhook_id}" + use_auth = True + if token is not None: + endpoint += f"/{token}" + use_auth = False + if use_auth: + data = await self.request("GET", endpoint) + else: + data = await self.request("GET", endpoint, use_auth_header=False) + from .models import Webhook + + return Webhook(data) async def execute_webhook( self, diff --git a/disagreement/models.py b/disagreement/models.py index 1f09d7c..f302664 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1,8 +1,8 @@ -""" -Data models for Discord objects. -""" - -from __future__ import annotations +""" +Data models for Discord objects. +""" + +from __future__ import annotations import asyncio import datetime @@ -49,14 +49,14 @@ from .enums import ( # These enums will need to be defined in disagreement/enum from .permissions import Permissions -if TYPE_CHECKING: - from .client import Client # For type hinting to avoid circular imports - from .enums import OverwriteType # For PermissionOverwrite model - from .ui.view import View - from .interactions import Snowflake - from .typing import Typing - from .shard_manager import Shard - from .asset import Asset +if TYPE_CHECKING: + from .client import Client # For type hinting to avoid circular imports + from .enums import OverwriteType # For PermissionOverwrite model + from .ui.view import View + from .interactions import Snowflake + from .typing import Typing + from .shard_manager import Shard + from .asset import Asset # Forward reference Message if it were used in type hints before its definition # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. @@ -104,23 +104,23 @@ class User(HashableById): return f"" @property - def avatar(self) -> Optional["Asset"]: - """Return the user's avatar as an :class:`Asset`.""" - - if self._avatar: - from .asset import Asset - - return Asset(self._avatar, self._client) - return None - - @avatar.setter - def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._avatar = value - elif value is None: - self._avatar = None - else: - self._avatar = value.url + def avatar(self) -> Optional["Asset"]: + """Return the user's avatar as an :class:`Asset`.""" + + if self._avatar: + from .asset import Asset + + return Asset(self._avatar, self._client) + return None + + @avatar.setter + def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._avatar = value + elif value is None: + self._avatar = None + else: + self._avatar = value.url async def send( self, @@ -843,21 +843,21 @@ class Role: return f"" @property - def icon(self) -> Optional["Asset"]: - if self._icon: - from .asset import Asset - - return Asset(self._icon, None) - return None - - @icon.setter - def icon(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._icon = value - elif value is None: - self._icon = None - else: - self._icon = value.url + def icon(self) -> Optional["Asset"]: + if self._icon: + from .asset import Asset + + return Asset(self._icon, None) + return None + + @icon.setter + def icon(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._icon = value + elif value is None: + self._icon = None + else: + self._icon = value.url class Member(User): # Member inherits from User @@ -924,23 +924,23 @@ class Member(User): # Member inherits from User return f"" @property - def avatar(self) -> Optional["Asset"]: - """Return the member's avatar as an :class:`Asset`.""" - - if self._avatar: - from .asset import Asset - - return Asset(self._avatar, self._client) - return None - - @avatar.setter - def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._avatar = value - elif value is None: - self._avatar = None - else: - self._avatar = value.url + def avatar(self) -> Optional["Asset"]: + """Return the member's avatar as an :class:`Asset`.""" + + if self._avatar: + from .asset import Asset + + return Asset(self._avatar, self._client) + return None + + @avatar.setter + def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._avatar = value + elif value is None: + self._avatar = None + else: + self._avatar = value.url @property def display_name(self) -> str: @@ -1034,10 +1034,10 @@ class Member(User): # Member inherits from User return Permissions(~0) return base - - @property - def voice(self) -> Optional["VoiceState"]: - """Return the member's cached voice state as a :class:`VoiceState`.""" + + @property + def voice(self) -> Optional["VoiceState"]: + """Return the member's cached voice state as a :class:`VoiceState`.""" if self.voice_state is None: return None @@ -1466,72 +1466,72 @@ class Guild(HashableById): return f"" @property - def icon(self) -> Optional["Asset"]: - if self._icon: - from .asset import Asset - - return Asset(self._icon, self._client) - return None - - @icon.setter - def icon(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._icon = value - elif value is None: - self._icon = None - else: - self._icon = value.url + def icon(self) -> Optional["Asset"]: + if self._icon: + from .asset import Asset + + return Asset(self._icon, self._client) + return None + + @icon.setter + def icon(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._icon = value + elif value is None: + self._icon = None + else: + self._icon = value.url @property - def splash(self) -> Optional["Asset"]: - if self._splash: - from .asset import Asset - - return Asset(self._splash, self._client) - return None - - @splash.setter - def splash(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._splash = value - elif value is None: - self._splash = None - else: - self._splash = value.url + def splash(self) -> Optional["Asset"]: + if self._splash: + from .asset import Asset + + return Asset(self._splash, self._client) + return None + + @splash.setter + def splash(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._splash = value + elif value is None: + self._splash = None + else: + self._splash = value.url @property - def discovery_splash(self) -> Optional["Asset"]: - if self._discovery_splash: - from .asset import Asset - - return Asset(self._discovery_splash, self._client) - return None - - @discovery_splash.setter - def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._discovery_splash = value - elif value is None: - self._discovery_splash = None - else: - self._discovery_splash = value.url + def discovery_splash(self) -> Optional["Asset"]: + if self._discovery_splash: + from .asset import Asset + + return Asset(self._discovery_splash, self._client) + return None + + @discovery_splash.setter + def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._discovery_splash = value + elif value is None: + self._discovery_splash = None + else: + self._discovery_splash = value.url @property - def banner(self) -> Optional["Asset"]: - if self._banner: - from .asset import Asset - - return Asset(self._banner, self._client) - return None - - @banner.setter - def banner(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._banner = value - elif value is None: - self._banner = None - else: - self._banner = value.url + def banner(self) -> Optional["Asset"]: + if self._banner: + from .asset import Asset + + return Asset(self._banner, self._client) + return None + + @banner.setter + def banner(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._banner = value + elif value is None: + self._banner = None + else: + self._banner = value.url async def fetch_widget(self) -> Dict[str, Any]: """|coro| Fetch this guild's widget settings.""" @@ -2280,23 +2280,23 @@ class Webhook: return f"" @property - def avatar(self) -> Optional["Asset"]: - """Return the webhook's avatar as an :class:`Asset`.""" - - if self._avatar: - from .asset import Asset - - return Asset(self._avatar, self._client) - return None - - @avatar.setter - def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._avatar = value - elif value is None: - self._avatar = None - else: - self._avatar = value.url + def avatar(self) -> Optional["Asset"]: + """Return the webhook's avatar as an :class:`Asset`.""" + + if self._avatar: + from .asset import Asset + + return Asset(self._avatar, self._client) + return None + + @avatar.setter + def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._avatar = value + elif value is None: + self._avatar = None + else: + self._avatar = value.url @classmethod def from_url( diff --git a/tests/test_cache.py b/tests/test_cache.py index 88effe9..5dbef27 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,6 +1,60 @@ import time from disagreement.cache import Cache +from disagreement.client import Client +from disagreement.caching import MemberCacheFlags +from disagreement.enums import ( + ChannelType, + ExplicitContentFilterLevel, + GuildNSFWLevel, + MFALevel, + MessageNotificationLevel, + PremiumTier, + VerificationLevel, +) + + +def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict: + base = { + "id": gid, + "name": f"g{gid}", + "owner_id": "1", + "afk_timeout": 60, + "verification_level": VerificationLevel.NONE.value, + "default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value, + "explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value, + "roles": [], + "emojis": [], + "features": [], + "mfa_level": MFALevel.NONE.value, + "system_channel_flags": 0, + "premium_tier": PremiumTier.NONE.value, + "nsfw_level": GuildNSFWLevel.DEFAULT.value, + "channels": [], + "members": [], + } + for i in range(channel_count): + base["channels"].append( + { + "id": f"{gid}-c{i}", + "type": ChannelType.GUILD_TEXT.value, + "guild_id": gid, + "permission_overwrites": [], + } + ) + for i in range(member_count): + base["members"].append( + { + "user": { + "id": f"{gid}-m{i}", + "username": f"u{i}", + "discriminator": "0001", + }, + "joined_at": "t", + "roles": [], + } + ) + return base def test_cache_store_and_get(): @@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item(): assert cache.get_or_fetch("c", fetch) == 3 assert called + + +def test_client_get_all_channels_and_members(): + client = Client(token="t") + client.parse_guild(_guild_payload("1", 2, 2)) + client.parse_guild(_guild_payload("2", 1, 1)) + + channels = {c.id for c in client.get_all_channels()} + members = {m.id for m in client.get_all_members()} + + assert channels == {"1-c0", "1-c1", "2-c0"} + assert members == {"1-m0", "1-m1", "2-m0"} + + +def test_client_get_all_members_disabled_cache(): + client = Client(token="t", member_cache_flags=MemberCacheFlags.none()) + client.parse_guild(_guild_payload("1", 1, 2)) + + assert client.get_all_members() == [] From e5ad932321d18f24824b81bd0ad88366cbff6284 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:46:20 -0600 Subject: [PATCH 10/10] Add recursive command enumeration (#115) --- disagreement/ext/commands/core.py | 18 ++++++++++++++-- disagreement/ext/commands/help.py | 10 +++++++++ tests/test_help_command.py | 33 ++++++++++++++++++++++++++--- tests/test_walk_commands.py | 35 +++++++++++++++++++++++++++++++ tests/test_webhooks.py | 2 +- 5 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 tests/test_walk_commands.py diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py index 663cb65..ee841c8 100644 --- a/disagreement/ext/commands/core.py +++ b/disagreement/ext/commands/core.py @@ -79,8 +79,15 @@ class GroupMixin: ) self.commands[alias.lower()] = command - def get_command(self, name: str) -> Optional["Command"]: - return self.commands.get(name.lower()) + def get_command(self, name: str) -> Optional["Command"]: + return self.commands.get(name.lower()) + + def walk_commands(self): + """Yield all commands in this group recursively.""" + for cmd in dict.fromkeys(self.commands.values()): + yield cmd + if isinstance(cmd, Group): + yield from cmd.walk_commands() class Command(GroupMixin): @@ -366,6 +373,13 @@ class CommandHandler: def get_command(self, name: str) -> Optional[Command]: return self.commands.get(name.lower()) + def walk_commands(self): + """Yield every registered command, including subcommands.""" + for cmd in dict.fromkeys(self.commands.values()): + yield cmd + if isinstance(cmd, Group): + yield from cmd.walk_commands() + def get_cog(self, name: str) -> Optional["Cog"]: """Return a loaded cog by name if present.""" diff --git a/disagreement/ext/commands/help.py b/disagreement/ext/commands/help.py index 516f5ce..f30292a 100644 --- a/disagreement/ext/commands/help.py +++ b/disagreement/ext/commands/help.py @@ -19,6 +19,16 @@ class HelpCommand(Command): return if isinstance(cmd, Group): await self.send_group_help(ctx, cmd) + elif cmd: + description = cmd.description or cmd.brief or "No description provided." + await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}") + else: + lines: List[str] = [] + for registered in handler.walk_commands(): + brief = registered.brief or registered.description or "" + lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip()) + if lines: + await ctx.send("\n".join(lines)) else: await self.send_command_help(ctx, cmd) else: diff --git a/tests/test_help_command.py b/tests/test_help_command.py index 6e1e30b..e3a1c04 100644 --- a/tests/test_help_command.py +++ b/tests/test_help_command.py @@ -3,6 +3,7 @@ import pytest from disagreement.ext import commands from disagreement.ext.commands.core import CommandHandler, Command, Group from disagreement.models import Message +from disagreement.ext.commands.help import HelpCommand class DummyBot: @@ -65,9 +66,6 @@ async def test_help_specific_command(): assert any("Bar desc" in m for m in bot.sent) -from disagreement.ext.commands.help import HelpCommand - - class CustomHelp(HelpCommand): async def send_command_help(self, ctx, command): await ctx.send(f"custom {command.name}") @@ -99,3 +97,32 @@ async def test_custom_help_methods(): msg = Message(msg_data, client_instance=bot) await handler.process_commands(msg) assert any("group grp" in m for m in bot.sent) + + +@pytest.mark.asyncio +async def test_help_lists_subcommands(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + + async def root(ctx): + pass + + group = Group(root, name="root") + + @group.command(name="child") + async def child(ctx): + pass + + handler.add_command(group) + + msg_data = { + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "!help", + "timestamp": "t", + } + msg = Message(msg_data, client_instance=bot) + await handler.process_commands(msg) + assert any("root" in m for m in bot.sent) + assert any("child" in m for m in bot.sent) diff --git a/tests/test_walk_commands.py b/tests/test_walk_commands.py new file mode 100644 index 0000000..ca34749 --- /dev/null +++ b/tests/test_walk_commands.py @@ -0,0 +1,35 @@ +import pytest + +from disagreement.ext.commands.core import CommandHandler, Command, Group + + +class DummyBot: + pass + + +@pytest.mark.asyncio +async def test_walk_commands_recurses_groups(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + + async def root(ctx): + pass + + root_group = Group(root, name="root") + + @root_group.command(name="child") + async def child(ctx): + pass + + @root_group.group(name="sub") + async def sub(ctx): + pass + + @sub.command(name="leaf") + async def leaf(ctx): + pass + + handler.add_command(root_group) + + names = [cmd.name for cmd in handler.walk_commands()] + assert set(names) == {"help", "root", "child", "sub", "leaf"} diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py index 4141197..dcb6539 100644 --- a/tests/test_webhooks.py +++ b/tests/test_webhooks.py @@ -204,7 +204,7 @@ async def test_get_webhook_calls_request(): await http.get_webhook("1") - http.request.assert_called_once_with("GET", "/webhooks/1") + http.request.assert_called_once_with("GET", "/webhooks/1", use_auth_header=True) @pytest.mark.asyncio