diff --git a/disagreement/client.py b/disagreement/client.py index e5604a3..df73b26 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -614,75 +614,75 @@ class Client: return decorator - def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None: - """ - Adds a standalone application command or group to the bot. - Use this for commands not defined within a Cog. - - Args: - command (Union[AppCommand, AppCommandGroup]): The application command or group instance. - This is typically the object returned by a decorator like @slash_command. - """ - from .ext.app_commands.commands import ( - AppCommand, - AppCommandGroup, - ) # Ensure types - - if not isinstance(command, (AppCommand, AppCommandGroup)): - raise TypeError( - "Command must be an instance of AppCommand or AppCommandGroup." - ) - - # If it's a decorated function, the command object might be on __app_command_object__ - if hasattr(command, "__app_command_object__") and isinstance( - getattr(command, "__app_command_object__"), (AppCommand, AppCommandGroup) - ): - actual_command_obj = getattr(command, "__app_command_object__") - self.app_command_handler.add_command(actual_command_obj) - print( - f"Registered standalone app command/group '{actual_command_obj.name}'." - ) - elif isinstance( - command, (AppCommand, AppCommandGroup) - ): # It's already the command object - self.app_command_handler.add_command(command) - print(f"Registered standalone app command/group '{command.name}'.") - else: - # This case should ideally not be hit if type checks are done by decorators - print( - f"Warning: Could not register app command {command}. It's not a recognized command object or decorated function." - ) - - async def on_command_error( - self, ctx: "CommandContext", error: "CommandError" - ) -> None: - """ - Default command error handler. Called when a command raises an error. - Users can override this method in a subclass of Client to implement custom error handling. - - Args: - ctx (CommandContext): The context of the command that raised the error. - error (CommandError): The error that was raised. - """ - # Default behavior: print to console. - # Users might want to send a message to ctx.channel or log to a file. - print( - f"Error in command '{ctx.command.name if ctx.command else 'unknown'}': {error}" - ) - - # Need to import CommandInvokeError for this check if not already globally available - # For now, assuming it's imported via TYPE_CHECKING or directly if needed at runtime - from .ext.commands.errors import ( - CommandInvokeError as CIE, - ) # Local import for isinstance check - - if isinstance(error, CIE): - # Now it's safe to access error.original - print( - f"Original exception: {type(error.original).__name__}: {error.original}" - ) - # import traceback - # traceback.print_exception(type(error.original), error.original, error.original.__traceback__) + def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None: + """ + Adds a standalone application command or group to the bot. + Use this for commands not defined within a Cog. + + Args: + command (Union[AppCommand, AppCommandGroup]): The application command or group instance. + This is typically the object returned by a decorator like @slash_command. + """ + from .ext.app_commands.commands import ( + AppCommand, + AppCommandGroup, + ) # Ensure types + + if not isinstance(command, (AppCommand, AppCommandGroup)): + raise TypeError( + "Command must be an instance of AppCommand or AppCommandGroup." + ) + + # If it's a decorated function, the command object might be on __app_command_object__ + if hasattr(command, "__app_command_object__") and isinstance( + getattr(command, "__app_command_object__"), (AppCommand, AppCommandGroup) + ): + actual_command_obj = getattr(command, "__app_command_object__") + self.app_command_handler.add_command(actual_command_obj) + print( + f"Registered standalone app command/group '{actual_command_obj.name}'." + ) + elif isinstance( + command, (AppCommand, AppCommandGroup) + ): # It's already the command object + self.app_command_handler.add_command(command) + print(f"Registered standalone app command/group '{command.name}'.") + else: + # This case should ideally not be hit if type checks are done by decorators + print( + f"Warning: Could not register app command {command}. It's not a recognized command object or decorated function." + ) + + async def on_command_error( + self, ctx: "CommandContext", error: "CommandError" + ) -> None: + """ + Default command error handler. Called when a command raises an error. + Users can override this method in a subclass of Client to implement custom error handling. + + Args: + ctx (CommandContext): The context of the command that raised the error. + error (CommandError): The error that was raised. + """ + # Default behavior: print to console. + # Users might want to send a message to ctx.channel or log to a file. + print( + f"Error in command '{ctx.command.name if ctx.command else 'unknown'}': {error}" + ) + + # Need to import CommandInvokeError for this check if not already globally available + # For now, assuming it's imported via TYPE_CHECKING or directly if needed at runtime + from .ext.commands.errors import ( + CommandInvokeError as CIE, + ) # Local import for isinstance check + + if isinstance(error, CIE): + # Now it's safe to access error.original + print( + f"Original exception: {type(error.original).__name__}: {error.original}" + ) + # import traceback + # traceback.print_exception(type(error.original), error.original, error.original.__traceback__) async def on_command_completion(self, ctx: "CommandContext") -> None: """ @@ -1551,86 +1551,86 @@ class Client: self._views[interaction.message.id] = new_view asyncio.create_task(new_view._dispatch(interaction)) return - - await self.app_command_handler.process_interaction(interaction) - - async def sync_application_commands( - self, guild_id: Optional[Snowflake] = None - ) -> None: - """Synchronizes application commands with Discord.""" - - if not self.application_id: - print( - "Warning: Cannot sync application commands, application_id is not set. " - "Ensure the client is connected and READY." - ) - return - if not self.is_ready(): - print( - "Warning: Client is not ready. Waiting for client to be ready before syncing commands." - ) - await self.wait_until_ready() - if not self.application_id: - print( - "Error: application_id still not set after client is ready. Cannot sync commands." - ) - return - - await self.app_command_handler.sync_commands( - application_id=self.application_id, guild_id=guild_id - ) - - async def on_interaction_create(self, interaction: Interaction) -> None: - """|coro| Called when an interaction is created.""" - - pass - - async def on_presence_update(self, presence) -> None: - """|coro| Called when a user's presence is updated.""" - - pass - - async def on_typing_start(self, typing) -> None: - """|coro| Called when a user starts typing in a channel.""" - - pass - - async def on_app_command_error( - self, context: AppCommandContext, error: Exception - ) -> None: - """Default error handler for application commands.""" - - print( - f"Error in application command '{context.command.name if context.command else 'unknown'}': {error}" - ) - try: - if not context._responded: - await context.send( - "An error occurred while running this command.", ephemeral=True - ) - except Exception as e: - print(f"Failed to send error message for app command: {e}") - - async def on_error( - self, event_method: str, exc: Exception, *args: Any, **kwargs: Any - ) -> None: - """Default event listener error handler.""" - - print(f"Unhandled exception in event listener for '{event_method}':") - print(f"{type(exc).__name__}: {exc}") - - -class AutoShardedClient(Client): - """A :class:`Client` that automatically determines the shard count. - - If ``shard_count`` is not provided, the client will query the Discord API - via :meth:`HTTPClient.get_gateway_bot` for the recommended shard count and - use that when connecting. - """ - - async def connect(self, reconnect: bool = True) -> None: # type: ignore[override] - if self.shard_count is None: - data = await self._http.get_gateway_bot() - self.shard_count = data.get("shards", 1) - - await super().connect(reconnect=reconnect) + + await self.app_command_handler.process_interaction(interaction) + + async def sync_application_commands( + self, guild_id: Optional[Snowflake] = None + ) -> None: + """Synchronizes application commands with Discord.""" + + if not self.application_id: + print( + "Warning: Cannot sync application commands, application_id is not set. " + "Ensure the client is connected and READY." + ) + return + if not self.is_ready(): + print( + "Warning: Client is not ready. Waiting for client to be ready before syncing commands." + ) + await self.wait_until_ready() + if not self.application_id: + print( + "Error: application_id still not set after client is ready. Cannot sync commands." + ) + return + + await self.app_command_handler.sync_commands( + application_id=self.application_id, guild_id=guild_id + ) + + async def on_interaction_create(self, interaction: Interaction) -> None: + """|coro| Called when an interaction is created.""" + + pass + + async def on_presence_update(self, presence) -> None: + """|coro| Called when a user's presence is updated.""" + + pass + + async def on_typing_start(self, typing) -> None: + """|coro| Called when a user starts typing in a channel.""" + + pass + + async def on_app_command_error( + self, context: AppCommandContext, error: Exception + ) -> None: + """Default error handler for application commands.""" + + print( + f"Error in application command '{context.command.name if context.command else 'unknown'}': {error}" + ) + try: + if not context._responded: + await context.send( + "An error occurred while running this command.", ephemeral=True + ) + except Exception as e: + print(f"Failed to send error message for app command: {e}") + + async def on_error( + self, event_method: str, exc: Exception, *args: Any, **kwargs: Any + ) -> None: + """Default event listener error handler.""" + + print(f"Unhandled exception in event listener for '{event_method}':") + print(f"{type(exc).__name__}: {exc}") + + +class AutoShardedClient(Client): + """A :class:`Client` that automatically determines the shard count. + + If ``shard_count`` is not provided, the client will query the Discord API + via :meth:`HTTPClient.get_gateway_bot` for the recommended shard count and + use that when connecting. + """ + + async def connect(self, reconnect: bool = True) -> None: # type: ignore[override] + if self.shard_count is None: + data = await self._http.get_gateway_bot() + self.shard_count = data.get("shards", 1) + + await super().connect(reconnect=reconnect) diff --git a/disagreement/ext/commands/__init__.py b/disagreement/ext/commands/__init__.py index ed4b952..5e2462c 100644 --- a/disagreement/ext/commands/__init__.py +++ b/disagreement/ext/commands/__init__.py @@ -1,65 +1,65 @@ -# disagreement/ext/commands/__init__.py - -""" -disagreement.ext.commands - A command framework extension for the Disagreement library. -""" - -from .cog import Cog -from .core import ( - Command, - CommandContext, - CommandHandler, -) # CommandHandler might be internal -from .decorators import ( - command, - listener, - check, - check_any, - cooldown, - max_concurrency, - requires_permissions, +# disagreement/ext/commands/__init__.py + +""" +disagreement.ext.commands - A command framework extension for the Disagreement library. +""" + +from .cog import Cog +from .core import ( + Command, + CommandContext, + CommandHandler, +) # CommandHandler might be internal +from .decorators import ( + command, + listener, + check, + check_any, + cooldown, + max_concurrency, + requires_permissions, has_role, has_any_role, -) -from .errors import ( - CommandError, - CommandNotFound, - BadArgument, - MissingRequiredArgument, - ArgumentParsingError, - CheckFailure, - CheckAnyFailure, - CommandOnCooldown, - CommandInvokeError, - MaxConcurrencyReached, -) - -__all__ = [ - # Cog - "Cog", - # Core - "Command", - "CommandContext", - # "CommandHandler", # Usually not part of public API for direct use by bot devs - # Decorators - "command", - "listener", - "check", - "check_any", - "cooldown", - "max_concurrency", - "requires_permissions", +) +from .errors import ( + CommandError, + CommandNotFound, + BadArgument, + MissingRequiredArgument, + ArgumentParsingError, + CheckFailure, + CheckAnyFailure, + CommandOnCooldown, + CommandInvokeError, + MaxConcurrencyReached, +) + +__all__ = [ + # Cog + "Cog", + # Core + "Command", + "CommandContext", + # "CommandHandler", # Usually not part of public API for direct use by bot devs + # Decorators + "command", + "listener", + "check", + "check_any", + "cooldown", + "max_concurrency", + "requires_permissions", "has_role", "has_any_role", - # Errors - "CommandError", - "CommandNotFound", - "BadArgument", - "MissingRequiredArgument", - "ArgumentParsingError", - "CheckFailure", - "CheckAnyFailure", - "CommandOnCooldown", - "CommandInvokeError", - "MaxConcurrencyReached", -] + # Errors + "CommandError", + "CommandNotFound", + "BadArgument", + "MissingRequiredArgument", + "ArgumentParsingError", + "CheckFailure", + "CheckAnyFailure", + "CommandOnCooldown", + "CommandInvokeError", + "MaxConcurrencyReached", +] diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py index 30c4871..b9f48c2 100644 --- a/disagreement/ext/commands/core.py +++ b/disagreement/ext/commands/core.py @@ -1,45 +1,45 @@ -# disagreement/ext/commands/core.py - -from __future__ import annotations - -import asyncio -import logging -import inspect -from typing import ( - TYPE_CHECKING, - Optional, - List, - Dict, - Any, - Union, - Callable, - Awaitable, - Tuple, - get_origin, - get_args, -) - -from .view import StringView -from .errors import ( - CommandError, - CommandNotFound, - BadArgument, - MissingRequiredArgument, - ArgumentParsingError, - CheckFailure, - CommandInvokeError, -) -from .converters import run_converters, DEFAULT_CONVERTERS, Converter -from disagreement.typing import Typing - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from .cog import Cog - from disagreement.client import Client - from disagreement.models import Message, User - - +# disagreement/ext/commands/core.py + +from __future__ import annotations + +import asyncio +import logging +import inspect +from typing import ( + TYPE_CHECKING, + Optional, + List, + Dict, + Any, + Union, + Callable, + Awaitable, + Tuple, + get_origin, + get_args, +) + +from .view import StringView +from .errors import ( + CommandError, + CommandNotFound, + BadArgument, + MissingRequiredArgument, + ArgumentParsingError, + CheckFailure, + CommandInvokeError, +) +from .converters import run_converters, DEFAULT_CONVERTERS, Converter +from disagreement.typing import Typing + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from .cog import Cog + from disagreement.client import Client + from disagreement.models import Message, User + + class GroupMixin: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -76,50 +76,50 @@ class GroupMixin: class Command(GroupMixin): - """ - Represents a bot command. - - Attributes: - name (str): The primary name of the command. - callback (Callable[..., Awaitable[None]]): The coroutine function to execute. - aliases (List[str]): Alternative names for the command. - brief (Optional[str]): A short description for help commands. - description (Optional[str]): A longer description for help commands. - cog (Optional['Cog']): Reference to the Cog this command belongs to. - params (Dict[str, inspect.Parameter]): Cached parameters of the callback. - """ - - def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any): - if not asyncio.iscoroutinefunction(callback): - raise TypeError("Command callback must be a coroutine function.") - + """ + Represents a bot command. + + Attributes: + name (str): The primary name of the command. + callback (Callable[..., Awaitable[None]]): The coroutine function to execute. + aliases (List[str]): Alternative names for the command. + brief (Optional[str]): A short description for help commands. + description (Optional[str]): A longer description for help commands. + cog (Optional['Cog']): Reference to the Cog this command belongs to. + params (Dict[str, inspect.Parameter]): Cached parameters of the callback. + """ + + def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any): + if not asyncio.iscoroutinefunction(callback): + raise TypeError("Command callback must be a coroutine function.") + super().__init__(**attrs) - self.callback: Callable[..., Awaitable[None]] = callback - self.name: str = attrs.get("name", callback.__name__) - self.aliases: List[str] = attrs.get("aliases", []) - self.brief: Optional[str] = attrs.get("brief") - self.description: Optional[str] = attrs.get("description") or callback.__doc__ - self.cog: Optional["Cog"] = attrs.get("cog") + self.callback: Callable[..., Awaitable[None]] = callback + self.name: str = attrs.get("name", callback.__name__) + self.aliases: List[str] = attrs.get("aliases", []) + self.brief: Optional[str] = attrs.get("brief") + self.description: Optional[str] = attrs.get("description") or callback.__doc__ + self.cog: Optional["Cog"] = attrs.get("cog") self.invoke_without_command: bool = attrs.get("invoke_without_command", False) - - self.params = inspect.signature(callback).parameters - self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = [] - if hasattr(callback, "__command_checks__"): - self.checks.extend(getattr(callback, "__command_checks__")) - - self.max_concurrency: Optional[Tuple[int, str]] = None - if hasattr(callback, "__max_concurrency__"): - self.max_concurrency = getattr(callback, "__max_concurrency__") - - def add_check( - self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool] - ) -> None: - self.checks.append(predicate) - + + self.params = inspect.signature(callback).parameters + self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = [] + if hasattr(callback, "__command_checks__"): + self.checks.extend(getattr(callback, "__command_checks__")) + + self.max_concurrency: Optional[Tuple[int, str]] = None + if hasattr(callback, "__max_concurrency__"): + self.max_concurrency = getattr(callback, "__max_concurrency__") + + def add_check( + self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool] + ) -> None: + self.checks.append(predicate) + async def _run_checks(self, ctx: "CommandContext") -> None: """Runs all cog, local and global checks for the command.""" - from .errors import CheckFailure - + from .errors import CheckFailure + # Run cog-level check first if self.cog: cog_check = getattr(self.cog, "cog_check", None) @@ -138,11 +138,11 @@ class Command(GroupMixin): raise CommandInvokeError(e) from e # Run local checks - for predicate in self.checks: - result = predicate(ctx) - if inspect.isawaitable(result): - result = await result - if not result: + for predicate in self.checks: + result = predicate(ctx) + if inspect.isawaitable(result): + result = await result + if not result: raise CheckFailure(f"A local check for command '{self.name}' failed.") # Then run global checks from the handler @@ -161,8 +161,8 @@ class Command(GroupMixin): before_invoke = None after_invoke = None - - if self.cog: + + if self.cog: before_invoke = getattr(self.cog, "cog_before_invoke", None) after_invoke = getattr(self.cog, "cog_after_invoke", None) @@ -183,161 +183,161 @@ class Group(Command): """A command that can have subcommands.""" def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any): super().__init__(callback, **attrs) - - -PrefixCommand = Command # Alias for clarity in hybrid commands - - -class CommandContext: - """ - Represents the context in which a command is being invoked. - """ - - def __init__( - self, - *, - message: "Message", - bot: "Client", - prefix: str, - command: "Command", - invoked_with: str, - args: Optional[List[Any]] = None, - kwargs: Optional[Dict[str, Any]] = None, - cog: Optional["Cog"] = None, - ): - self.message: "Message" = message - self.bot: "Client" = bot - self.prefix: str = prefix - self.command: "Command" = command - self.invoked_with: str = invoked_with - self.args: List[Any] = args or [] - self.kwargs: Dict[str, Any] = kwargs or {} - self.cog: Optional["Cog"] = cog - - self.author: "User" = message.author - - @property - def guild(self): - """The guild this command was invoked in.""" - if self.message.guild_id and hasattr(self.bot, "get_guild"): - return self.bot.get_guild(self.message.guild_id) - return None - - async def reply( - self, - content: Optional[str] = None, - *, - mention_author: Optional[bool] = None, - **kwargs: Any, - ) -> "Message": - """Replies to the invoking message. - - Parameters - ---------- - content: str - The content to send. - mention_author: Optional[bool] - Whether to mention the author in the reply. If ``None`` the - client's :attr:`mention_replies` value is used. - """ - - allowed_mentions = kwargs.pop("allowed_mentions", None) - if mention_author is None: - mention_author = getattr(self.bot, "mention_replies", False) - - if allowed_mentions is None: - allowed_mentions = {"replied_user": mention_author} - else: - allowed_mentions = dict(allowed_mentions) - allowed_mentions.setdefault("replied_user", mention_author) - - return await self.bot.send_message( - channel_id=self.message.channel_id, - content=content, - message_reference={ - "message_id": self.message.id, - "channel_id": self.message.channel_id, - "guild_id": self.message.guild_id, - }, - allowed_mentions=allowed_mentions, - **kwargs, - ) - - async def send(self, content: str, **kwargs: Any) -> "Message": - return await self.bot.send_message( - channel_id=self.message.channel_id, content=content, **kwargs - ) - - async def edit( - self, - message: Union[str, "Message"], - *, - content: Optional[str] = None, - **kwargs: Any, - ) -> "Message": - """Edits a message previously sent by the bot.""" - - message_id = message if isinstance(message, str) else message.id - return await self.bot.edit_message( - channel_id=self.message.channel_id, - message_id=message_id, - content=content, - **kwargs, - ) - - def typing(self) -> "Typing": - """Return a typing context manager for this context's channel.""" - - return self.bot.typing(self.message.channel_id) - - -class CommandHandler: - """ - Manages command registration, parsing, and dispatching. - """ - - def __init__( - self, - client: "Client", - prefix: Union[ - str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] - ], - ): - self.client: "Client" = client - self.prefix: Union[ - str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] - ] = prefix - self.commands: Dict[str, Command] = {} - self.cogs: Dict[str, "Cog"] = {} - self._concurrency: Dict[str, Dict[str, int]] = {} + + +PrefixCommand = Command # Alias for clarity in hybrid commands + + +class CommandContext: + """ + Represents the context in which a command is being invoked. + """ + + def __init__( + self, + *, + message: "Message", + bot: "Client", + prefix: str, + command: "Command", + invoked_with: str, + args: Optional[List[Any]] = None, + kwargs: Optional[Dict[str, Any]] = None, + cog: Optional["Cog"] = None, + ): + self.message: "Message" = message + self.bot: "Client" = bot + self.prefix: str = prefix + self.command: "Command" = command + self.invoked_with: str = invoked_with + self.args: List[Any] = args or [] + self.kwargs: Dict[str, Any] = kwargs or {} + self.cog: Optional["Cog"] = cog + + self.author: "User" = message.author + + @property + def guild(self): + """The guild this command was invoked in.""" + if self.message.guild_id and hasattr(self.bot, "get_guild"): + return self.bot.get_guild(self.message.guild_id) + return None + + async def reply( + self, + content: Optional[str] = None, + *, + mention_author: Optional[bool] = None, + **kwargs: Any, + ) -> "Message": + """Replies to the invoking message. + + Parameters + ---------- + content: str + The content to send. + mention_author: Optional[bool] + Whether to mention the author in the reply. If ``None`` the + client's :attr:`mention_replies` value is used. + """ + + allowed_mentions = kwargs.pop("allowed_mentions", None) + if mention_author is None: + mention_author = getattr(self.bot, "mention_replies", False) + + if allowed_mentions is None: + allowed_mentions = {"replied_user": mention_author} + else: + allowed_mentions = dict(allowed_mentions) + allowed_mentions.setdefault("replied_user", mention_author) + + return await self.bot.send_message( + channel_id=self.message.channel_id, + content=content, + message_reference={ + "message_id": self.message.id, + "channel_id": self.message.channel_id, + "guild_id": self.message.guild_id, + }, + allowed_mentions=allowed_mentions, + **kwargs, + ) + + async def send(self, content: str, **kwargs: Any) -> "Message": + return await self.bot.send_message( + channel_id=self.message.channel_id, content=content, **kwargs + ) + + async def edit( + self, + message: Union[str, "Message"], + *, + content: Optional[str] = None, + **kwargs: Any, + ) -> "Message": + """Edits a message previously sent by the bot.""" + + message_id = message if isinstance(message, str) else message.id + return await self.bot.edit_message( + channel_id=self.message.channel_id, + message_id=message_id, + content=content, + **kwargs, + ) + + def typing(self) -> "Typing": + """Return a typing context manager for this context's channel.""" + + return self.bot.typing(self.message.channel_id) + + +class CommandHandler: + """ + Manages command registration, parsing, and dispatching. + """ + + def __init__( + self, + client: "Client", + prefix: Union[ + str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] + ], + ): + self.client: "Client" = client + self.prefix: Union[ + str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] + ] = prefix + self.commands: Dict[str, Command] = {} + self.cogs: Dict[str, "Cog"] = {} + self._concurrency: Dict[str, Dict[str, int]] = {} self._global_checks: List[ Callable[["CommandContext"], Awaitable[bool] | bool] ] = [] - - from .help import HelpCommand - - self.add_command(HelpCommand(self)) - + + from .help import HelpCommand + + self.add_command(HelpCommand(self)) + def add_check( self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool] ) -> None: """Adds a global check to the command handler.""" self._global_checks.append(predicate) - def add_command(self, command: Command) -> None: - if command.name in self.commands: - raise ValueError(f"Command '{command.name}' is already registered.") - - self.commands[command.name.lower()] = command - for alias in command.aliases: - if alias in self.commands: - logger.warning( - "Alias '%s' for command '%s' conflicts with an existing command or alias.", - alias, - command.name, - ) - self.commands[alias.lower()] = command - + def add_command(self, command: Command) -> None: + if command.name in self.commands: + raise ValueError(f"Command '{command.name}' is already registered.") + + self.commands[command.name.lower()] = command + for alias in command.aliases: + if alias in self.commands: + logger.warning( + "Alias '%s' for command '%s' conflicts with an existing command or alias.", + alias, + command.name, + ) + self.commands[alias.lower()] = command + if isinstance(command, Group): for sub_cmd in command.commands.values(): if sub_cmd.name in self.commands: @@ -347,301 +347,301 @@ class CommandHandler: command.name, ) - def remove_command(self, name: str) -> Optional[Command]: - command = self.commands.pop(name.lower(), None) - if command: - for alias in command.aliases: - self.commands.pop(alias.lower(), None) - return command - - def get_command(self, name: str) -> Optional[Command]: - return self.commands.get(name.lower()) - - def add_cog(self, cog_to_add: "Cog") -> None: - from .cog import Cog - - if not isinstance(cog_to_add, Cog): - raise TypeError("Argument must be a subclass of Cog.") - - if cog_to_add.cog_name in self.cogs: - raise ValueError( - f"Cog with name '{cog_to_add.cog_name}' is already registered." - ) - - self.cogs[cog_to_add.cog_name] = cog_to_add - - for cmd in cog_to_add.get_commands(): - self.add_command(cmd) - - if hasattr(self.client, "_event_dispatcher"): - for event_name, callback in cog_to_add.get_listeners(): - self.client._event_dispatcher.register(event_name.upper(), callback) - else: - logger.warning( - "Client does not have '_event_dispatcher'. Listeners for cog '%s' not registered.", - cog_to_add.cog_name, - ) - - if hasattr(cog_to_add, "cog_load") and inspect.iscoroutinefunction( - cog_to_add.cog_load - ): - asyncio.create_task(cog_to_add.cog_load()) - - logger.info("Cog '%s' added.", cog_to_add.cog_name) - - def remove_cog(self, cog_name: str) -> Optional["Cog"]: - cog_to_remove = self.cogs.pop(cog_name, None) - if cog_to_remove: - for cmd in cog_to_remove.get_commands(): - self.remove_command(cmd.name) - - if hasattr(self.client, "_event_dispatcher"): - for event_name, callback in cog_to_remove.get_listeners(): - logger.debug( - "Listener '%s' for event '%s' from cog '%s' needs manual unregistration logic in EventDispatcher.", - callback.__name__, - event_name, - cog_name, - ) - - if hasattr(cog_to_remove, "cog_unload") and inspect.iscoroutinefunction( - cog_to_remove.cog_unload - ): - asyncio.create_task(cog_to_remove.cog_unload()) - - cog_to_remove._eject() - logger.info("Cog '%s' removed.", cog_name) - return cog_to_remove - - def _acquire_concurrency(self, ctx: CommandContext) -> None: - mc = getattr(ctx.command, "max_concurrency", None) - if not mc: - return - limit, scope = mc - if scope == "user": - key = ctx.author.id - elif scope == "guild": - key = ctx.message.guild_id or ctx.author.id - else: - key = "global" - buckets = self._concurrency.setdefault(ctx.command.name, {}) - current = buckets.get(key, 0) - if current >= limit: - from .errors import MaxConcurrencyReached - - raise MaxConcurrencyReached(limit) - buckets[key] = current + 1 - - def _release_concurrency(self, ctx: CommandContext) -> None: - mc = getattr(ctx.command, "max_concurrency", None) - if not mc: - return - _, scope = mc - if scope == "user": - key = ctx.author.id - elif scope == "guild": - key = ctx.message.guild_id or ctx.author.id - else: - key = "global" - buckets = self._concurrency.get(ctx.command.name) - if not buckets: - return - current = buckets.get(key, 0) - if current <= 1: - buckets.pop(key, None) - else: - buckets[key] = current - 1 - if not buckets: - self._concurrency.pop(ctx.command.name, None) - - async def get_prefix(self, message: "Message") -> Union[str, List[str], None]: - if callable(self.prefix): - if inspect.iscoroutinefunction(self.prefix): - return await self.prefix(self.client, message) - else: - return self.prefix(self.client, message) # type: ignore - return self.prefix - - async def _parse_arguments( - self, command: Command, ctx: CommandContext, view: StringView - ) -> Tuple[List[Any], Dict[str, Any]]: - args_list = [] - kwargs_dict = {} - params_to_parse = list(command.params.values()) - - if params_to_parse and params_to_parse[0].name == "self" and command.cog: - params_to_parse.pop(0) - if params_to_parse and params_to_parse[0].name == "ctx": - params_to_parse.pop(0) - - for param in params_to_parse: - view.skip_whitespace() - final_value_for_param: Any = inspect.Parameter.empty - - if param.kind == inspect.Parameter.VAR_POSITIONAL: - while not view.eof: - view.skip_whitespace() - if view.eof: - break - word = view.get_word() - if word or not view.eof: - args_list.append(word) - elif view.eof: - break - break - - arg_str_value: Optional[str] = ( - None # Holds the raw string for current param - ) - - if view.eof: # No more input string - if param.default is not inspect.Parameter.empty: - final_value_for_param = param.default - elif param.kind != inspect.Parameter.VAR_KEYWORD: - raise MissingRequiredArgument(param.name) - else: # VAR_KEYWORD at EOF is fine - break - else: # Input available - is_last_pos_str_greedy = ( - param == params_to_parse[-1] - and param.annotation is str - and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ) - - if is_last_pos_str_greedy: - arg_str_value = view.read_rest().strip() - if ( - not arg_str_value - and param.default is not inspect.Parameter.empty - ): - final_value_for_param = param.default - else: # Includes empty string if that's what's left - final_value_for_param = arg_str_value - else: # Not greedy, or not string, or not last positional - if view.buffer[view.index] == '"': - arg_str_value = view.get_quoted_string() - if arg_str_value == "" and view.buffer[view.index] == '"': - raise BadArgument( - f"Unterminated quoted string for argument '{param.name}'." - ) - else: - arg_str_value = view.get_word() - - # If final_value_for_param was not set by greedy logic, try conversion - if final_value_for_param is inspect.Parameter.empty: - if ( - arg_str_value is None - ): # Should not happen if view.get_word/get_quoted_string is robust - if param.default is not inspect.Parameter.empty: - final_value_for_param = param.default - else: - raise MissingRequiredArgument(param.name) - else: # We have an arg_str_value (could be empty string "" from quotes) - annotation = param.annotation - origin = get_origin(annotation) - - if origin is Union: # Handles Optional[T] and Union[T1, T2] - union_args = get_args(annotation) - is_optional = ( - len(union_args) == 2 and type(None) in union_args - ) - - converted_for_union = False - last_err_union: Optional[BadArgument] = None - for t_arg in union_args: - if t_arg is type(None): - continue - try: - final_value_for_param = await run_converters( - ctx, t_arg, arg_str_value - ) - converted_for_union = True - break - except BadArgument as e: - last_err_union = e - - if not converted_for_union: - if ( - is_optional and param.default is None - ): # Special handling for Optional[T] if conversion failed - # If arg_str_value was "" and type was Optional[str], StringConverter would return "" - # If arg_str_value was "" and type was Optional[int], BadArgument would be raised. - # This path is for when all actual types in Optional[T] fail conversion. - # If default is None, we can assign None. - final_value_for_param = None - elif last_err_union: - raise last_err_union - else: # Should not be reached if logic is correct - raise BadArgument( - f"Could not convert '{arg_str_value}' to any of {union_args} for param '{param.name}'." - ) - elif annotation is inspect.Parameter.empty or annotation is str: - final_value_for_param = arg_str_value - else: # Standard type hint - final_value_for_param = await run_converters( - ctx, annotation, arg_str_value - ) - - # Final check if value was resolved - if final_value_for_param is inspect.Parameter.empty: - if param.default is not inspect.Parameter.empty: - final_value_for_param = param.default - elif param.kind != inspect.Parameter.VAR_KEYWORD: - # This state implies an issue if required and no default, and no input was parsed. - raise MissingRequiredArgument( - f"Parameter '{param.name}' could not be resolved." - ) - - # Assign to args_list or kwargs_dict if a value was determined - if final_value_for_param is not inspect.Parameter.empty: - if ( - param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - or param.kind == inspect.Parameter.POSITIONAL_ONLY - ): - args_list.append(final_value_for_param) - elif param.kind == inspect.Parameter.KEYWORD_ONLY: - kwargs_dict[param.name] = final_value_for_param - - return args_list, kwargs_dict - - async def process_commands(self, message: "Message") -> None: - if not message.content: - return - - prefix_to_use = await self.get_prefix(message) - if not prefix_to_use: - return - - actual_prefix: Optional[str] = None - if isinstance(prefix_to_use, list): - for p in prefix_to_use: - if message.content.startswith(p): - actual_prefix = p - break - if not actual_prefix: - return - elif isinstance(prefix_to_use, str): - if message.content.startswith(prefix_to_use): - actual_prefix = prefix_to_use - else: - return - else: - return - - if actual_prefix is None: - return - - content_without_prefix = message.content[len(actual_prefix) :] - view = StringView(content_without_prefix) - - command_name = view.get_word() - if not command_name: - return - - command = self.get_command(command_name) - if not command: - return - + def remove_command(self, name: str) -> Optional[Command]: + command = self.commands.pop(name.lower(), None) + if command: + for alias in command.aliases: + self.commands.pop(alias.lower(), None) + return command + + def get_command(self, name: str) -> Optional[Command]: + return self.commands.get(name.lower()) + + def add_cog(self, cog_to_add: "Cog") -> None: + from .cog import Cog + + if not isinstance(cog_to_add, Cog): + raise TypeError("Argument must be a subclass of Cog.") + + if cog_to_add.cog_name in self.cogs: + raise ValueError( + f"Cog with name '{cog_to_add.cog_name}' is already registered." + ) + + self.cogs[cog_to_add.cog_name] = cog_to_add + + for cmd in cog_to_add.get_commands(): + self.add_command(cmd) + + if hasattr(self.client, "_event_dispatcher"): + for event_name, callback in cog_to_add.get_listeners(): + self.client._event_dispatcher.register(event_name.upper(), callback) + else: + logger.warning( + "Client does not have '_event_dispatcher'. Listeners for cog '%s' not registered.", + cog_to_add.cog_name, + ) + + if hasattr(cog_to_add, "cog_load") and inspect.iscoroutinefunction( + cog_to_add.cog_load + ): + asyncio.create_task(cog_to_add.cog_load()) + + logger.info("Cog '%s' added.", cog_to_add.cog_name) + + def remove_cog(self, cog_name: str) -> Optional["Cog"]: + cog_to_remove = self.cogs.pop(cog_name, None) + if cog_to_remove: + for cmd in cog_to_remove.get_commands(): + self.remove_command(cmd.name) + + if hasattr(self.client, "_event_dispatcher"): + for event_name, callback in cog_to_remove.get_listeners(): + logger.debug( + "Listener '%s' for event '%s' from cog '%s' needs manual unregistration logic in EventDispatcher.", + callback.__name__, + event_name, + cog_name, + ) + + if hasattr(cog_to_remove, "cog_unload") and inspect.iscoroutinefunction( + cog_to_remove.cog_unload + ): + asyncio.create_task(cog_to_remove.cog_unload()) + + cog_to_remove._eject() + logger.info("Cog '%s' removed.", cog_name) + return cog_to_remove + + def _acquire_concurrency(self, ctx: CommandContext) -> None: + mc = getattr(ctx.command, "max_concurrency", None) + if not mc: + return + limit, scope = mc + if scope == "user": + key = ctx.author.id + elif scope == "guild": + key = ctx.message.guild_id or ctx.author.id + else: + key = "global" + buckets = self._concurrency.setdefault(ctx.command.name, {}) + current = buckets.get(key, 0) + if current >= limit: + from .errors import MaxConcurrencyReached + + raise MaxConcurrencyReached(limit) + buckets[key] = current + 1 + + def _release_concurrency(self, ctx: CommandContext) -> None: + mc = getattr(ctx.command, "max_concurrency", None) + if not mc: + return + _, scope = mc + if scope == "user": + key = ctx.author.id + elif scope == "guild": + key = ctx.message.guild_id or ctx.author.id + else: + key = "global" + buckets = self._concurrency.get(ctx.command.name) + if not buckets: + return + current = buckets.get(key, 0) + if current <= 1: + buckets.pop(key, None) + else: + buckets[key] = current - 1 + if not buckets: + self._concurrency.pop(ctx.command.name, None) + + async def get_prefix(self, message: "Message") -> Union[str, List[str], None]: + if callable(self.prefix): + if inspect.iscoroutinefunction(self.prefix): + return await self.prefix(self.client, message) + else: + return self.prefix(self.client, message) # type: ignore + return self.prefix + + async def _parse_arguments( + self, command: Command, ctx: CommandContext, view: StringView + ) -> Tuple[List[Any], Dict[str, Any]]: + args_list = [] + kwargs_dict = {} + params_to_parse = list(command.params.values()) + + if params_to_parse and params_to_parse[0].name == "self" and command.cog: + params_to_parse.pop(0) + if params_to_parse and params_to_parse[0].name == "ctx": + params_to_parse.pop(0) + + for param in params_to_parse: + view.skip_whitespace() + final_value_for_param: Any = inspect.Parameter.empty + + if param.kind == inspect.Parameter.VAR_POSITIONAL: + while not view.eof: + view.skip_whitespace() + if view.eof: + break + word = view.get_word() + if word or not view.eof: + args_list.append(word) + elif view.eof: + break + break + + arg_str_value: Optional[str] = ( + None # Holds the raw string for current param + ) + + if view.eof: # No more input string + if param.default is not inspect.Parameter.empty: + final_value_for_param = param.default + elif param.kind != inspect.Parameter.VAR_KEYWORD: + raise MissingRequiredArgument(param.name) + else: # VAR_KEYWORD at EOF is fine + break + else: # Input available + is_last_pos_str_greedy = ( + param == params_to_parse[-1] + and param.annotation is str + and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + + if is_last_pos_str_greedy: + arg_str_value = view.read_rest().strip() + if ( + not arg_str_value + and param.default is not inspect.Parameter.empty + ): + final_value_for_param = param.default + else: # Includes empty string if that's what's left + final_value_for_param = arg_str_value + else: # Not greedy, or not string, or not last positional + if view.buffer[view.index] == '"': + arg_str_value = view.get_quoted_string() + if arg_str_value == "" and view.buffer[view.index] == '"': + raise BadArgument( + f"Unterminated quoted string for argument '{param.name}'." + ) + else: + arg_str_value = view.get_word() + + # If final_value_for_param was not set by greedy logic, try conversion + if final_value_for_param is inspect.Parameter.empty: + if ( + arg_str_value is None + ): # Should not happen if view.get_word/get_quoted_string is robust + if param.default is not inspect.Parameter.empty: + final_value_for_param = param.default + else: + raise MissingRequiredArgument(param.name) + else: # We have an arg_str_value (could be empty string "" from quotes) + annotation = param.annotation + origin = get_origin(annotation) + + if origin is Union: # Handles Optional[T] and Union[T1, T2] + union_args = get_args(annotation) + is_optional = ( + len(union_args) == 2 and type(None) in union_args + ) + + converted_for_union = False + last_err_union: Optional[BadArgument] = None + for t_arg in union_args: + if t_arg is type(None): + continue + try: + final_value_for_param = await run_converters( + ctx, t_arg, arg_str_value + ) + converted_for_union = True + break + except BadArgument as e: + last_err_union = e + + if not converted_for_union: + if ( + is_optional and param.default is None + ): # Special handling for Optional[T] if conversion failed + # If arg_str_value was "" and type was Optional[str], StringConverter would return "" + # If arg_str_value was "" and type was Optional[int], BadArgument would be raised. + # This path is for when all actual types in Optional[T] fail conversion. + # If default is None, we can assign None. + final_value_for_param = None + elif last_err_union: + raise last_err_union + else: # Should not be reached if logic is correct + raise BadArgument( + f"Could not convert '{arg_str_value}' to any of {union_args} for param '{param.name}'." + ) + elif annotation is inspect.Parameter.empty or annotation is str: + final_value_for_param = arg_str_value + else: # Standard type hint + final_value_for_param = await run_converters( + ctx, annotation, arg_str_value + ) + + # Final check if value was resolved + if final_value_for_param is inspect.Parameter.empty: + if param.default is not inspect.Parameter.empty: + final_value_for_param = param.default + elif param.kind != inspect.Parameter.VAR_KEYWORD: + # This state implies an issue if required and no default, and no input was parsed. + raise MissingRequiredArgument( + f"Parameter '{param.name}' could not be resolved." + ) + + # Assign to args_list or kwargs_dict if a value was determined + if final_value_for_param is not inspect.Parameter.empty: + if ( + param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + or param.kind == inspect.Parameter.POSITIONAL_ONLY + ): + args_list.append(final_value_for_param) + elif param.kind == inspect.Parameter.KEYWORD_ONLY: + kwargs_dict[param.name] = final_value_for_param + + return args_list, kwargs_dict + + async def process_commands(self, message: "Message") -> None: + if not message.content: + return + + prefix_to_use = await self.get_prefix(message) + if not prefix_to_use: + return + + actual_prefix: Optional[str] = None + if isinstance(prefix_to_use, list): + for p in prefix_to_use: + if message.content.startswith(p): + actual_prefix = p + break + if not actual_prefix: + return + elif isinstance(prefix_to_use, str): + if message.content.startswith(prefix_to_use): + actual_prefix = prefix_to_use + else: + return + else: + return + + if actual_prefix is None: + return + + content_without_prefix = message.content[len(actual_prefix) :] + view = StringView(content_without_prefix) + + command_name = view.get_word() + if not command_name: + return + + command = self.get_command(command_name) + if not command: + return + invoked_with = command_name original_command = command @@ -658,33 +658,33 @@ class CommandHandler: else: raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.") - ctx = CommandContext( - message=message, - bot=self.client, - prefix=actual_prefix, - command=command, + ctx = CommandContext( + message=message, + bot=self.client, + prefix=actual_prefix, + command=command, invoked_with=invoked_with, - cog=command.cog, - ) - - try: - parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view) - ctx.args = parsed_args - ctx.kwargs = parsed_kwargs - self._acquire_concurrency(ctx) - try: - await command.invoke(ctx, *parsed_args, **parsed_kwargs) - finally: - self._release_concurrency(ctx) - except CommandError as e: + cog=command.cog, + ) + + try: + parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view) + ctx.args = parsed_args + ctx.kwargs = parsed_kwargs + self._acquire_concurrency(ctx) + try: + await command.invoke(ctx, *parsed_args, **parsed_kwargs) + finally: + self._release_concurrency(ctx) + except CommandError as e: logger.error("Command error for '%s': %s", original_command.name, e) - if hasattr(self.client, "on_command_error"): - await self.client.on_command_error(ctx, e) - except Exception as e: + if hasattr(self.client, "on_command_error"): + await self.client.on_command_error(ctx, e) + except Exception as e: logger.error("Unexpected error invoking command '%s': %s", original_command.name, e) - exc = CommandInvokeError(e) - if hasattr(self.client, "on_command_error"): - await self.client.on_command_error(ctx, exc) + exc = CommandInvokeError(e) + if hasattr(self.client, "on_command_error"): + await self.client.on_command_error(ctx, exc) else: if hasattr(self.client, "on_command_completion"): await self.client.on_command_completion(ctx) diff --git a/disagreement/ext/commands/decorators.py b/disagreement/ext/commands/decorators.py index a0c9019..7400118 100644 --- a/disagreement/ext/commands/decorators.py +++ b/disagreement/ext/commands/decorators.py @@ -1,222 +1,222 @@ -# disagreement/ext/commands/decorators.py -from __future__ import annotations - -import asyncio -import inspect -import time -from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable - -if TYPE_CHECKING: - from .core import Command, CommandContext - from disagreement.permissions import Permissions - from disagreement.models import Member, Guild, Channel - - -def command( - name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any -) -> Callable: - """ - A decorator that transforms a function into a Command. - - Args: - name (Optional[str]): The name of the command. Defaults to the function name. - aliases (Optional[List[str]]): Alternative names for the command. - **attrs: Additional attributes to pass to the Command constructor - (e.g., brief, description, hidden). - - Returns: - Callable: A decorator that registers the command. - """ - - def decorator( - func: Callable[..., Awaitable[None]], - ) -> Callable[..., Awaitable[None]]: - if not asyncio.iscoroutinefunction(func): - raise TypeError("Command callback must be a coroutine function.") - - from .core import Command - - cmd_name = name or func.__name__ - - if hasattr(func, "__command_attrs__"): - raise TypeError("Function is already a command or has command attributes.") - - cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs) - func.__command_object__ = cmd # type: ignore - return func - - return decorator - - -def listener( - name: Optional[str] = None, -) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: - """ - A decorator that marks a function as an event listener within a Cog. - """ - - def decorator( - func: Callable[..., Awaitable[None]], - ) -> Callable[..., Awaitable[None]]: - if not asyncio.iscoroutinefunction(func): - raise TypeError("Listener callback must be a coroutine function.") - - actual_event_name = name or func.__name__ - setattr(func, "__listener_name__", actual_event_name) - return func - - return decorator - - -def check( - predicate: Callable[["CommandContext"], Awaitable[bool] | bool], -) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: - """Decorator to add a check to a command.""" - - def decorator( - func: Callable[..., Awaitable[None]], - ) -> Callable[..., Awaitable[None]]: - checks = getattr(func, "__command_checks__", []) - checks.append(predicate) - setattr(func, "__command_checks__", checks) - return func - - return decorator - - -def check_any( - *predicates: Callable[["CommandContext"], Awaitable[bool] | bool] -) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: - """Decorator that passes if any predicate returns ``True``.""" - - async def predicate(ctx: "CommandContext") -> bool: - from .errors import CheckAnyFailure, CheckFailure - - errors = [] - for p in predicates: - try: - result = p(ctx) - if inspect.isawaitable(result): - result = await result - if result: - return True - except CheckFailure as e: - errors.append(e) - raise CheckAnyFailure(errors) - - return check(predicate) - - -def max_concurrency( - number: int, per: str = "user" -) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: - """Limit how many concurrent invocations of a command are allowed. - - Parameters - ---------- - number: - The maximum number of concurrent invocations. - per: - The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``. - """ - - if number < 1: - raise ValueError("Concurrency number must be at least 1.") - if per not in {"user", "guild", "global"}: - raise ValueError("per must be 'user', 'guild', or 'global'.") - - def decorator( - func: Callable[..., Awaitable[None]], - ) -> Callable[..., Awaitable[None]]: - setattr(func, "__max_concurrency__", (number, per)) - return func - - return decorator - - -def cooldown( - rate: int, per: float -) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: - """Simple per-user cooldown decorator.""" - - buckets: dict[str, dict[str, float]] = {} - - async def predicate(ctx: "CommandContext") -> bool: - from .errors import CommandOnCooldown - - now = time.monotonic() - user_buckets = buckets.setdefault(ctx.command.name, {}) - reset = user_buckets.get(ctx.author.id, 0) - if now < reset: - raise CommandOnCooldown(reset - now) - user_buckets[ctx.author.id] = now + per - return True - - return check(predicate) - - -def _compute_permissions( - member: "Member", channel: "Channel", guild: "Guild" -) -> "Permissions": - """Compute the effective permissions for a member in a channel.""" - return channel.permissions_for(member) - - -def requires_permissions( - *perms: "Permissions", -) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: - """Check that the invoking member has the given permissions in the channel.""" - - async def predicate(ctx: "CommandContext") -> bool: - from .errors import CheckFailure - from disagreement.permissions import ( - has_permissions, - missing_permissions, - ) - from disagreement.models import Member - - channel = getattr(ctx, "channel", None) - if channel is None and hasattr(ctx.bot, "get_channel"): - channel = ctx.bot.get_channel(ctx.message.channel_id) - if channel is None and hasattr(ctx.bot, "fetch_channel"): - channel = await ctx.bot.fetch_channel(ctx.message.channel_id) - - if channel is None: - raise CheckFailure("Channel for permission check not found.") - - guild = getattr(channel, "guild", None) - if not guild and hasattr(channel, "guild_id") and channel.guild_id: - if hasattr(ctx.bot, "get_guild"): - guild = ctx.bot.get_guild(channel.guild_id) - if not guild and hasattr(ctx.bot, "fetch_guild"): - guild = await ctx.bot.fetch_guild(channel.guild_id) - - if not guild: - is_dm = not hasattr(channel, "guild_id") or not channel.guild_id - if is_dm: - if perms: - raise CheckFailure("Permission checks are not supported in DMs.") - return True - raise CheckFailure("Guild for permission check not found.") - - member = ctx.author - if not isinstance(member, Member): - member = guild.get_member(ctx.author.id) - if not member and hasattr(ctx.bot, "fetch_member"): - member = await ctx.bot.fetch_member(guild.id, ctx.author.id) - - if not member: - raise CheckFailure("Could not resolve author to a guild member.") - - perms_value = _compute_permissions(member, channel, guild) - - if not has_permissions(perms_value, *perms): - missing = missing_permissions(perms_value, *perms) - missing_names = ", ".join(p.name for p in missing if p.name) - raise CheckFailure(f"Missing permissions: {missing_names}") - return True - - return check(predicate) +# disagreement/ext/commands/decorators.py +from __future__ import annotations + +import asyncio +import inspect +import time +from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable + +if TYPE_CHECKING: + from .core import Command, CommandContext + from disagreement.permissions import Permissions + from disagreement.models import Member, Guild, Channel + + +def command( + name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any +) -> Callable: + """ + A decorator that transforms a function into a Command. + + Args: + name (Optional[str]): The name of the command. Defaults to the function name. + aliases (Optional[List[str]]): Alternative names for the command. + **attrs: Additional attributes to pass to the Command constructor + (e.g., brief, description, hidden). + + Returns: + Callable: A decorator that registers the command. + """ + + def decorator( + func: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + if not asyncio.iscoroutinefunction(func): + raise TypeError("Command callback must be a coroutine function.") + + from .core import Command + + cmd_name = name or func.__name__ + + if hasattr(func, "__command_attrs__"): + raise TypeError("Function is already a command or has command attributes.") + + cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs) + func.__command_object__ = cmd # type: ignore + return func + + return decorator + + +def listener( + name: Optional[str] = None, +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """ + A decorator that marks a function as an event listener within a Cog. + """ + + def decorator( + func: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + if not asyncio.iscoroutinefunction(func): + raise TypeError("Listener callback must be a coroutine function.") + + actual_event_name = name or func.__name__ + setattr(func, "__listener_name__", actual_event_name) + return func + + return decorator + + +def check( + predicate: Callable[["CommandContext"], Awaitable[bool] | bool], +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Decorator to add a check to a command.""" + + def decorator( + func: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + checks = getattr(func, "__command_checks__", []) + checks.append(predicate) + setattr(func, "__command_checks__", checks) + return func + + return decorator + + +def check_any( + *predicates: Callable[["CommandContext"], Awaitable[bool] | bool] +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Decorator that passes if any predicate returns ``True``.""" + + async def predicate(ctx: "CommandContext") -> bool: + from .errors import CheckAnyFailure, CheckFailure + + errors = [] + for p in predicates: + try: + result = p(ctx) + if inspect.isawaitable(result): + result = await result + if result: + return True + except CheckFailure as e: + errors.append(e) + raise CheckAnyFailure(errors) + + return check(predicate) + + +def max_concurrency( + number: int, per: str = "user" +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Limit how many concurrent invocations of a command are allowed. + + Parameters + ---------- + number: + The maximum number of concurrent invocations. + per: + The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``. + """ + + if number < 1: + raise ValueError("Concurrency number must be at least 1.") + if per not in {"user", "guild", "global"}: + raise ValueError("per must be 'user', 'guild', or 'global'.") + + def decorator( + func: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + setattr(func, "__max_concurrency__", (number, per)) + return func + + return decorator + + +def cooldown( + rate: int, per: float +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Simple per-user cooldown decorator.""" + + buckets: dict[str, dict[str, float]] = {} + + async def predicate(ctx: "CommandContext") -> bool: + from .errors import CommandOnCooldown + + now = time.monotonic() + user_buckets = buckets.setdefault(ctx.command.name, {}) + reset = user_buckets.get(ctx.author.id, 0) + if now < reset: + raise CommandOnCooldown(reset - now) + user_buckets[ctx.author.id] = now + per + return True + + return check(predicate) + + +def _compute_permissions( + member: "Member", channel: "Channel", guild: "Guild" +) -> "Permissions": + """Compute the effective permissions for a member in a channel.""" + return channel.permissions_for(member) + + +def requires_permissions( + *perms: "Permissions", +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Check that the invoking member has the given permissions in the channel.""" + + async def predicate(ctx: "CommandContext") -> bool: + from .errors import CheckFailure + from disagreement.permissions import ( + has_permissions, + missing_permissions, + ) + from disagreement.models import Member + + channel = getattr(ctx, "channel", None) + if channel is None and hasattr(ctx.bot, "get_channel"): + channel = ctx.bot.get_channel(ctx.message.channel_id) + if channel is None and hasattr(ctx.bot, "fetch_channel"): + channel = await ctx.bot.fetch_channel(ctx.message.channel_id) + + if channel is None: + raise CheckFailure("Channel for permission check not found.") + + guild = getattr(channel, "guild", None) + if not guild and hasattr(channel, "guild_id") and channel.guild_id: + if hasattr(ctx.bot, "get_guild"): + guild = ctx.bot.get_guild(channel.guild_id) + if not guild and hasattr(ctx.bot, "fetch_guild"): + guild = await ctx.bot.fetch_guild(channel.guild_id) + + if not guild: + is_dm = not hasattr(channel, "guild_id") or not channel.guild_id + if is_dm: + if perms: + raise CheckFailure("Permission checks are not supported in DMs.") + return True + raise CheckFailure("Guild for permission check not found.") + + member = ctx.author + if not isinstance(member, Member): + member = guild.get_member(ctx.author.id) + if not member and hasattr(ctx.bot, "fetch_member"): + member = await ctx.bot.fetch_member(guild.id, ctx.author.id) + + if not member: + raise CheckFailure("Could not resolve author to a guild member.") + + perms_value = _compute_permissions(member, channel, guild) + + if not has_permissions(perms_value, *perms): + missing = missing_permissions(perms_value, *perms) + missing_names = ", ".join(p.name for p in missing if p.name) + raise CheckFailure(f"Missing permissions: {missing_names}") + return True + + return check(predicate) def has_role( name_or_id: str | int, diff --git a/disagreement/gateway.py b/disagreement/gateway.py index ff132d7..3875d92 100644 --- a/disagreement/gateway.py +++ b/disagreement/gateway.py @@ -1,244 +1,244 @@ -# disagreement/gateway.py - -""" -Manages the WebSocket connection to the Discord Gateway. -""" - -import asyncio -import logging -import traceback -import aiohttp -import json -import zlib -import time -import random -from typing import Optional, TYPE_CHECKING, Any, Dict - -from .enums import GatewayOpcode, GatewayIntent -from .errors import GatewayException, DisagreementException, AuthenticationError -from .interactions import Interaction - -if TYPE_CHECKING: - from .client import Client # For type hinting - from .event_dispatcher import EventDispatcher - from .http import HTTPClient - from .interactions import Interaction # Added for INTERACTION_CREATE - -# ZLIB Decompression constants -ZLIB_SUFFIX = b"\x00\x00\xff\xff" -MAX_DECOMPRESSION_SIZE = 10 * 1024 * 1024 # 10 MiB, adjust as needed - - -logger = logging.getLogger(__name__) - - -class GatewayClient: - """ - Handles the Discord Gateway WebSocket connection, heartbeating, and event dispatching. - """ - - def __init__( - self, - http_client: "HTTPClient", - event_dispatcher: "EventDispatcher", - token: str, - intents: int, - client_instance: "Client", # Pass the main client instance - verbose: bool = False, - *, - shard_id: Optional[int] = None, - shard_count: Optional[int] = None, - max_retries: int = 5, - max_backoff: float = 60.0, - ): - self._http: "HTTPClient" = http_client - self._dispatcher: "EventDispatcher" = event_dispatcher - self._token: str = token - self._intents: int = intents - self._client_instance: "Client" = client_instance # Store client instance - self.verbose: bool = verbose - self._shard_id: Optional[int] = shard_id - self._shard_count: Optional[int] = shard_count - self._max_retries: int = max_retries - self._max_backoff: float = max_backoff - - self._ws: Optional[aiohttp.ClientWebSocketResponse] = None - self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() - self._heartbeat_interval: Optional[float] = None - self._last_sequence: Optional[int] = None - self._session_id: Optional[str] = None - self._resume_gateway_url: Optional[str] = None - - self._keep_alive_task: Optional[asyncio.Task] = None - self._receive_task: Optional[asyncio.Task] = None - - self._last_heartbeat_sent: Optional[float] = None - self._last_heartbeat_ack: Optional[float] = None - - # For zlib decompression - self._buffer = bytearray() - self._inflator = zlib.decompressobj() - +# disagreement/gateway.py + +""" +Manages the WebSocket connection to the Discord Gateway. +""" + +import asyncio +import logging +import traceback +import aiohttp +import json +import zlib +import time +import random +from typing import Optional, TYPE_CHECKING, Any, Dict + +from .enums import GatewayOpcode, GatewayIntent +from .errors import GatewayException, DisagreementException, AuthenticationError +from .interactions import Interaction + +if TYPE_CHECKING: + from .client import Client # For type hinting + from .event_dispatcher import EventDispatcher + from .http import HTTPClient + from .interactions import Interaction # Added for INTERACTION_CREATE + +# ZLIB Decompression constants +ZLIB_SUFFIX = b"\x00\x00\xff\xff" +MAX_DECOMPRESSION_SIZE = 10 * 1024 * 1024 # 10 MiB, adjust as needed + + +logger = logging.getLogger(__name__) + + +class GatewayClient: + """ + Handles the Discord Gateway WebSocket connection, heartbeating, and event dispatching. + """ + + def __init__( + self, + http_client: "HTTPClient", + event_dispatcher: "EventDispatcher", + token: str, + intents: int, + client_instance: "Client", # Pass the main client instance + verbose: bool = False, + *, + shard_id: Optional[int] = None, + shard_count: Optional[int] = None, + max_retries: int = 5, + max_backoff: float = 60.0, + ): + self._http: "HTTPClient" = http_client + self._dispatcher: "EventDispatcher" = event_dispatcher + self._token: str = token + self._intents: int = intents + self._client_instance: "Client" = client_instance # Store client instance + self.verbose: bool = verbose + self._shard_id: Optional[int] = shard_id + self._shard_count: Optional[int] = shard_count + self._max_retries: int = max_retries + self._max_backoff: float = max_backoff + + self._ws: Optional[aiohttp.ClientWebSocketResponse] = None + self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self._heartbeat_interval: Optional[float] = None + self._last_sequence: Optional[int] = None + self._session_id: Optional[str] = None + self._resume_gateway_url: Optional[str] = None + + self._keep_alive_task: Optional[asyncio.Task] = None + self._receive_task: Optional[asyncio.Task] = None + + self._last_heartbeat_sent: Optional[float] = None + self._last_heartbeat_ack: Optional[float] = None + + # For zlib decompression + self._buffer = bytearray() + self._inflator = zlib.decompressobj() + self._member_chunk_requests: Dict[str, asyncio.Future] = {} - async def _reconnect(self) -> None: - """Attempts to reconnect using exponential backoff with jitter.""" - delay = 1.0 - for attempt in range(self._max_retries): - try: - await self.connect() - return - except Exception as e: # noqa: BLE001 - if attempt >= self._max_retries - 1: - logger.error( - "Reconnect failed after %s attempts: %s", attempt + 1, e - ) - raise - jitter = random.uniform(0, delay) - wait_time = min(delay + jitter, self._max_backoff) - logger.warning( - "Reconnect attempt %s failed: %s. Retrying in %.2f seconds...", - attempt + 1, - e, - wait_time, - ) - await asyncio.sleep(wait_time) - delay = min(delay * 2, self._max_backoff) - - async def _decompress_message( - self, message_bytes: bytes - ) -> Optional[Dict[str, Any]]: - """Decompresses a zlib-compressed message from the Gateway.""" - self._buffer.extend(message_bytes) - - if len(message_bytes) < 4 or message_bytes[-4:] != ZLIB_SUFFIX: - # Message is not complete or not zlib compressed in the expected way - return None - # Or handle partial messages if Discord ever sends them fragmented like this, - # but typically each binary message is a complete zlib stream. - - try: - decompressed = self._inflator.decompress(self._buffer) - self._buffer.clear() # Reset buffer after successful decompression - return json.loads(decompressed.decode("utf-8")) - except zlib.error as e: - logger.error("Zlib decompression error: %s", e) - self._buffer.clear() # Clear buffer on error - self._inflator = zlib.decompressobj() # Reset inflator - return None - except json.JSONDecodeError as e: - logger.error("JSON decode error after decompression: %s", e) - return None - - async def _send_json(self, payload: Dict[str, Any]): - if self._ws and not self._ws.closed: - if self.verbose: - logger.debug("GATEWAY SEND: %s", payload) - await self._ws.send_json(payload) - else: - logger.warning( - "Gateway send attempted but WebSocket is closed or not available." - ) - # raise GatewayException("WebSocket is not connected.") - - async def _heartbeat(self): - """Sends a heartbeat to the Gateway.""" - self._last_heartbeat_sent = time.monotonic() - payload = {"op": GatewayOpcode.HEARTBEAT, "d": self._last_sequence} - await self._send_json(payload) - # print("Sent heartbeat.") - - async def _keep_alive(self): - """Manages the heartbeating loop.""" - if self._heartbeat_interval is None: - # This should not happen if HELLO was processed correctly - logger.error("Heartbeat interval not set. Cannot start keep_alive.") - return - - try: - while True: - await self._heartbeat() - await asyncio.sleep( - self._heartbeat_interval / 1000 - ) # Interval is in ms - except asyncio.CancelledError: - logger.debug("Keep_alive task cancelled.") - except Exception as e: - logger.error("Error in keep_alive loop: %s", e) - # Potentially trigger a reconnect here or notify client - await self._client_instance.close_gateway(code=1000) # Generic close - - async def _identify(self): - """Sends the IDENTIFY payload to the Gateway.""" - payload = { - "op": GatewayOpcode.IDENTIFY, - "d": { - "token": self._token, - "intents": self._intents, - "properties": { - "$os": "python", # Or platform.system() - "$browser": "disagreement", # Library name - "$device": "disagreement", # Library name - }, - "compress": True, # Request zlib compression - }, - } - if self._shard_id is not None and self._shard_count is not None: - payload["d"]["shard"] = [self._shard_id, self._shard_count] - await self._send_json(payload) - logger.info("Sent IDENTIFY.") - - async def _resume(self): - """Sends the RESUME payload to the Gateway.""" - if not self._session_id or self._last_sequence is None: - logger.warning("Cannot RESUME: session_id or last_sequence is missing.") - await self._identify() # Fallback to identify - return - - payload = { - "op": GatewayOpcode.RESUME, - "d": { - "token": self._token, - "session_id": self._session_id, - "seq": self._last_sequence, - }, - } - await self._send_json(payload) - logger.info( - "Sent RESUME for session %s at sequence %s.", - self._session_id, - self._last_sequence, - ) - - async def update_presence( - self, - status: str, - activity_name: Optional[str] = None, - activity_type: int = 0, - since: int = 0, - afk: bool = False, - ): - """Sends the presence update payload to the Gateway.""" - payload = { - "op": GatewayOpcode.PRESENCE_UPDATE, - "d": { - "since": since, - "activities": ( - [ - { - "name": activity_name, - "type": activity_type, - } - ] - if activity_name - else [] - ), - "status": status, - "afk": afk, - }, - } - await self._send_json(payload) - + async def _reconnect(self) -> None: + """Attempts to reconnect using exponential backoff with jitter.""" + delay = 1.0 + for attempt in range(self._max_retries): + try: + await self.connect() + return + except Exception as e: # noqa: BLE001 + if attempt >= self._max_retries - 1: + logger.error( + "Reconnect failed after %s attempts: %s", attempt + 1, e + ) + raise + jitter = random.uniform(0, delay) + wait_time = min(delay + jitter, self._max_backoff) + logger.warning( + "Reconnect attempt %s failed: %s. Retrying in %.2f seconds...", + attempt + 1, + e, + wait_time, + ) + await asyncio.sleep(wait_time) + delay = min(delay * 2, self._max_backoff) + + async def _decompress_message( + self, message_bytes: bytes + ) -> Optional[Dict[str, Any]]: + """Decompresses a zlib-compressed message from the Gateway.""" + self._buffer.extend(message_bytes) + + if len(message_bytes) < 4 or message_bytes[-4:] != ZLIB_SUFFIX: + # Message is not complete or not zlib compressed in the expected way + return None + # Or handle partial messages if Discord ever sends them fragmented like this, + # but typically each binary message is a complete zlib stream. + + try: + decompressed = self._inflator.decompress(self._buffer) + self._buffer.clear() # Reset buffer after successful decompression + return json.loads(decompressed.decode("utf-8")) + except zlib.error as e: + logger.error("Zlib decompression error: %s", e) + self._buffer.clear() # Clear buffer on error + self._inflator = zlib.decompressobj() # Reset inflator + return None + except json.JSONDecodeError as e: + logger.error("JSON decode error after decompression: %s", e) + return None + + async def _send_json(self, payload: Dict[str, Any]): + if self._ws and not self._ws.closed: + if self.verbose: + logger.debug("GATEWAY SEND: %s", payload) + await self._ws.send_json(payload) + else: + logger.warning( + "Gateway send attempted but WebSocket is closed or not available." + ) + # raise GatewayException("WebSocket is not connected.") + + async def _heartbeat(self): + """Sends a heartbeat to the Gateway.""" + self._last_heartbeat_sent = time.monotonic() + payload = {"op": GatewayOpcode.HEARTBEAT, "d": self._last_sequence} + await self._send_json(payload) + # print("Sent heartbeat.") + + async def _keep_alive(self): + """Manages the heartbeating loop.""" + if self._heartbeat_interval is None: + # This should not happen if HELLO was processed correctly + logger.error("Heartbeat interval not set. Cannot start keep_alive.") + return + + try: + while True: + await self._heartbeat() + await asyncio.sleep( + self._heartbeat_interval / 1000 + ) # Interval is in ms + except asyncio.CancelledError: + logger.debug("Keep_alive task cancelled.") + except Exception as e: + logger.error("Error in keep_alive loop: %s", e) + # Potentially trigger a reconnect here or notify client + await self._client_instance.close_gateway(code=1000) # Generic close + + async def _identify(self): + """Sends the IDENTIFY payload to the Gateway.""" + payload = { + "op": GatewayOpcode.IDENTIFY, + "d": { + "token": self._token, + "intents": self._intents, + "properties": { + "$os": "python", # Or platform.system() + "$browser": "disagreement", # Library name + "$device": "disagreement", # Library name + }, + "compress": True, # Request zlib compression + }, + } + if self._shard_id is not None and self._shard_count is not None: + payload["d"]["shard"] = [self._shard_id, self._shard_count] + await self._send_json(payload) + logger.info("Sent IDENTIFY.") + + async def _resume(self): + """Sends the RESUME payload to the Gateway.""" + if not self._session_id or self._last_sequence is None: + logger.warning("Cannot RESUME: session_id or last_sequence is missing.") + await self._identify() # Fallback to identify + return + + payload = { + "op": GatewayOpcode.RESUME, + "d": { + "token": self._token, + "session_id": self._session_id, + "seq": self._last_sequence, + }, + } + await self._send_json(payload) + logger.info( + "Sent RESUME for session %s at sequence %s.", + self._session_id, + self._last_sequence, + ) + + async def update_presence( + self, + status: str, + activity_name: Optional[str] = None, + activity_type: int = 0, + since: int = 0, + afk: bool = False, + ): + """Sends the presence update payload to the Gateway.""" + payload = { + "op": GatewayOpcode.PRESENCE_UPDATE, + "d": { + "since": since, + "activities": ( + [ + { + "name": activity_name, + "type": activity_type, + } + ] + if activity_name + else [] + ), + "status": status, + "afk": afk, + }, + } + await self._send_json(payload) + async def request_guild_members( self, guild_id: str, @@ -265,82 +265,82 @@ class GatewayClient: await self._send_json(payload) - async def _handle_dispatch(self, data: Dict[str, Any]): - """Handles DISPATCH events (actual Discord events).""" - event_name = data.get("t") - sequence_num = data.get("s") - raw_event_d_payload = data.get( - "d" - ) # This is the 'd' field from the gateway event - - if sequence_num is not None: - self._last_sequence = sequence_num - - if event_name == "READY": # Special handling for READY - if not isinstance(raw_event_d_payload, dict): - logger.error( - "READY event 'd' payload is not a dict or is missing: %s", - raw_event_d_payload, - ) - # Consider raising an error or attempting a reconnect - return - self._session_id = raw_event_d_payload.get("session_id") - self._resume_gateway_url = raw_event_d_payload.get("resume_gateway_url") - - app_id_str = "N/A" - # Store application_id on the client instance - if ( - "application" in raw_event_d_payload - and isinstance(raw_event_d_payload["application"], dict) - and "id" in raw_event_d_payload["application"] - ): - app_id_value = raw_event_d_payload["application"]["id"] - self._client_instance.application_id = ( - app_id_value # Snowflake can be str or int - ) - app_id_str = str(app_id_value) - else: - logger.warning( - "Could not find application ID in READY payload. App commands may not work." - ) - - # Parse and store the bot's own user object - if "user" in raw_event_d_payload and isinstance( - raw_event_d_payload["user"], dict - ): - try: - # Assuming Client has a parse_user method that takes user data dict - # and returns a User object, also caching it. - bot_user_obj = self._client_instance.parse_user( - raw_event_d_payload["user"] - ) - self._client_instance.user = bot_user_obj - logger.info( - "Gateway READY. Bot User: %s#%s. Session ID: %s. App ID: %s. Resume URL: %s", - bot_user_obj.username, - bot_user_obj.discriminator, - self._session_id, - app_id_str, - self._resume_gateway_url, - ) - except Exception as e: - logger.error("Error parsing bot user from READY payload: %s", e) - logger.info( - "Gateway READY (user parse failed). Session ID: %s. App ID: %s. Resume URL: %s", - self._session_id, - app_id_str, - self._resume_gateway_url, - ) - else: - logger.warning("Bot user object not found or invalid in READY payload.") - logger.info( - "Gateway READY (no user). Session ID: %s. App ID: %s. Resume URL: %s", - self._session_id, - app_id_str, - self._resume_gateway_url, - ) - - await self._dispatcher.dispatch(event_name, raw_event_d_payload) + async def _handle_dispatch(self, data: Dict[str, Any]): + """Handles DISPATCH events (actual Discord events).""" + event_name = data.get("t") + sequence_num = data.get("s") + raw_event_d_payload = data.get( + "d" + ) # This is the 'd' field from the gateway event + + if sequence_num is not None: + self._last_sequence = sequence_num + + if event_name == "READY": # Special handling for READY + if not isinstance(raw_event_d_payload, dict): + logger.error( + "READY event 'd' payload is not a dict or is missing: %s", + raw_event_d_payload, + ) + # Consider raising an error or attempting a reconnect + return + self._session_id = raw_event_d_payload.get("session_id") + self._resume_gateway_url = raw_event_d_payload.get("resume_gateway_url") + + app_id_str = "N/A" + # Store application_id on the client instance + if ( + "application" in raw_event_d_payload + and isinstance(raw_event_d_payload["application"], dict) + and "id" in raw_event_d_payload["application"] + ): + app_id_value = raw_event_d_payload["application"]["id"] + self._client_instance.application_id = ( + app_id_value # Snowflake can be str or int + ) + app_id_str = str(app_id_value) + else: + logger.warning( + "Could not find application ID in READY payload. App commands may not work." + ) + + # Parse and store the bot's own user object + if "user" in raw_event_d_payload and isinstance( + raw_event_d_payload["user"], dict + ): + try: + # Assuming Client has a parse_user method that takes user data dict + # and returns a User object, also caching it. + bot_user_obj = self._client_instance.parse_user( + raw_event_d_payload["user"] + ) + self._client_instance.user = bot_user_obj + logger.info( + "Gateway READY. Bot User: %s#%s. Session ID: %s. App ID: %s. Resume URL: %s", + bot_user_obj.username, + bot_user_obj.discriminator, + self._session_id, + app_id_str, + self._resume_gateway_url, + ) + except Exception as e: + logger.error("Error parsing bot user from READY payload: %s", e) + logger.info( + "Gateway READY (user parse failed). Session ID: %s. App ID: %s. Resume URL: %s", + self._session_id, + app_id_str, + self._resume_gateway_url, + ) + else: + logger.warning("Bot user object not found or invalid in READY payload.") + logger.info( + "Gateway READY (no user). Session ID: %s. App ID: %s. Resume URL: %s", + self._session_id, + app_id_str, + self._resume_gateway_url, + ) + + await self._dispatcher.dispatch(event_name, raw_event_d_payload) elif event_name == "GUILD_MEMBERS_CHUNK": if isinstance(raw_event_d_payload, dict): nonce = raw_event_d_payload.get("nonce") @@ -357,274 +357,274 @@ class GatewayClient: future.set_result(future._members) # type: ignore del self._member_chunk_requests[nonce] - elif event_name == "INTERACTION_CREATE": - # print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}") - if isinstance(raw_event_d_payload, dict): - interaction = Interaction( - data=raw_event_d_payload, client_instance=self._client_instance - ) - await self._dispatcher.dispatch( - "INTERACTION_CREATE", raw_event_d_payload - ) - # Dispatch to a new client method that will then call AppCommandHandler - if hasattr(self._client_instance, "process_interaction"): - asyncio.create_task( - self._client_instance.process_interaction(interaction) - ) # type: ignore - else: - logger.warning( - "Client instance does not have process_interaction method for INTERACTION_CREATE." - ) - else: - logger.error( - "INTERACTION_CREATE event 'd' payload is not a dict: %s", - raw_event_d_payload, - ) - elif event_name == "RESUMED": - logger.info("Gateway RESUMED successfully.") - # RESUMED 'd' payload is often an empty object or debug info. - # Ensure it's a dict for the dispatcher. - event_data_to_dispatch = ( - raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} - ) - await self._dispatcher.dispatch(event_name, event_data_to_dispatch) - await self._dispatcher.dispatch( - "SHARD_RESUME", {"shard_id": self._shard_id} - ) - elif event_name: - # For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing. - # Models/parsers in EventDispatcher will need to handle potentially empty dicts. - event_data_to_dispatch = ( - raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} - ) - # print(f"GATEWAY RECV EVENT: {event_name} | DATA: {event_data_to_dispatch}") - await self._dispatcher.dispatch(event_name, event_data_to_dispatch) - else: - logger.warning("Received dispatch with no event name: %s", data) - - async def _process_message(self, msg: aiohttp.WSMessage): - """Processes a single message from the WebSocket.""" - if msg.type == aiohttp.WSMsgType.TEXT: - try: - data = json.loads(msg.data) - except json.JSONDecodeError: - logger.error("Failed to decode JSON from Gateway: %s", msg.data[:200]) - return - elif msg.type == aiohttp.WSMsgType.BINARY: - decompressed_data = await self._decompress_message(msg.data) - if decompressed_data is None: - logger.error( - "Failed to decompress or decode binary message from Gateway." - ) - return - data = decompressed_data - elif msg.type == aiohttp.WSMsgType.ERROR: - logger.error( - "WebSocket error: %s", - self._ws.exception() if self._ws else "Unknown WSError", - ) - raise GatewayException( - f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}" - ) - elif msg.type == aiohttp.WSMsgType.CLOSED: - close_code = ( - self._ws.close_code - if self._ws and hasattr(self._ws, "close_code") - else "N/A" - ) - logger.warning( - "WebSocket connection closed by server. Code: %s", close_code - ) - # Raise an exception to signal the closure to the client's main run loop - raise GatewayException(f"WebSocket closed by server. Code: {close_code}") - else: - logger.warning("Received unhandled WebSocket message type: %s", msg.type) - return - - if self.verbose: - logger.debug("GATEWAY RECV: %s", data) - op = data.get("op") - # 'd' payload (event_data) is handled specifically by each opcode handler below - - if op == GatewayOpcode.DISPATCH: - await self._handle_dispatch(data) # _handle_dispatch will extract 'd' - elif op == GatewayOpcode.HEARTBEAT: # Server requests a heartbeat - await self._heartbeat() - elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect - logger.info( - "Gateway requested RECONNECT. Closing and will attempt to reconnect." - ) - await self.close(code=4000, reconnect=True) - elif op == GatewayOpcode.INVALID_SESSION: - # The 'd' payload for INVALID_SESSION is a boolean indicating resumability - can_resume = data.get("d") is True - logger.warning( - "Gateway indicated INVALID_SESSION. Resumable: %s", can_resume - ) - if not can_resume: - self._session_id = None # Clear session_id to force re-identify - self._last_sequence = None - # Close and reconnect. The connect logic will decide to resume or identify. - await self.close(code=4000 if can_resume else 4009, reconnect=True) - elif op == GatewayOpcode.HELLO: - hello_d_payload = data.get("d") - if ( - not isinstance(hello_d_payload, dict) - or "heartbeat_interval" not in hello_d_payload - ): - logger.error( - "HELLO event 'd' payload is invalid or missing heartbeat_interval: %s", - hello_d_payload, - ) - await self.close(code=1011) # Internal error, malformed HELLO - return - self._heartbeat_interval = hello_d_payload["heartbeat_interval"] - logger.info( - "Gateway HELLO. Heartbeat interval: %sms.", self._heartbeat_interval - ) - # Start heartbeating - if self._keep_alive_task: - self._keep_alive_task.cancel() - self._keep_alive_task = self._loop.create_task(self._keep_alive()) - - # Identify or Resume - if self._session_id and self._resume_gateway_url: # Check if we can resume - logger.info("Attempting to RESUME session.") - await self._resume() - else: - logger.info("Performing initial IDENTIFY.") - await self._identify() - elif op == GatewayOpcode.HEARTBEAT_ACK: - self._last_heartbeat_ack = time.monotonic() - # print("Received heartbeat ACK.") - pass # Good, connection is alive - else: - logger.warning( - "Received unhandled Gateway Opcode: %s with data: %s", op, data - ) - - async def _receive_loop(self): - """Continuously receives and processes messages from the WebSocket.""" - if not self._ws or self._ws.closed: - logger.warning( - "Receive loop cannot start: WebSocket is not connected or closed." - ) - return - - try: - async for msg in self._ws: - await self._process_message(msg) - except asyncio.CancelledError: - logger.debug("Receive_loop task cancelled.") - except aiohttp.ClientConnectionError as e: - logger.warning( - "ClientConnectionError in receive_loop: %s. Attempting reconnect.", e - ) - await self.close(code=1006, reconnect=True) # Abnormal closure - except Exception as e: - logger.error("Unexpected error in receive_loop: %s", e) - traceback.print_exc() - await self.close(code=1011, reconnect=True) - finally: - logger.info("Receive_loop ended.") - # If the loop ends unexpectedly (not due to explicit close), - # the main client might want to try reconnecting. - - async def connect(self): - """Connects to the Discord Gateway.""" - if self._ws and not self._ws.closed: - logger.warning("Gateway already connected or connecting.") - return - - gateway_url = ( - self._resume_gateway_url or (await self._http.get_gateway_bot())["url"] - ) - if not gateway_url.endswith("?v=10&encoding=json&compress=zlib-stream"): - gateway_url += "?v=10&encoding=json&compress=zlib-stream" - - logger.info("Connecting to Gateway: %s", gateway_url) - try: - await self._http._ensure_session() # Ensure the HTTP client's session is active - assert ( - self._http._session is not None - ), "HTTPClient session not initialized after ensure_session" - self._ws = await self._http._session.ws_connect(gateway_url, max_msg_size=0) - logger.info("Gateway WebSocket connection established.") - - if self._receive_task: - self._receive_task.cancel() - self._receive_task = self._loop.create_task(self._receive_loop()) - - await self._dispatcher.dispatch( - "SHARD_CONNECT", {"shard_id": self._shard_id} - ) - - except aiohttp.ClientConnectorError as e: - raise GatewayException( - f"Failed to connect to Gateway (Connector Error): {e}" - ) from e - except aiohttp.WSServerHandshakeError as e: - if e.status == 401: # Unauthorized during handshake - raise AuthenticationError( - f"Gateway handshake failed (401 Unauthorized): {e.message}. Check your bot token." - ) from e - raise GatewayException( - f"Gateway handshake failed (Status: {e.status}): {e.message}" - ) from e - except Exception as e: # Catch other potential errors during connection - raise GatewayException( - f"An unexpected error occurred during Gateway connection: {e}" - ) from e - - async def close(self, code: int = 1000, *, reconnect: bool = False): - """Closes the Gateway connection.""" - logger.info("Closing Gateway connection with code %s...", code) - if self._keep_alive_task and not self._keep_alive_task.done(): - self._keep_alive_task.cancel() - try: - await self._keep_alive_task - except asyncio.CancelledError: - pass # Expected - - if self._receive_task and not self._receive_task.done(): - current = asyncio.current_task(loop=self._loop) - self._receive_task.cancel() - if self._receive_task is not current: - try: - await self._receive_task - except asyncio.CancelledError: - pass # Expected - - if self._ws and not self._ws.closed: - await self._ws.close(code=code) - logger.info("Gateway WebSocket closed.") - - self._ws = None - # Do not reset session_id, last_sequence, or resume_gateway_url here - # if the close code indicates a resumable disconnect (e.g. 4000-4009, or server-initiated RECONNECT) - # The connect logic will decide whether to resume or re-identify. - # However, if it's a non-resumable close (e.g. Invalid Session non-resumable), clear them. - if code == 4009: # Invalid session, not resumable - logger.info("Clearing session state due to non-resumable invalid session.") - self._session_id = None - self._last_sequence = None - self._resume_gateway_url = None # This might be re-fetched anyway - - await self._dispatcher.dispatch( - "SHARD_DISCONNECT", {"shard_id": self._shard_id} - ) - - @property - def latency(self) -> Optional[float]: - """Returns the latency between heartbeat and ACK in seconds.""" - if self._last_heartbeat_sent is None or self._last_heartbeat_ack is None: - return None - return self._last_heartbeat_ack - self._last_heartbeat_sent - - @property - def last_heartbeat_sent(self) -> Optional[float]: - return self._last_heartbeat_sent - - @property - def last_heartbeat_ack(self) -> Optional[float]: - return self._last_heartbeat_ack + elif event_name == "INTERACTION_CREATE": + # print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}") + if isinstance(raw_event_d_payload, dict): + interaction = Interaction( + data=raw_event_d_payload, client_instance=self._client_instance + ) + await self._dispatcher.dispatch( + "INTERACTION_CREATE", raw_event_d_payload + ) + # Dispatch to a new client method that will then call AppCommandHandler + if hasattr(self._client_instance, "process_interaction"): + asyncio.create_task( + self._client_instance.process_interaction(interaction) + ) # type: ignore + else: + logger.warning( + "Client instance does not have process_interaction method for INTERACTION_CREATE." + ) + else: + logger.error( + "INTERACTION_CREATE event 'd' payload is not a dict: %s", + raw_event_d_payload, + ) + elif event_name == "RESUMED": + logger.info("Gateway RESUMED successfully.") + # RESUMED 'd' payload is often an empty object or debug info. + # Ensure it's a dict for the dispatcher. + event_data_to_dispatch = ( + raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} + ) + await self._dispatcher.dispatch(event_name, event_data_to_dispatch) + await self._dispatcher.dispatch( + "SHARD_RESUME", {"shard_id": self._shard_id} + ) + elif event_name: + # For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing. + # Models/parsers in EventDispatcher will need to handle potentially empty dicts. + event_data_to_dispatch = ( + raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} + ) + # print(f"GATEWAY RECV EVENT: {event_name} | DATA: {event_data_to_dispatch}") + await self._dispatcher.dispatch(event_name, event_data_to_dispatch) + else: + logger.warning("Received dispatch with no event name: %s", data) + + async def _process_message(self, msg: aiohttp.WSMessage): + """Processes a single message from the WebSocket.""" + if msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(msg.data) + except json.JSONDecodeError: + logger.error("Failed to decode JSON from Gateway: %s", msg.data[:200]) + return + elif msg.type == aiohttp.WSMsgType.BINARY: + decompressed_data = await self._decompress_message(msg.data) + if decompressed_data is None: + logger.error( + "Failed to decompress or decode binary message from Gateway." + ) + return + data = decompressed_data + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error( + "WebSocket error: %s", + self._ws.exception() if self._ws else "Unknown WSError", + ) + raise GatewayException( + f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}" + ) + elif msg.type == aiohttp.WSMsgType.CLOSED: + close_code = ( + self._ws.close_code + if self._ws and hasattr(self._ws, "close_code") + else "N/A" + ) + logger.warning( + "WebSocket connection closed by server. Code: %s", close_code + ) + # Raise an exception to signal the closure to the client's main run loop + raise GatewayException(f"WebSocket closed by server. Code: {close_code}") + else: + logger.warning("Received unhandled WebSocket message type: %s", msg.type) + return + + if self.verbose: + logger.debug("GATEWAY RECV: %s", data) + op = data.get("op") + # 'd' payload (event_data) is handled specifically by each opcode handler below + + if op == GatewayOpcode.DISPATCH: + await self._handle_dispatch(data) # _handle_dispatch will extract 'd' + elif op == GatewayOpcode.HEARTBEAT: # Server requests a heartbeat + await self._heartbeat() + elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect + logger.info( + "Gateway requested RECONNECT. Closing and will attempt to reconnect." + ) + await self.close(code=4000, reconnect=True) + elif op == GatewayOpcode.INVALID_SESSION: + # The 'd' payload for INVALID_SESSION is a boolean indicating resumability + can_resume = data.get("d") is True + logger.warning( + "Gateway indicated INVALID_SESSION. Resumable: %s", can_resume + ) + if not can_resume: + self._session_id = None # Clear session_id to force re-identify + self._last_sequence = None + # Close and reconnect. The connect logic will decide to resume or identify. + await self.close(code=4000 if can_resume else 4009, reconnect=True) + elif op == GatewayOpcode.HELLO: + hello_d_payload = data.get("d") + if ( + not isinstance(hello_d_payload, dict) + or "heartbeat_interval" not in hello_d_payload + ): + logger.error( + "HELLO event 'd' payload is invalid or missing heartbeat_interval: %s", + hello_d_payload, + ) + await self.close(code=1011) # Internal error, malformed HELLO + return + self._heartbeat_interval = hello_d_payload["heartbeat_interval"] + logger.info( + "Gateway HELLO. Heartbeat interval: %sms.", self._heartbeat_interval + ) + # Start heartbeating + if self._keep_alive_task: + self._keep_alive_task.cancel() + self._keep_alive_task = self._loop.create_task(self._keep_alive()) + + # Identify or Resume + if self._session_id and self._resume_gateway_url: # Check if we can resume + logger.info("Attempting to RESUME session.") + await self._resume() + else: + logger.info("Performing initial IDENTIFY.") + await self._identify() + elif op == GatewayOpcode.HEARTBEAT_ACK: + self._last_heartbeat_ack = time.monotonic() + # print("Received heartbeat ACK.") + pass # Good, connection is alive + else: + logger.warning( + "Received unhandled Gateway Opcode: %s with data: %s", op, data + ) + + async def _receive_loop(self): + """Continuously receives and processes messages from the WebSocket.""" + if not self._ws or self._ws.closed: + logger.warning( + "Receive loop cannot start: WebSocket is not connected or closed." + ) + return + + try: + async for msg in self._ws: + await self._process_message(msg) + except asyncio.CancelledError: + logger.debug("Receive_loop task cancelled.") + except aiohttp.ClientConnectionError as e: + logger.warning( + "ClientConnectionError in receive_loop: %s. Attempting reconnect.", e + ) + await self.close(code=1006, reconnect=True) # Abnormal closure + except Exception as e: + logger.error("Unexpected error in receive_loop: %s", e) + traceback.print_exc() + await self.close(code=1011, reconnect=True) + finally: + logger.info("Receive_loop ended.") + # If the loop ends unexpectedly (not due to explicit close), + # the main client might want to try reconnecting. + + async def connect(self): + """Connects to the Discord Gateway.""" + if self._ws and not self._ws.closed: + logger.warning("Gateway already connected or connecting.") + return + + gateway_url = ( + self._resume_gateway_url or (await self._http.get_gateway_bot())["url"] + ) + if not gateway_url.endswith("?v=10&encoding=json&compress=zlib-stream"): + gateway_url += "?v=10&encoding=json&compress=zlib-stream" + + logger.info("Connecting to Gateway: %s", gateway_url) + try: + await self._http._ensure_session() # Ensure the HTTP client's session is active + assert ( + self._http._session is not None + ), "HTTPClient session not initialized after ensure_session" + self._ws = await self._http._session.ws_connect(gateway_url, max_msg_size=0) + logger.info("Gateway WebSocket connection established.") + + if self._receive_task: + self._receive_task.cancel() + self._receive_task = self._loop.create_task(self._receive_loop()) + + await self._dispatcher.dispatch( + "SHARD_CONNECT", {"shard_id": self._shard_id} + ) + + except aiohttp.ClientConnectorError as e: + raise GatewayException( + f"Failed to connect to Gateway (Connector Error): {e}" + ) from e + except aiohttp.WSServerHandshakeError as e: + if e.status == 401: # Unauthorized during handshake + raise AuthenticationError( + f"Gateway handshake failed (401 Unauthorized): {e.message}. Check your bot token." + ) from e + raise GatewayException( + f"Gateway handshake failed (Status: {e.status}): {e.message}" + ) from e + except Exception as e: # Catch other potential errors during connection + raise GatewayException( + f"An unexpected error occurred during Gateway connection: {e}" + ) from e + + async def close(self, code: int = 1000, *, reconnect: bool = False): + """Closes the Gateway connection.""" + logger.info("Closing Gateway connection with code %s...", code) + if self._keep_alive_task and not self._keep_alive_task.done(): + self._keep_alive_task.cancel() + try: + await self._keep_alive_task + except asyncio.CancelledError: + pass # Expected + + if self._receive_task and not self._receive_task.done(): + current = asyncio.current_task(loop=self._loop) + self._receive_task.cancel() + if self._receive_task is not current: + try: + await self._receive_task + except asyncio.CancelledError: + pass # Expected + + if self._ws and not self._ws.closed: + await self._ws.close(code=code) + logger.info("Gateway WebSocket closed.") + + self._ws = None + # Do not reset session_id, last_sequence, or resume_gateway_url here + # if the close code indicates a resumable disconnect (e.g. 4000-4009, or server-initiated RECONNECT) + # The connect logic will decide whether to resume or re-identify. + # However, if it's a non-resumable close (e.g. Invalid Session non-resumable), clear them. + if code == 4009: # Invalid session, not resumable + logger.info("Clearing session state due to non-resumable invalid session.") + self._session_id = None + self._last_sequence = None + self._resume_gateway_url = None # This might be re-fetched anyway + + await self._dispatcher.dispatch( + "SHARD_DISCONNECT", {"shard_id": self._shard_id} + ) + + @property + def latency(self) -> Optional[float]: + """Returns the latency between heartbeat and ACK in seconds.""" + if self._last_heartbeat_sent is None or self._last_heartbeat_ack is None: + return None + return self._last_heartbeat_ack - self._last_heartbeat_sent + + @property + def last_heartbeat_sent(self) -> Optional[float]: + return self._last_heartbeat_sent + + @property + def last_heartbeat_ack(self) -> Optional[float]: + return self._last_heartbeat_ack diff --git a/disagreement/http.py b/disagreement/http.py index 354bff3..3a81990 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -1,373 +1,373 @@ -# disagreement/http.py - -""" -HTTP client for interacting with the Discord REST API. -""" - -import asyncio -import logging -import aiohttp # pylint: disable=import-error -import json -from urllib.parse import quote -from typing import Optional, Dict, Any, Union, TYPE_CHECKING, List - -from .errors import ( - HTTPException, - RateLimitError, - AuthenticationError, - DisagreementException, -) -from . import __version__ # For User-Agent -from .rate_limiter import RateLimiter -from .interactions import InteractionResponsePayload - -if TYPE_CHECKING: - from .client import Client - from .models import Message, Webhook, File, StageInstance, Invite - from .interactions import ApplicationCommand, Snowflake - -# Discord API constants -API_BASE_URL = "https://discord.com/api/v10" # Using API v10 - -logger = logging.getLogger(__name__) - - -class HTTPClient: - """Handles HTTP requests to the Discord API.""" - - def __init__( - self, - token: str, - client_session: Optional[aiohttp.ClientSession] = None, - verbose: bool = False, - **session_kwargs: Any, - ): - """Create a new HTTP client. - - Parameters - ---------- - token: - Bot token for authentication. - client_session: - Optional existing :class:`aiohttp.ClientSession`. - verbose: - If ``True``, log HTTP requests and responses. - **session_kwargs: - Additional options forwarded to :class:`aiohttp.ClientSession`, such - as ``proxy`` or ``connector``. - """ - - self.token = token - self._session: Optional[aiohttp.ClientSession] = client_session - self._session_kwargs: Dict[str, Any] = session_kwargs - self.user_agent = f"DiscordBot (https://github.com/Slipstreamm/disagreement, {__version__})" # Customize URL - - self.verbose = verbose - - self._rate_limiter = RateLimiter() - - async def _ensure_session(self): - if self._session is None or self._session.closed: - self._session = aiohttp.ClientSession(**self._session_kwargs) - - async def close(self): - """Closes the underlying aiohttp.ClientSession.""" - if self._session and not self._session.closed: - await self._session.close() - - async def request( - self, - method: str, - endpoint: str, - payload: Optional[ - Union[Dict[str, Any], List[Dict[str, Any]], aiohttp.FormData] - ] = None, - params: Optional[Dict[str, Any]] = None, - is_json: bool = True, - use_auth_header: bool = True, - custom_headers: Optional[Dict[str, str]] = None, - ) -> Any: - """Makes an HTTP request to the Discord API.""" - await self._ensure_session() - - url = f"{API_BASE_URL}{endpoint}" - final_headers: Dict[str, str] = { # Renamed to final_headers - "User-Agent": self.user_agent, - } - if use_auth_header: - final_headers["Authorization"] = f"Bot {self.token}" - - if is_json and payload: - final_headers["Content-Type"] = "application/json" - - if custom_headers: # Merge custom headers - final_headers.update(custom_headers) - - if self.verbose: - logger.debug( - "HTTP REQUEST: %s %s | payload=%s params=%s", - method, - url, - payload, - params, - ) - - route = f"{method.upper()}:{endpoint}" - - for attempt in range(5): # Max 5 retries for rate limits - await self._rate_limiter.acquire(route) - assert self._session is not None, "ClientSession not initialized" - async with self._session.request( - method, - url, - json=payload if is_json else None, - data=payload if not is_json else None, - headers=final_headers, - params=params, - ) as response: - - data = None - try: - if response.headers.get("Content-Type", "").startswith( - "application/json" - ): - data = await response.json() - else: - # For non-JSON responses, like fetching images or other files - # We might return the raw response or handle it differently - # For now, let's assume most API calls expect JSON - data = await response.text() - except (aiohttp.ContentTypeError, json.JSONDecodeError): - data = ( - await response.text() - ) # Fallback to text if JSON parsing fails - - if self.verbose: - logger.debug( - "HTTP RESPONSE: %s %s | %s", response.status, url, data - ) - - self._rate_limiter.release(route, response.headers) - - if 200 <= response.status < 300: - if response.status == 204: - return None - return data - - # Rate limit handling - if response.status == 429: # Rate limited - retry_after_str = response.headers.get("Retry-After", "1") - try: - retry_after = float(retry_after_str) - except ValueError: - retry_after = 1.0 # Default retry if header is malformed - - is_global = ( - response.headers.get("X-RateLimit-Global", "false").lower() - == "true" - ) - - error_message = f"Rate limited on {method} {endpoint}." - if data and isinstance(data, dict) and "message" in data: - error_message += f" Discord says: {data['message']}" - - await self._rate_limiter.handle_rate_limit( - route, retry_after, is_global - ) - - if attempt < 4: # Don't log on the last attempt before raising - logger.warning( - "%s Retrying after %ss (Attempt %s/5). Global: %s", - error_message, - retry_after, - attempt + 1, - is_global, - ) - continue # Retry the request - else: # Last attempt failed - raise RateLimitError( - response, - message=error_message, - retry_after=retry_after, - is_global=is_global, - ) - - # Other error handling - if response.status == 401: # Unauthorized - raise AuthenticationError(response, "Invalid token provided.") - if response.status == 403: # Forbidden - raise HTTPException( - response, - "Missing permissions or access denied.", - status=response.status, - text=str(data), - ) - - # General HTTP error - error_text = str(data) if data else "Unknown error" - discord_error_code = ( - data.get("code") if isinstance(data, dict) else None - ) - raise HTTPException( - response, - f"API Error on {method} {endpoint}: {error_text}", - status=response.status, - text=error_text, - error_code=discord_error_code, - ) - - # Should not be reached if retries are exhausted by RateLimitError - raise DisagreementException( - f"Failed request to {method} {endpoint} after multiple retries." - ) - - # --- Specific API call methods --- - - async def get_gateway_bot(self) -> Dict[str, Any]: - """Gets the WSS URL and sharding information for the Gateway.""" - return await self.request("GET", "/gateway/bot") - - async def send_message( - self, - channel_id: str, - content: Optional[str] = None, - tts: bool = False, - embeds: Optional[List[Dict[str, Any]]] = None, - components: Optional[List[Dict[str, Any]]] = None, - allowed_mentions: Optional[dict] = None, - message_reference: Optional[Dict[str, Any]] = None, - attachments: Optional[List[Any]] = None, - files: Optional[List[Any]] = None, - flags: Optional[int] = None, - ) -> Dict[str, Any]: - """Sends a message to a channel. - - Parameters - ---------- - attachments: - A list of attachment payloads to include with the message. - files: - A list of :class:`File` objects containing binary data to upload. - - Returns - ------- - Dict[str, Any] - The created message data. - """ - payload: Dict[str, Any] = {} - if content is not None: # Content is optional if embeds/components are present - payload["content"] = content - if tts: - payload["tts"] = True - if embeds: - payload["embeds"] = embeds - if components: - payload["components"] = components - if allowed_mentions: - payload["allowed_mentions"] = allowed_mentions - all_files: List["File"] = [] - if attachments is not None: - payload["attachments"] = [] - for a in attachments: - if hasattr(a, "data") and hasattr(a, "filename"): - idx = len(all_files) - all_files.append(a) - payload["attachments"].append({"id": idx, "filename": a.filename}) - else: - payload["attachments"].append( - a.to_dict() if hasattr(a, "to_dict") else a - ) - if files is not None: - for f in files: - if hasattr(f, "data") and hasattr(f, "filename"): - idx = len(all_files) - all_files.append(f) - if "attachments" not in payload: - payload["attachments"] = [] - payload["attachments"].append({"id": idx, "filename": f.filename}) - else: - raise TypeError("files must be File objects") - if flags: - payload["flags"] = flags - if message_reference: - payload["message_reference"] = message_reference - - if not payload: - raise ValueError("Message must have content, embeds, or components.") - - if all_files: - form = aiohttp.FormData() - form.add_field( - "payload_json", json.dumps(payload), content_type="application/json" - ) - for idx, f in enumerate(all_files): - form.add_field( - f"files[{idx}]", - f.data, - filename=f.filename, - content_type="application/octet-stream", - ) - return await self.request( - "POST", - f"/channels/{channel_id}/messages", - payload=form, - is_json=False, - ) - - return await self.request( - "POST", f"/channels/{channel_id}/messages", payload=payload - ) - - async def edit_message( - self, - channel_id: str, - message_id: str, - payload: Dict[str, Any], - ) -> Dict[str, Any]: - """Edits a message in a channel.""" - - return await self.request( - "PATCH", - f"/channels/{channel_id}/messages/{message_id}", - payload=payload, - ) - - async def get_message( - self, channel_id: "Snowflake", message_id: "Snowflake" - ) -> Dict[str, Any]: - """Fetches a message from a channel.""" - - return await self.request( - "GET", f"/channels/{channel_id}/messages/{message_id}" - ) - - async def delete_message( - self, channel_id: "Snowflake", message_id: "Snowflake" - ) -> None: - """Deletes a message in a channel.""" - - await self.request("DELETE", f"/channels/{channel_id}/messages/{message_id}") - - async def create_reaction( - self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str - ) -> None: - """Adds a reaction to a message as the current user.""" - encoded = quote(emoji) - await self.request( - "PUT", - f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me", - ) - - async def delete_reaction( - self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str - ) -> None: - """Removes the current user's reaction from a message.""" - encoded = quote(emoji) - await self.request( - "DELETE", - f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me", - ) - +# disagreement/http.py + +""" +HTTP client for interacting with the Discord REST API. +""" + +import asyncio +import logging +import aiohttp # pylint: disable=import-error +import json +from urllib.parse import quote +from typing import Optional, Dict, Any, Union, TYPE_CHECKING, List + +from .errors import ( + HTTPException, + RateLimitError, + AuthenticationError, + DisagreementException, +) +from . import __version__ # For User-Agent +from .rate_limiter import RateLimiter +from .interactions import InteractionResponsePayload + +if TYPE_CHECKING: + from .client import Client + from .models import Message, Webhook, File, StageInstance, Invite + from .interactions import ApplicationCommand, Snowflake + +# Discord API constants +API_BASE_URL = "https://discord.com/api/v10" # Using API v10 + +logger = logging.getLogger(__name__) + + +class HTTPClient: + """Handles HTTP requests to the Discord API.""" + + def __init__( + self, + token: str, + client_session: Optional[aiohttp.ClientSession] = None, + verbose: bool = False, + **session_kwargs: Any, + ): + """Create a new HTTP client. + + Parameters + ---------- + token: + Bot token for authentication. + client_session: + Optional existing :class:`aiohttp.ClientSession`. + verbose: + If ``True``, log HTTP requests and responses. + **session_kwargs: + Additional options forwarded to :class:`aiohttp.ClientSession`, such + as ``proxy`` or ``connector``. + """ + + self.token = token + self._session: Optional[aiohttp.ClientSession] = client_session + self._session_kwargs: Dict[str, Any] = session_kwargs + self.user_agent = f"DiscordBot (https://github.com/Slipstreamm/disagreement, {__version__})" # Customize URL + + self.verbose = verbose + + self._rate_limiter = RateLimiter() + + async def _ensure_session(self): + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession(**self._session_kwargs) + + async def close(self): + """Closes the underlying aiohttp.ClientSession.""" + if self._session and not self._session.closed: + await self._session.close() + + async def request( + self, + method: str, + endpoint: str, + payload: Optional[ + Union[Dict[str, Any], List[Dict[str, Any]], aiohttp.FormData] + ] = None, + params: Optional[Dict[str, Any]] = None, + is_json: bool = True, + use_auth_header: bool = True, + custom_headers: Optional[Dict[str, str]] = None, + ) -> Any: + """Makes an HTTP request to the Discord API.""" + await self._ensure_session() + + url = f"{API_BASE_URL}{endpoint}" + final_headers: Dict[str, str] = { # Renamed to final_headers + "User-Agent": self.user_agent, + } + if use_auth_header: + final_headers["Authorization"] = f"Bot {self.token}" + + if is_json and payload: + final_headers["Content-Type"] = "application/json" + + if custom_headers: # Merge custom headers + final_headers.update(custom_headers) + + if self.verbose: + logger.debug( + "HTTP REQUEST: %s %s | payload=%s params=%s", + method, + url, + payload, + params, + ) + + route = f"{method.upper()}:{endpoint}" + + for attempt in range(5): # Max 5 retries for rate limits + await self._rate_limiter.acquire(route) + assert self._session is not None, "ClientSession not initialized" + async with self._session.request( + method, + url, + json=payload if is_json else None, + data=payload if not is_json else None, + headers=final_headers, + params=params, + ) as response: + + data = None + try: + if response.headers.get("Content-Type", "").startswith( + "application/json" + ): + data = await response.json() + else: + # For non-JSON responses, like fetching images or other files + # We might return the raw response or handle it differently + # For now, let's assume most API calls expect JSON + data = await response.text() + except (aiohttp.ContentTypeError, json.JSONDecodeError): + data = ( + await response.text() + ) # Fallback to text if JSON parsing fails + + if self.verbose: + logger.debug( + "HTTP RESPONSE: %s %s | %s", response.status, url, data + ) + + self._rate_limiter.release(route, response.headers) + + if 200 <= response.status < 300: + if response.status == 204: + return None + return data + + # Rate limit handling + if response.status == 429: # Rate limited + retry_after_str = response.headers.get("Retry-After", "1") + try: + retry_after = float(retry_after_str) + except ValueError: + retry_after = 1.0 # Default retry if header is malformed + + is_global = ( + response.headers.get("X-RateLimit-Global", "false").lower() + == "true" + ) + + error_message = f"Rate limited on {method} {endpoint}." + if data and isinstance(data, dict) and "message" in data: + error_message += f" Discord says: {data['message']}" + + await self._rate_limiter.handle_rate_limit( + route, retry_after, is_global + ) + + if attempt < 4: # Don't log on the last attempt before raising + logger.warning( + "%s Retrying after %ss (Attempt %s/5). Global: %s", + error_message, + retry_after, + attempt + 1, + is_global, + ) + continue # Retry the request + else: # Last attempt failed + raise RateLimitError( + response, + message=error_message, + retry_after=retry_after, + is_global=is_global, + ) + + # Other error handling + if response.status == 401: # Unauthorized + raise AuthenticationError(response, "Invalid token provided.") + if response.status == 403: # Forbidden + raise HTTPException( + response, + "Missing permissions or access denied.", + status=response.status, + text=str(data), + ) + + # General HTTP error + error_text = str(data) if data else "Unknown error" + discord_error_code = ( + data.get("code") if isinstance(data, dict) else None + ) + raise HTTPException( + response, + f"API Error on {method} {endpoint}: {error_text}", + status=response.status, + text=error_text, + error_code=discord_error_code, + ) + + # Should not be reached if retries are exhausted by RateLimitError + raise DisagreementException( + f"Failed request to {method} {endpoint} after multiple retries." + ) + + # --- Specific API call methods --- + + async def get_gateway_bot(self) -> Dict[str, Any]: + """Gets the WSS URL and sharding information for the Gateway.""" + return await self.request("GET", "/gateway/bot") + + async def send_message( + self, + channel_id: str, + content: Optional[str] = None, + tts: bool = False, + embeds: Optional[List[Dict[str, Any]]] = None, + components: Optional[List[Dict[str, Any]]] = None, + allowed_mentions: Optional[dict] = None, + message_reference: Optional[Dict[str, Any]] = None, + attachments: Optional[List[Any]] = None, + files: Optional[List[Any]] = None, + flags: Optional[int] = None, + ) -> Dict[str, Any]: + """Sends a message to a channel. + + Parameters + ---------- + attachments: + A list of attachment payloads to include with the message. + files: + A list of :class:`File` objects containing binary data to upload. + + Returns + ------- + Dict[str, Any] + The created message data. + """ + payload: Dict[str, Any] = {} + if content is not None: # Content is optional if embeds/components are present + payload["content"] = content + if tts: + payload["tts"] = True + if embeds: + payload["embeds"] = embeds + if components: + payload["components"] = components + if allowed_mentions: + payload["allowed_mentions"] = allowed_mentions + all_files: List["File"] = [] + if attachments is not None: + payload["attachments"] = [] + for a in attachments: + if hasattr(a, "data") and hasattr(a, "filename"): + idx = len(all_files) + all_files.append(a) + payload["attachments"].append({"id": idx, "filename": a.filename}) + else: + payload["attachments"].append( + a.to_dict() if hasattr(a, "to_dict") else a + ) + if files is not None: + for f in files: + if hasattr(f, "data") and hasattr(f, "filename"): + idx = len(all_files) + all_files.append(f) + if "attachments" not in payload: + payload["attachments"] = [] + payload["attachments"].append({"id": idx, "filename": f.filename}) + else: + raise TypeError("files must be File objects") + if flags: + payload["flags"] = flags + if message_reference: + payload["message_reference"] = message_reference + + if not payload: + raise ValueError("Message must have content, embeds, or components.") + + if all_files: + form = aiohttp.FormData() + form.add_field( + "payload_json", json.dumps(payload), content_type="application/json" + ) + for idx, f in enumerate(all_files): + form.add_field( + f"files[{idx}]", + f.data, + filename=f.filename, + content_type="application/octet-stream", + ) + return await self.request( + "POST", + f"/channels/{channel_id}/messages", + payload=form, + is_json=False, + ) + + return await self.request( + "POST", f"/channels/{channel_id}/messages", payload=payload + ) + + async def edit_message( + self, + channel_id: str, + message_id: str, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Edits a message in a channel.""" + + return await self.request( + "PATCH", + f"/channels/{channel_id}/messages/{message_id}", + payload=payload, + ) + + async def get_message( + self, channel_id: "Snowflake", message_id: "Snowflake" + ) -> Dict[str, Any]: + """Fetches a message from a channel.""" + + return await self.request( + "GET", f"/channels/{channel_id}/messages/{message_id}" + ) + + async def delete_message( + self, channel_id: "Snowflake", message_id: "Snowflake" + ) -> None: + """Deletes a message in a channel.""" + + await self.request("DELETE", f"/channels/{channel_id}/messages/{message_id}") + + async def create_reaction( + self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str + ) -> None: + """Adds a reaction to a message as the current user.""" + encoded = quote(emoji) + await self.request( + "PUT", + f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me", + ) + + async def delete_reaction( + self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str + ) -> None: + """Removes the current user's reaction from a message.""" + encoded = quote(emoji) + await self.request( + "DELETE", + f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me", + ) + async def delete_user_reaction( self, channel_id: "Snowflake", @@ -382,38 +382,38 @@ class HTTPClient: f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}", ) - async def get_reactions( - self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str - ) -> List[Dict[str, Any]]: - """Fetches the users that reacted with a specific emoji.""" - encoded = quote(emoji) - return await self.request( - "GET", - f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}", - ) - - async def clear_reactions( - self, channel_id: "Snowflake", message_id: "Snowflake" - ) -> None: - """Removes all reactions from a message.""" - - await self.request( - "DELETE", - f"/channels/{channel_id}/messages/{message_id}/reactions", - ) - - async def bulk_delete_messages( - self, channel_id: "Snowflake", messages: List["Snowflake"] - ) -> List["Snowflake"]: - """Bulk deletes messages in a channel and returns their IDs.""" - - await self.request( - "POST", - f"/channels/{channel_id}/messages/bulk-delete", - payload={"messages": messages}, - ) - return messages - + async def get_reactions( + self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str + ) -> List[Dict[str, Any]]: + """Fetches the users that reacted with a specific emoji.""" + encoded = quote(emoji) + return await self.request( + "GET", + f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}", + ) + + async def clear_reactions( + self, channel_id: "Snowflake", message_id: "Snowflake" + ) -> None: + """Removes all reactions from a message.""" + + await self.request( + "DELETE", + f"/channels/{channel_id}/messages/{message_id}/reactions", + ) + + async def bulk_delete_messages( + self, channel_id: "Snowflake", messages: List["Snowflake"] + ) -> List["Snowflake"]: + """Bulk deletes messages in a channel and returns their IDs.""" + + await self.request( + "POST", + f"/channels/{channel_id}/messages/bulk-delete", + payload={"messages": messages}, + ) + return messages + async def get_pinned_messages( self, channel_id: "Snowflake" ) -> List[Dict[str, Any]]: @@ -435,26 +435,26 @@ class HTTPClient: await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}") - async def delete_channel( - self, channel_id: str, reason: Optional[str] = None - ) -> None: - """Deletes a channel. - - If the channel is a guild channel, requires the MANAGE_CHANNELS permission. - If the channel is a thread, requires the MANAGE_THREADS permission (if locked) or - be the thread creator (if not locked). - Deleting a category does not delete its child channels. - """ - custom_headers = {} - if reason: - custom_headers["X-Audit-Log-Reason"] = reason - - await self.request( - "DELETE", - f"/channels/{channel_id}", - custom_headers=custom_headers if custom_headers else None, - ) - + async def delete_channel( + self, channel_id: str, reason: Optional[str] = None + ) -> None: + """Deletes a channel. + + If the channel is a guild channel, requires the MANAGE_CHANNELS permission. + If the channel is a thread, requires the MANAGE_THREADS permission (if locked) or + be the thread creator (if not locked). + Deleting a category does not delete its child channels. + """ + custom_headers = {} + if reason: + custom_headers["X-Audit-Log-Reason"] = reason + + await self.request( + "DELETE", + f"/channels/{channel_id}", + custom_headers=custom_headers if custom_headers else None, + ) + async def edit_channel( self, channel_id: "Snowflake", @@ -470,625 +470,625 @@ class HTTPClient: custom_headers=headers, ) - async def get_channel(self, channel_id: str) -> Dict[str, Any]: - """Fetches a channel by ID.""" - return await self.request("GET", f"/channels/{channel_id}") - - async def get_channel_invites( - self, channel_id: "Snowflake" - ) -> List[Dict[str, Any]]: - """Fetches the invites for a channel.""" - - return await self.request("GET", f"/channels/{channel_id}/invites") - - async def create_invite( - self, channel_id: "Snowflake", payload: Dict[str, Any] - ) -> "Invite": - """Creates an invite for a channel.""" - - data = await self.request( - "POST", f"/channels/{channel_id}/invites", payload=payload - ) - from .models import Invite - - return Invite.from_dict(data) - - async def delete_invite(self, code: str) -> None: - """Deletes an invite by code.""" - - await self.request("DELETE", f"/invites/{code}") - - async def create_webhook( - self, channel_id: "Snowflake", payload: Dict[str, Any] - ) -> "Webhook": - """Creates a webhook in the specified channel.""" - - data = await self.request( - "POST", f"/channels/{channel_id}/webhooks", payload=payload - ) - from .models import Webhook - - return Webhook(data) - - async def edit_webhook( - self, webhook_id: "Snowflake", payload: Dict[str, Any] - ) -> "Webhook": - """Edits an existing webhook.""" - - data = await self.request("PATCH", f"/webhooks/{webhook_id}", payload=payload) - from .models import Webhook - - return Webhook(data) - - async def delete_webhook(self, webhook_id: "Snowflake") -> None: - """Deletes a webhook.""" - - await self.request("DELETE", f"/webhooks/{webhook_id}") - - async def execute_webhook( - self, - webhook_id: "Snowflake", - token: str, - *, - content: Optional[str] = None, - tts: bool = False, - embeds: Optional[List[Dict[str, Any]]] = None, - components: Optional[List[Dict[str, Any]]] = None, - allowed_mentions: Optional[dict] = None, - attachments: Optional[List[Any]] = None, - files: Optional[List[Any]] = None, - flags: Optional[int] = None, - username: Optional[str] = None, - avatar_url: Optional[str] = None, - ) -> Dict[str, Any]: - """Executes a webhook and returns the created message.""" - - payload: Dict[str, Any] = {} - if content is not None: - payload["content"] = content - if tts: - payload["tts"] = True - if embeds: - payload["embeds"] = embeds - if components: - payload["components"] = components - if allowed_mentions: - payload["allowed_mentions"] = allowed_mentions - if username: - payload["username"] = username - if avatar_url: - payload["avatar_url"] = avatar_url - - all_files: List["File"] = [] - if attachments is not None: - payload["attachments"] = [] - for a in attachments: - if hasattr(a, "data") and hasattr(a, "filename"): - idx = len(all_files) - all_files.append(a) - payload["attachments"].append({"id": idx, "filename": a.filename}) - else: - payload["attachments"].append( - a.to_dict() if hasattr(a, "to_dict") else a - ) - if files is not None: - for f in files: - if hasattr(f, "data") and hasattr(f, "filename"): - idx = len(all_files) - all_files.append(f) - if "attachments" not in payload: - payload["attachments"] = [] - payload["attachments"].append({"id": idx, "filename": f.filename}) - else: - raise TypeError("files must be File objects") - if flags: - payload["flags"] = flags - - if all_files: - form = aiohttp.FormData() - form.add_field( - "payload_json", json.dumps(payload), content_type="application/json" - ) - for idx, f in enumerate(all_files): - form.add_field( - f"files[{idx}]", - f.data, - filename=f.filename, - content_type="application/octet-stream", - ) - return await self.request( - "POST", - f"/webhooks/{webhook_id}/{token}", - payload=form, - is_json=False, - use_auth_header=False, - ) - - return await self.request( - "POST", - f"/webhooks/{webhook_id}/{token}", - payload=payload, - use_auth_header=False, - ) - - async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]: - """Fetches a user object for a given user ID.""" - return await self.request("GET", f"/users/{user_id}") - - async def get_guild_member( - self, guild_id: "Snowflake", user_id: "Snowflake" - ) -> Dict[str, Any]: - """Returns a guild member object for the specified user.""" - return await self.request("GET", f"/guilds/{guild_id}/members/{user_id}") - - async def kick_member( - self, guild_id: "Snowflake", user_id: "Snowflake", reason: Optional[str] = None - ) -> None: - """Kicks a member from the guild.""" - headers = {"X-Audit-Log-Reason": reason} if reason else None - await self.request( - "DELETE", - f"/guilds/{guild_id}/members/{user_id}", - custom_headers=headers, - ) - - async def ban_member( - self, - guild_id: "Snowflake", - user_id: "Snowflake", - *, - delete_message_seconds: int = 0, - reason: Optional[str] = None, - ) -> None: - """Bans a member from the guild.""" - payload = {} - if delete_message_seconds: - payload["delete_message_seconds"] = delete_message_seconds - headers = {"X-Audit-Log-Reason": reason} if reason else None - await self.request( - "PUT", - f"/guilds/{guild_id}/bans/{user_id}", - payload=payload if payload else None, - custom_headers=headers, - ) - - async def timeout_member( - self, - guild_id: "Snowflake", - user_id: "Snowflake", - *, - until: Optional[str], - reason: Optional[str] = None, - ) -> Dict[str, Any]: - """Times out a member until the given ISO8601 timestamp.""" - payload = {"communication_disabled_until": until} - headers = {"X-Audit-Log-Reason": reason} if reason else None - return await self.request( - "PATCH", - f"/guilds/{guild_id}/members/{user_id}", - payload=payload, - custom_headers=headers, - ) - - async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]: - """Returns a list of role objects for the guild.""" - return await self.request("GET", f"/guilds/{guild_id}/roles") - - async def get_guild(self, guild_id: "Snowflake") -> Dict[str, Any]: - """Fetches a guild object for a given guild ID.""" - return await self.request("GET", f"/guilds/{guild_id}") - - async def get_guild_templates(self, guild_id: "Snowflake") -> List[Dict[str, Any]]: - """Fetches all templates for the given guild.""" - return await self.request("GET", f"/guilds/{guild_id}/templates") - - async def create_guild_template( - self, guild_id: "Snowflake", payload: Dict[str, Any] - ) -> Dict[str, Any]: - """Creates a guild template.""" - return await self.request( - "POST", f"/guilds/{guild_id}/templates", payload=payload - ) - - async def sync_guild_template( - self, guild_id: "Snowflake", template_code: str - ) -> Dict[str, Any]: - """Syncs a guild template to the guild's current state.""" - return await self.request( - "PUT", - f"/guilds/{guild_id}/templates/{template_code}", - ) - - async def delete_guild_template( - self, guild_id: "Snowflake", template_code: str - ) -> None: - """Deletes a guild template.""" - await self.request("DELETE", f"/guilds/{guild_id}/templates/{template_code}") - - async def get_guild_scheduled_events( - self, guild_id: "Snowflake" - ) -> List[Dict[str, Any]]: - """Returns a list of scheduled events for the guild.""" - - return await self.request("GET", f"/guilds/{guild_id}/scheduled-events") - - async def get_guild_scheduled_event( - self, guild_id: "Snowflake", event_id: "Snowflake" - ) -> Dict[str, Any]: - """Returns a guild scheduled event.""" - - return await self.request( - "GET", f"/guilds/{guild_id}/scheduled-events/{event_id}" - ) - - async def create_guild_scheduled_event( - self, guild_id: "Snowflake", payload: Dict[str, Any] - ) -> Dict[str, Any]: - """Creates a guild scheduled event.""" - - return await self.request( - "POST", f"/guilds/{guild_id}/scheduled-events", payload=payload - ) - - async def edit_guild_scheduled_event( - self, guild_id: "Snowflake", event_id: "Snowflake", payload: Dict[str, Any] - ) -> Dict[str, Any]: - """Edits a guild scheduled event.""" - - return await self.request( - "PATCH", - f"/guilds/{guild_id}/scheduled-events/{event_id}", - payload=payload, - ) - - async def delete_guild_scheduled_event( - self, guild_id: "Snowflake", event_id: "Snowflake" - ) -> None: - """Deletes a guild scheduled event.""" - - await self.request("DELETE", f"/guilds/{guild_id}/scheduled-events/{event_id}") - - async def get_audit_logs( - self, guild_id: "Snowflake", **filters: Any - ) -> Dict[str, Any]: - """Fetches audit log entries for a guild.""" - params = {k: v for k, v in filters.items() if v is not None} - return await self.request( - "GET", - f"/guilds/{guild_id}/audit-logs", - params=params if params else None, - ) - - # Add other methods like: - # async def get_guild(self, guild_id: str) -> Dict[str, Any]: ... - # async def create_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: ... - # etc. - # --- Application Command Endpoints --- - - # Global Application Commands - async def get_global_application_commands( - self, application_id: "Snowflake", with_localizations: bool = False - ) -> List["ApplicationCommand"]: - """Fetches all global commands for your application.""" - params = {"with_localizations": str(with_localizations).lower()} - data = await self.request( - "GET", f"/applications/{application_id}/commands", params=params - ) - from .interactions import ApplicationCommand # Ensure constructor is available - - return [ApplicationCommand(cmd_data) for cmd_data in data] - - async def create_global_application_command( - self, application_id: "Snowflake", payload: Dict[str, Any] - ) -> "ApplicationCommand": - """Creates a new global command.""" - data = await self.request( - "POST", f"/applications/{application_id}/commands", payload=payload - ) - from .interactions import ApplicationCommand - - return ApplicationCommand(data) - - async def get_global_application_command( - self, application_id: "Snowflake", command_id: "Snowflake" - ) -> "ApplicationCommand": - """Fetches a specific global command.""" - data = await self.request( - "GET", f"/applications/{application_id}/commands/{command_id}" - ) - from .interactions import ApplicationCommand - - return ApplicationCommand(data) - - async def edit_global_application_command( - self, - application_id: "Snowflake", - command_id: "Snowflake", - payload: Dict[str, Any], - ) -> "ApplicationCommand": - """Edits a specific global command.""" - data = await self.request( - "PATCH", - f"/applications/{application_id}/commands/{command_id}", - payload=payload, - ) - from .interactions import ApplicationCommand - - return ApplicationCommand(data) - - async def delete_global_application_command( - self, application_id: "Snowflake", command_id: "Snowflake" - ) -> None: - """Deletes a specific global command.""" - await self.request( - "DELETE", f"/applications/{application_id}/commands/{command_id}" - ) - - async def bulk_overwrite_global_application_commands( - self, application_id: "Snowflake", payload: List[Dict[str, Any]] - ) -> List["ApplicationCommand"]: - """Bulk overwrites all global commands for your application.""" - data = await self.request( - "PUT", f"/applications/{application_id}/commands", payload=payload - ) - from .interactions import ApplicationCommand - - return [ApplicationCommand(cmd_data) for cmd_data in data] - - # Guild Application Commands - async def get_guild_application_commands( - self, - application_id: "Snowflake", - guild_id: "Snowflake", - with_localizations: bool = False, - ) -> List["ApplicationCommand"]: - """Fetches all commands for your application for a specific guild.""" - params = {"with_localizations": str(with_localizations).lower()} - data = await self.request( - "GET", - f"/applications/{application_id}/guilds/{guild_id}/commands", - params=params, - ) - from .interactions import ApplicationCommand - - return [ApplicationCommand(cmd_data) for cmd_data in data] - - async def create_guild_application_command( - self, - application_id: "Snowflake", - guild_id: "Snowflake", - payload: Dict[str, Any], - ) -> "ApplicationCommand": - """Creates a new guild command.""" - data = await self.request( - "POST", - f"/applications/{application_id}/guilds/{guild_id}/commands", - payload=payload, - ) - from .interactions import ApplicationCommand - - return ApplicationCommand(data) - - async def get_guild_application_command( - self, - application_id: "Snowflake", - guild_id: "Snowflake", - command_id: "Snowflake", - ) -> "ApplicationCommand": - """Fetches a specific guild command.""" - data = await self.request( - "GET", - f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", - ) - from .interactions import ApplicationCommand - - return ApplicationCommand(data) - - async def edit_guild_application_command( - self, - application_id: "Snowflake", - guild_id: "Snowflake", - command_id: "Snowflake", - payload: Dict[str, Any], - ) -> "ApplicationCommand": - """Edits a specific guild command.""" - data = await self.request( - "PATCH", - f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", - payload=payload, - ) - from .interactions import ApplicationCommand - - return ApplicationCommand(data) - - async def delete_guild_application_command( - self, - application_id: "Snowflake", - guild_id: "Snowflake", - command_id: "Snowflake", - ) -> None: - """Deletes a specific guild command.""" - await self.request( - "DELETE", - f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", - ) - - async def bulk_overwrite_guild_application_commands( - self, - application_id: "Snowflake", - guild_id: "Snowflake", - payload: List[Dict[str, Any]], - ) -> List["ApplicationCommand"]: - """Bulk overwrites all commands for your application for a specific guild.""" - data = await self.request( - "PUT", - f"/applications/{application_id}/guilds/{guild_id}/commands", - payload=payload, - ) - from .interactions import ApplicationCommand - - return [ApplicationCommand(cmd_data) for cmd_data in data] - - # --- Interaction Response Endpoints --- - # Note: These methods return Dict[str, Any] representing the Message data. - # The caller (e.g., AppCommandHandler) will be responsible for constructing Message models - # if needed, as Message model instantiation requires a `client_instance`. - - async def create_interaction_response( - self, - interaction_id: "Snowflake", - interaction_token: str, - payload: Union["InteractionResponsePayload", Dict[str, Any]], - *, - ephemeral: bool = False, - ) -> None: - """Creates a response to an Interaction. - - Parameters - ---------- - ephemeral: bool - Ignored parameter for test compatibility. - """ - # Interaction responses do not use the bot token in the Authorization header. - # They are authenticated by the interaction_token in the URL. - payload_data: Dict[str, Any] - if isinstance(payload, InteractionResponsePayload): - payload_data = payload.to_dict() - else: - payload_data = payload - - await self.request( - "POST", - f"/interactions/{interaction_id}/{interaction_token}/callback", - payload=payload_data, - use_auth_header=False, - ) - - async def get_original_interaction_response( - self, application_id: "Snowflake", interaction_token: str - ) -> Dict[str, Any]: - """Gets the initial Interaction response.""" - # This endpoint uses the bot token for auth. - return await self.request( - "GET", f"/webhooks/{application_id}/{interaction_token}/messages/@original" - ) - - async def edit_original_interaction_response( - self, - application_id: "Snowflake", - interaction_token: str, - payload: Dict[str, Any], - ) -> Dict[str, Any]: - """Edits the initial Interaction response.""" - return await self.request( - "PATCH", - f"/webhooks/{application_id}/{interaction_token}/messages/@original", - payload=payload, - use_auth_header=False, - ) # Docs imply webhook-style auth - - async def delete_original_interaction_response( - self, application_id: "Snowflake", interaction_token: str - ) -> None: - """Deletes the initial Interaction response.""" - await self.request( - "DELETE", - f"/webhooks/{application_id}/{interaction_token}/messages/@original", - use_auth_header=False, - ) # Docs imply webhook-style auth - - async def create_followup_message( - self, - application_id: "Snowflake", - interaction_token: str, - payload: Dict[str, Any], - ) -> Dict[str, Any]: - """Creates a followup message for an Interaction.""" - # Followup messages are sent to a webhook endpoint. - return await self.request( - "POST", - f"/webhooks/{application_id}/{interaction_token}", - payload=payload, - use_auth_header=False, - ) # Docs imply webhook-style auth - - async def edit_followup_message( - self, - application_id: "Snowflake", - interaction_token: str, - message_id: "Snowflake", - payload: Dict[str, Any], - ) -> Dict[str, Any]: - """Edits a followup message for an Interaction.""" - return await self.request( - "PATCH", - f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}", - payload=payload, - use_auth_header=False, - ) # Docs imply webhook-style auth - - async def delete_followup_message( - self, - application_id: "Snowflake", - interaction_token: str, - message_id: "Snowflake", - ) -> None: - """Deletes a followup message for an Interaction.""" - await self.request( - "DELETE", - f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}", - use_auth_header=False, - ) - - async def trigger_typing(self, channel_id: str) -> None: - """Sends a typing indicator to the specified channel.""" - await self.request("POST", f"/channels/{channel_id}/typing") - - async def start_stage_instance( - self, payload: Dict[str, Any], reason: Optional[str] = None - ) -> "StageInstance": - """Starts a stage instance.""" - - headers = {"X-Audit-Log-Reason": reason} if reason else None - data = await self.request( - "POST", "/stage-instances", payload=payload, custom_headers=headers - ) - from .models import StageInstance - - return StageInstance(data) - - async def edit_stage_instance( - self, - channel_id: "Snowflake", - payload: Dict[str, Any], - reason: Optional[str] = None, - ) -> "StageInstance": - """Edits an existing stage instance.""" - - headers = {"X-Audit-Log-Reason": reason} if reason else None - data = await self.request( - "PATCH", - f"/stage-instances/{channel_id}", - payload=payload, - custom_headers=headers, - ) - from .models import StageInstance - - return StageInstance(data) - - async def end_stage_instance( - self, channel_id: "Snowflake", reason: Optional[str] = None - ) -> None: - """Ends a stage instance.""" - - headers = {"X-Audit-Log-Reason": reason} if reason else None - await self.request( - "DELETE", f"/stage-instances/{channel_id}", custom_headers=headers - ) - - async def get_voice_regions(self) -> List[Dict[str, Any]]: - """Returns available voice regions.""" - return await self.request("GET", "/voice/regions") + async def get_channel(self, channel_id: str) -> Dict[str, Any]: + """Fetches a channel by ID.""" + return await self.request("GET", f"/channels/{channel_id}") + + async def get_channel_invites( + self, channel_id: "Snowflake" + ) -> List[Dict[str, Any]]: + """Fetches the invites for a channel.""" + + return await self.request("GET", f"/channels/{channel_id}/invites") + + async def create_invite( + self, channel_id: "Snowflake", payload: Dict[str, Any] + ) -> "Invite": + """Creates an invite for a channel.""" + + data = await self.request( + "POST", f"/channels/{channel_id}/invites", payload=payload + ) + from .models import Invite + + return Invite.from_dict(data) + + async def delete_invite(self, code: str) -> None: + """Deletes an invite by code.""" + + await self.request("DELETE", f"/invites/{code}") + + async def create_webhook( + self, channel_id: "Snowflake", payload: Dict[str, Any] + ) -> "Webhook": + """Creates a webhook in the specified channel.""" + + data = await self.request( + "POST", f"/channels/{channel_id}/webhooks", payload=payload + ) + from .models import Webhook + + return Webhook(data) + + async def edit_webhook( + self, webhook_id: "Snowflake", payload: Dict[str, Any] + ) -> "Webhook": + """Edits an existing webhook.""" + + data = await self.request("PATCH", f"/webhooks/{webhook_id}", payload=payload) + from .models import Webhook + + return Webhook(data) + + async def delete_webhook(self, webhook_id: "Snowflake") -> None: + """Deletes a webhook.""" + + await self.request("DELETE", f"/webhooks/{webhook_id}") + + async def execute_webhook( + self, + webhook_id: "Snowflake", + token: str, + *, + content: Optional[str] = None, + tts: bool = False, + embeds: Optional[List[Dict[str, Any]]] = None, + components: Optional[List[Dict[str, Any]]] = None, + allowed_mentions: Optional[dict] = None, + attachments: Optional[List[Any]] = None, + files: Optional[List[Any]] = None, + flags: Optional[int] = None, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + ) -> Dict[str, Any]: + """Executes a webhook and returns the created message.""" + + payload: Dict[str, Any] = {} + if content is not None: + payload["content"] = content + if tts: + payload["tts"] = True + if embeds: + payload["embeds"] = embeds + if components: + payload["components"] = components + if allowed_mentions: + payload["allowed_mentions"] = allowed_mentions + if username: + payload["username"] = username + if avatar_url: + payload["avatar_url"] = avatar_url + + all_files: List["File"] = [] + if attachments is not None: + payload["attachments"] = [] + for a in attachments: + if hasattr(a, "data") and hasattr(a, "filename"): + idx = len(all_files) + all_files.append(a) + payload["attachments"].append({"id": idx, "filename": a.filename}) + else: + payload["attachments"].append( + a.to_dict() if hasattr(a, "to_dict") else a + ) + if files is not None: + for f in files: + if hasattr(f, "data") and hasattr(f, "filename"): + idx = len(all_files) + all_files.append(f) + if "attachments" not in payload: + payload["attachments"] = [] + payload["attachments"].append({"id": idx, "filename": f.filename}) + else: + raise TypeError("files must be File objects") + if flags: + payload["flags"] = flags + + if all_files: + form = aiohttp.FormData() + form.add_field( + "payload_json", json.dumps(payload), content_type="application/json" + ) + for idx, f in enumerate(all_files): + form.add_field( + f"files[{idx}]", + f.data, + filename=f.filename, + content_type="application/octet-stream", + ) + return await self.request( + "POST", + f"/webhooks/{webhook_id}/{token}", + payload=form, + is_json=False, + use_auth_header=False, + ) + + return await self.request( + "POST", + f"/webhooks/{webhook_id}/{token}", + payload=payload, + use_auth_header=False, + ) + + async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]: + """Fetches a user object for a given user ID.""" + return await self.request("GET", f"/users/{user_id}") + + async def get_guild_member( + self, guild_id: "Snowflake", user_id: "Snowflake" + ) -> Dict[str, Any]: + """Returns a guild member object for the specified user.""" + return await self.request("GET", f"/guilds/{guild_id}/members/{user_id}") + + async def kick_member( + self, guild_id: "Snowflake", user_id: "Snowflake", reason: Optional[str] = None + ) -> None: + """Kicks a member from the guild.""" + headers = {"X-Audit-Log-Reason": reason} if reason else None + await self.request( + "DELETE", + f"/guilds/{guild_id}/members/{user_id}", + custom_headers=headers, + ) + + async def ban_member( + self, + guild_id: "Snowflake", + user_id: "Snowflake", + *, + delete_message_seconds: int = 0, + reason: Optional[str] = None, + ) -> None: + """Bans a member from the guild.""" + payload = {} + if delete_message_seconds: + payload["delete_message_seconds"] = delete_message_seconds + headers = {"X-Audit-Log-Reason": reason} if reason else None + await self.request( + "PUT", + f"/guilds/{guild_id}/bans/{user_id}", + payload=payload if payload else None, + custom_headers=headers, + ) + + async def timeout_member( + self, + guild_id: "Snowflake", + user_id: "Snowflake", + *, + until: Optional[str], + reason: Optional[str] = None, + ) -> Dict[str, Any]: + """Times out a member until the given ISO8601 timestamp.""" + payload = {"communication_disabled_until": until} + headers = {"X-Audit-Log-Reason": reason} if reason else None + return await self.request( + "PATCH", + f"/guilds/{guild_id}/members/{user_id}", + payload=payload, + custom_headers=headers, + ) + + async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]: + """Returns a list of role objects for the guild.""" + return await self.request("GET", f"/guilds/{guild_id}/roles") + + async def get_guild(self, guild_id: "Snowflake") -> Dict[str, Any]: + """Fetches a guild object for a given guild ID.""" + return await self.request("GET", f"/guilds/{guild_id}") + + async def get_guild_templates(self, guild_id: "Snowflake") -> List[Dict[str, Any]]: + """Fetches all templates for the given guild.""" + return await self.request("GET", f"/guilds/{guild_id}/templates") + + async def create_guild_template( + self, guild_id: "Snowflake", payload: Dict[str, Any] + ) -> Dict[str, Any]: + """Creates a guild template.""" + return await self.request( + "POST", f"/guilds/{guild_id}/templates", payload=payload + ) + + async def sync_guild_template( + self, guild_id: "Snowflake", template_code: str + ) -> Dict[str, Any]: + """Syncs a guild template to the guild's current state.""" + return await self.request( + "PUT", + f"/guilds/{guild_id}/templates/{template_code}", + ) + + async def delete_guild_template( + self, guild_id: "Snowflake", template_code: str + ) -> None: + """Deletes a guild template.""" + await self.request("DELETE", f"/guilds/{guild_id}/templates/{template_code}") + + async def get_guild_scheduled_events( + self, guild_id: "Snowflake" + ) -> List[Dict[str, Any]]: + """Returns a list of scheduled events for the guild.""" + + return await self.request("GET", f"/guilds/{guild_id}/scheduled-events") + + async def get_guild_scheduled_event( + self, guild_id: "Snowflake", event_id: "Snowflake" + ) -> Dict[str, Any]: + """Returns a guild scheduled event.""" + + return await self.request( + "GET", f"/guilds/{guild_id}/scheduled-events/{event_id}" + ) + + async def create_guild_scheduled_event( + self, guild_id: "Snowflake", payload: Dict[str, Any] + ) -> Dict[str, Any]: + """Creates a guild scheduled event.""" + + return await self.request( + "POST", f"/guilds/{guild_id}/scheduled-events", payload=payload + ) + + async def edit_guild_scheduled_event( + self, guild_id: "Snowflake", event_id: "Snowflake", payload: Dict[str, Any] + ) -> Dict[str, Any]: + """Edits a guild scheduled event.""" + + return await self.request( + "PATCH", + f"/guilds/{guild_id}/scheduled-events/{event_id}", + payload=payload, + ) + + async def delete_guild_scheduled_event( + self, guild_id: "Snowflake", event_id: "Snowflake" + ) -> None: + """Deletes a guild scheduled event.""" + + await self.request("DELETE", f"/guilds/{guild_id}/scheduled-events/{event_id}") + + async def get_audit_logs( + self, guild_id: "Snowflake", **filters: Any + ) -> Dict[str, Any]: + """Fetches audit log entries for a guild.""" + params = {k: v for k, v in filters.items() if v is not None} + return await self.request( + "GET", + f"/guilds/{guild_id}/audit-logs", + params=params if params else None, + ) + + # Add other methods like: + # async def get_guild(self, guild_id: str) -> Dict[str, Any]: ... + # async def create_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: ... + # etc. + # --- Application Command Endpoints --- + + # Global Application Commands + async def get_global_application_commands( + self, application_id: "Snowflake", with_localizations: bool = False + ) -> List["ApplicationCommand"]: + """Fetches all global commands for your application.""" + params = {"with_localizations": str(with_localizations).lower()} + data = await self.request( + "GET", f"/applications/{application_id}/commands", params=params + ) + from .interactions import ApplicationCommand # Ensure constructor is available + + return [ApplicationCommand(cmd_data) for cmd_data in data] + + async def create_global_application_command( + self, application_id: "Snowflake", payload: Dict[str, Any] + ) -> "ApplicationCommand": + """Creates a new global command.""" + data = await self.request( + "POST", f"/applications/{application_id}/commands", payload=payload + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def get_global_application_command( + self, application_id: "Snowflake", command_id: "Snowflake" + ) -> "ApplicationCommand": + """Fetches a specific global command.""" + data = await self.request( + "GET", f"/applications/{application_id}/commands/{command_id}" + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def edit_global_application_command( + self, + application_id: "Snowflake", + command_id: "Snowflake", + payload: Dict[str, Any], + ) -> "ApplicationCommand": + """Edits a specific global command.""" + data = await self.request( + "PATCH", + f"/applications/{application_id}/commands/{command_id}", + payload=payload, + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def delete_global_application_command( + self, application_id: "Snowflake", command_id: "Snowflake" + ) -> None: + """Deletes a specific global command.""" + await self.request( + "DELETE", f"/applications/{application_id}/commands/{command_id}" + ) + + async def bulk_overwrite_global_application_commands( + self, application_id: "Snowflake", payload: List[Dict[str, Any]] + ) -> List["ApplicationCommand"]: + """Bulk overwrites all global commands for your application.""" + data = await self.request( + "PUT", f"/applications/{application_id}/commands", payload=payload + ) + from .interactions import ApplicationCommand + + return [ApplicationCommand(cmd_data) for cmd_data in data] + + # Guild Application Commands + async def get_guild_application_commands( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + with_localizations: bool = False, + ) -> List["ApplicationCommand"]: + """Fetches all commands for your application for a specific guild.""" + params = {"with_localizations": str(with_localizations).lower()} + data = await self.request( + "GET", + f"/applications/{application_id}/guilds/{guild_id}/commands", + params=params, + ) + from .interactions import ApplicationCommand + + return [ApplicationCommand(cmd_data) for cmd_data in data] + + async def create_guild_application_command( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + payload: Dict[str, Any], + ) -> "ApplicationCommand": + """Creates a new guild command.""" + data = await self.request( + "POST", + f"/applications/{application_id}/guilds/{guild_id}/commands", + payload=payload, + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def get_guild_application_command( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + command_id: "Snowflake", + ) -> "ApplicationCommand": + """Fetches a specific guild command.""" + data = await self.request( + "GET", + f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def edit_guild_application_command( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + command_id: "Snowflake", + payload: Dict[str, Any], + ) -> "ApplicationCommand": + """Edits a specific guild command.""" + data = await self.request( + "PATCH", + f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", + payload=payload, + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def delete_guild_application_command( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + command_id: "Snowflake", + ) -> None: + """Deletes a specific guild command.""" + await self.request( + "DELETE", + f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", + ) + + async def bulk_overwrite_guild_application_commands( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + payload: List[Dict[str, Any]], + ) -> List["ApplicationCommand"]: + """Bulk overwrites all commands for your application for a specific guild.""" + data = await self.request( + "PUT", + f"/applications/{application_id}/guilds/{guild_id}/commands", + payload=payload, + ) + from .interactions import ApplicationCommand + + return [ApplicationCommand(cmd_data) for cmd_data in data] + + # --- Interaction Response Endpoints --- + # Note: These methods return Dict[str, Any] representing the Message data. + # The caller (e.g., AppCommandHandler) will be responsible for constructing Message models + # if needed, as Message model instantiation requires a `client_instance`. + + async def create_interaction_response( + self, + interaction_id: "Snowflake", + interaction_token: str, + payload: Union["InteractionResponsePayload", Dict[str, Any]], + *, + ephemeral: bool = False, + ) -> None: + """Creates a response to an Interaction. + + Parameters + ---------- + ephemeral: bool + Ignored parameter for test compatibility. + """ + # Interaction responses do not use the bot token in the Authorization header. + # They are authenticated by the interaction_token in the URL. + payload_data: Dict[str, Any] + if isinstance(payload, InteractionResponsePayload): + payload_data = payload.to_dict() + else: + payload_data = payload + + await self.request( + "POST", + f"/interactions/{interaction_id}/{interaction_token}/callback", + payload=payload_data, + use_auth_header=False, + ) + + async def get_original_interaction_response( + self, application_id: "Snowflake", interaction_token: str + ) -> Dict[str, Any]: + """Gets the initial Interaction response.""" + # This endpoint uses the bot token for auth. + return await self.request( + "GET", f"/webhooks/{application_id}/{interaction_token}/messages/@original" + ) + + async def edit_original_interaction_response( + self, + application_id: "Snowflake", + interaction_token: str, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Edits the initial Interaction response.""" + return await self.request( + "PATCH", + f"/webhooks/{application_id}/{interaction_token}/messages/@original", + payload=payload, + use_auth_header=False, + ) # Docs imply webhook-style auth + + async def delete_original_interaction_response( + self, application_id: "Snowflake", interaction_token: str + ) -> None: + """Deletes the initial Interaction response.""" + await self.request( + "DELETE", + f"/webhooks/{application_id}/{interaction_token}/messages/@original", + use_auth_header=False, + ) # Docs imply webhook-style auth + + async def create_followup_message( + self, + application_id: "Snowflake", + interaction_token: str, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Creates a followup message for an Interaction.""" + # Followup messages are sent to a webhook endpoint. + return await self.request( + "POST", + f"/webhooks/{application_id}/{interaction_token}", + payload=payload, + use_auth_header=False, + ) # Docs imply webhook-style auth + + async def edit_followup_message( + self, + application_id: "Snowflake", + interaction_token: str, + message_id: "Snowflake", + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Edits a followup message for an Interaction.""" + return await self.request( + "PATCH", + f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}", + payload=payload, + use_auth_header=False, + ) # Docs imply webhook-style auth + + async def delete_followup_message( + self, + application_id: "Snowflake", + interaction_token: str, + message_id: "Snowflake", + ) -> None: + """Deletes a followup message for an Interaction.""" + await self.request( + "DELETE", + f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}", + use_auth_header=False, + ) + + async def trigger_typing(self, channel_id: str) -> None: + """Sends a typing indicator to the specified channel.""" + await self.request("POST", f"/channels/{channel_id}/typing") + + async def start_stage_instance( + self, payload: Dict[str, Any], reason: Optional[str] = None + ) -> "StageInstance": + """Starts a stage instance.""" + + headers = {"X-Audit-Log-Reason": reason} if reason else None + data = await self.request( + "POST", "/stage-instances", payload=payload, custom_headers=headers + ) + from .models import StageInstance + + return StageInstance(data) + + async def edit_stage_instance( + self, + channel_id: "Snowflake", + payload: Dict[str, Any], + reason: Optional[str] = None, + ) -> "StageInstance": + """Edits an existing stage instance.""" + + headers = {"X-Audit-Log-Reason": reason} if reason else None + data = await self.request( + "PATCH", + f"/stage-instances/{channel_id}", + payload=payload, + custom_headers=headers, + ) + from .models import StageInstance + + return StageInstance(data) + + async def end_stage_instance( + self, channel_id: "Snowflake", reason: Optional[str] = None + ) -> None: + """Ends a stage instance.""" + + headers = {"X-Audit-Log-Reason": reason} if reason else None + await self.request( + "DELETE", f"/stage-instances/{channel_id}", custom_headers=headers + ) + + async def get_voice_regions(self) -> List[Dict[str, Any]]: + """Returns available voice regions.""" + return await self.request("GET", "/voice/regions") async def start_thread_from_message( self, diff --git a/disagreement/models.py b/disagreement/models.py index 37c9baf..be671a0 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -137,106 +137,106 @@ class Message: await self._client._http.unpin_message(self.channel_id, self.id) self.pinned = False - async def reply( - self, - content: Optional[str] = None, - *, # Make additional params keyword-only - tts: bool = False, - embed: Optional["Embed"] = None, - embeds: Optional[List["Embed"]] = None, - components: Optional[List["ActionRow"]] = None, - allowed_mentions: Optional[Dict[str, Any]] = None, - mention_author: Optional[bool] = None, - flags: Optional[int] = None, - view: Optional["View"] = None, - ) -> "Message": - """|coro| - - Sends a reply to the message. - This is a shorthand for `Client.send_message` in the message's channel. - - Parameters: - content (Optional[str]): The content of the message. - tts (bool): Whether the message should be sent with text-to-speech. - embed (Optional[Embed]): A single embed to send. Cannot be used with `embeds`. - embeds (Optional[List[Embed]]): A list of embeds to send. - components (Optional[List[ActionRow]]): A list of ActionRow components. - allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. - mention_author (Optional[bool]): Whether to mention the author in the reply. If ``None`` the - client's :attr:`mention_replies` setting is used. - flags (Optional[int]): Message flags. - view (Optional[View]): A view to send with the message. - - Returns: - Message: The message that was sent. - - Raises: - HTTPException: Sending the message failed. - ValueError: If both `embed` and `embeds` are provided. - """ - # Determine allowed mentions for the reply - if mention_author is None: - mention_author = getattr(self._client, "mention_replies", False) - - if allowed_mentions is None: - allowed_mentions = {"replied_user": mention_author} - else: - allowed_mentions = dict(allowed_mentions) - allowed_mentions.setdefault("replied_user", mention_author) - - # Client.send_message is already updated to handle these parameters - return await self._client.send_message( - channel_id=self.channel_id, - content=content, - tts=tts, - embed=embed, - embeds=embeds, - components=components, - allowed_mentions=allowed_mentions, - message_reference={ - "message_id": self.id, - "channel_id": self.channel_id, - "guild_id": self.guild_id, - }, - flags=flags, - view=view, - ) - - async def edit( - self, - *, - content: Optional[str] = None, - embed: Optional["Embed"] = None, - embeds: Optional[List["Embed"]] = None, - components: Optional[List["ActionRow"]] = None, - allowed_mentions: Optional[Dict[str, Any]] = None, - flags: Optional[int] = None, - view: Optional["View"] = None, - ) -> "Message": - """|coro| - - Edits this message. - - Parameters are the same as :meth:`Client.edit_message`. - """ - - return await self._client.edit_message( - channel_id=self.channel_id, - message_id=self.id, - content=content, - embed=embed, - embeds=embeds, - components=components, - allowed_mentions=allowed_mentions, - flags=flags, - view=view, - ) - - async def add_reaction(self, emoji: str) -> None: - """|coro| Add a reaction to this message.""" - - await self._client.add_reaction(self.channel_id, self.id, emoji) - + async def reply( + self, + content: Optional[str] = None, + *, # Make additional params keyword-only + tts: bool = False, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + mention_author: Optional[bool] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """|coro| + + Sends a reply to the message. + This is a shorthand for `Client.send_message` in the message's channel. + + Parameters: + content (Optional[str]): The content of the message. + tts (bool): Whether the message should be sent with text-to-speech. + embed (Optional[Embed]): A single embed to send. Cannot be used with `embeds`. + embeds (Optional[List[Embed]]): A list of embeds to send. + components (Optional[List[ActionRow]]): A list of ActionRow components. + allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. + mention_author (Optional[bool]): Whether to mention the author in the reply. If ``None`` the + client's :attr:`mention_replies` setting is used. + flags (Optional[int]): Message flags. + view (Optional[View]): A view to send with the message. + + Returns: + Message: The message that was sent. + + Raises: + HTTPException: Sending the message failed. + ValueError: If both `embed` and `embeds` are provided. + """ + # Determine allowed mentions for the reply + if mention_author is None: + mention_author = getattr(self._client, "mention_replies", False) + + if allowed_mentions is None: + allowed_mentions = {"replied_user": mention_author} + else: + allowed_mentions = dict(allowed_mentions) + allowed_mentions.setdefault("replied_user", mention_author) + + # Client.send_message is already updated to handle these parameters + return await self._client.send_message( + channel_id=self.channel_id, + content=content, + tts=tts, + embed=embed, + embeds=embeds, + components=components, + allowed_mentions=allowed_mentions, + message_reference={ + "message_id": self.id, + "channel_id": self.channel_id, + "guild_id": self.guild_id, + }, + flags=flags, + view=view, + ) + + async def edit( + self, + *, + content: Optional[str] = None, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """|coro| + + Edits this message. + + Parameters are the same as :meth:`Client.edit_message`. + """ + + return await self._client.edit_message( + channel_id=self.channel_id, + message_id=self.id, + content=content, + embed=embed, + embeds=embeds, + components=components, + allowed_mentions=allowed_mentions, + flags=flags, + view=view, + ) + + async def add_reaction(self, emoji: str) -> None: + """|coro| Add a reaction to this message.""" + + await self._client.add_reaction(self.channel_id, self.id, emoji) + async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None: """|coro| Removes a reaction from this message. @@ -248,31 +248,31 @@ class Message: ) else: await self._client.remove_reaction(self.channel_id, self.id, emoji) - - async def clear_reactions(self) -> None: - """|coro| Remove all reactions from this message.""" - - await self._client.clear_reactions(self.channel_id, self.id) - - async def delete(self, delay: Optional[float] = None) -> None: - """|coro| - - Deletes this message. - - Parameters - ---------- - delay: - If provided, wait this many seconds before deleting. - """ - - if delay is not None: - await asyncio.sleep(delay) - - await self._client._http.delete_message(self.channel_id, self.id) - - def __repr__(self) -> str: - return f"" - + + async def clear_reactions(self) -> None: + """|coro| Remove all reactions from this message.""" + + await self._client.clear_reactions(self.channel_id, self.id) + + async def delete(self, delay: Optional[float] = None) -> None: + """|coro| + + Deletes this message. + + Parameters + ---------- + delay: + If provided, wait this many seconds before deleting. + """ + + if delay is not None: + await asyncio.sleep(delay) + + await self._client._http.delete_message(self.channel_id, self.id) + + def __repr__(self) -> str: + return f"" + async def create_thread( self, name: str, @@ -1539,959 +1539,959 @@ class Thread(TextChannel): # Threads are a specialized TextChannel data = await self._client._http.edit_channel(self.id, payload, reason=reason) return cast("Thread", self._client.parse_channel(data)) - -class DMChannel(Channel): - """Represents a Direct Message channel.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - super().__init__(data, client_instance) - self.last_message_id: Optional[str] = data.get("last_message_id") - self.recipients: List[User] = [ - User(u_data) for u_data in data.get("recipients", []) - ] - - @property - def recipient(self) -> Optional[User]: - return self.recipients[0] if self.recipients else None - - async def send( - self, - content: Optional[str] = None, - *, - embed: Optional[Embed] = None, - embeds: Optional[List[Embed]] = None, - components: Optional[List["ActionRow"]] = None, # Added components - ) -> "Message": - if not hasattr(self._client, "send_message"): - raise NotImplementedError( - "Client.send_message is required for DMChannel.send" - ) - - return await self._client.send_message( - channel_id=self.id, - content=content, - embed=embed, - embeds=embeds, - components=components, - ) - - async def history( - self, - *, - limit: Optional[int] = 100, - before: "Snowflake | None" = None, - ): - """An async iterator over messages in this DM.""" - - params: Dict[str, Union[int, str]] = {} - if before is not None: - params["before"] = before - - fetched = 0 - while True: - to_fetch = 100 if limit is None else min(100, limit - fetched) - if to_fetch <= 0: - break - params["limit"] = to_fetch - messages = await self._client._http.request( - "GET", f"/channels/{self.id}/messages", params=params.copy() - ) - if not messages: - break - params["before"] = messages[-1]["id"] - for msg in messages: - yield Message(msg, self._client) - fetched += 1 - if limit is not None and fetched >= limit: - return - - def __repr__(self) -> str: - recipient_repr = self.recipient.username if self.recipient else "Unknown" - return f"" - - -class PartialChannel: - """Represents a partial channel object, often from interactions.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client: Optional["Client"] = client_instance - self.id: str = data["id"] - self.name: Optional[str] = data.get("name") - self._type_val: int = int(data["type"]) - self.permissions: Optional[str] = data.get("permissions") - - @property - def type(self) -> ChannelType: - return ChannelType(self._type_val) - - @property - def mention(self) -> str: - return f"<#{self.id}>" - - async def fetch_full_channel(self) -> Optional[Channel]: - if not self._client or not hasattr(self._client, "fetch_channel"): - # Log or raise if fetching is not possible - return None - try: - # This assumes Client.fetch_channel exists and returns a full Channel object - return await self._client.fetch_channel(self.id) - except HTTPException as exc: - print(f"HTTP error while fetching channel {self.id}: {exc}") - except (json.JSONDecodeError, KeyError, ValueError) as exc: - print(f"Failed to parse channel {self.id}: {exc}") - except DisagreementException as exc: - print(f"Error fetching channel {self.id}: {exc}") - return None - - def __repr__(self) -> str: - type_name = self.type.name if hasattr(self.type, "name") else self._type_val - return f"" - - -class Webhook: - """Represents a Discord Webhook.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client: Optional["Client"] = client_instance - self.id: str = data["id"] - self.type: int = int(data.get("type", 1)) - self.guild_id: Optional[str] = data.get("guild_id") - self.channel_id: Optional[str] = data.get("channel_id") - self.name: Optional[str] = data.get("name") - self.avatar: Optional[str] = data.get("avatar") - self.token: Optional[str] = data.get("token") - self.application_id: Optional[str] = data.get("application_id") - self.url: Optional[str] = data.get("url") - self.user: Optional[User] = User(data["user"]) if data.get("user") else None - - def __repr__(self) -> str: - return f"" - - @classmethod - def from_url( - cls, url: str, session: Optional[aiohttp.ClientSession] = None - ) -> "Webhook": - """Create a minimal :class:`Webhook` from a webhook URL. - - Parameters - ---------- - url: - The full Discord webhook URL. - session: - Unused for now. Present for API compatibility. - - Returns - ------- - Webhook - A webhook instance containing only the ``id``, ``token`` and ``url``. - """ - - parts = url.rstrip("/").split("/") - if len(parts) < 2: - raise ValueError("Invalid webhook URL") - token = parts[-1] - webhook_id = parts[-2] - - return cls({"id": webhook_id, "token": token, "url": url}) - - async def send( - self, - content: Optional[str] = None, - *, - username: Optional[str] = None, - avatar_url: Optional[str] = None, - tts: bool = False, - embed: Optional["Embed"] = None, - embeds: Optional[List["Embed"]] = None, - components: Optional[List["ActionRow"]] = None, - allowed_mentions: Optional[Dict[str, Any]] = None, - attachments: Optional[List[Any]] = None, - files: Optional[List[Any]] = None, - flags: Optional[int] = None, - ) -> "Message": - """Send a message using this webhook.""" - - if not self._client: - raise DisagreementException("Webhook is not bound to a Client") - assert self.token is not None, "Webhook token missing" - - if embed and embeds: - raise ValueError("Cannot provide both embed and embeds.") - - final_embeds_payload: Optional[List[Dict[str, Any]]] = None - if embed: - final_embeds_payload = [embed.to_dict()] - elif embeds: - final_embeds_payload = [e.to_dict() for e in embeds] - - components_payload: Optional[List[Dict[str, Any]]] = None - if components: - components_payload = [c.to_dict() for c in components] - - message_data = await self._client._http.execute_webhook( - self.id, - self.token, - content=content, - tts=tts, - embeds=final_embeds_payload, - components=components_payload, - allowed_mentions=allowed_mentions, - attachments=attachments, - files=files, - flags=flags, - username=username, - avatar_url=avatar_url, - ) - - return self._client.parse_message(message_data) - - -class GuildTemplate: - """Represents a guild template.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.code: str = data["code"] - self.name: str = data["name"] - self.description: Optional[str] = data.get("description") - self.usage_count: int = data.get("usage_count", 0) - self.creator_id: str = data.get("creator_id", "") - self.creator: Optional[User] = ( - User(data["creator"]) if data.get("creator") else None - ) - self.created_at: Optional[str] = data.get("created_at") - self.updated_at: Optional[str] = data.get("updated_at") - self.source_guild_id: Optional[str] = data.get("source_guild_id") - self.serialized_source_guild: Dict[str, Any] = data.get( - "serialized_source_guild", {} - ) - self.is_dirty: Optional[bool] = data.get("is_dirty") - - def __repr__(self) -> str: - return f"" - - -# --- Message Components --- - - -class Component: - """Base class for message components.""" - - def __init__(self, type: ComponentType): - self.type: ComponentType = type - self.custom_id: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"type": self.type.value} - if self.custom_id: - payload["custom_id"] = self.custom_id - return payload - - -class ActionRow(Component): - """Represents an Action Row, a container for other components.""" - - def __init__(self, components: Optional[List[Component]] = None): - super().__init__(ComponentType.ACTION_ROW) - self.components: List[Component] = components or [] - - def add_component(self, component: Component): - if isinstance(component, ActionRow): - raise ValueError("Cannot nest ActionRows inside another ActionRow.") - - select_types = { - ComponentType.STRING_SELECT, - ComponentType.USER_SELECT, - ComponentType.ROLE_SELECT, - ComponentType.MENTIONABLE_SELECT, - ComponentType.CHANNEL_SELECT, - } - - if component.type in select_types: - if self.components: - raise ValueError( - "Select menu components must be the only component in an ActionRow." - ) - self.components.append(component) - return self - - if any(c.type in select_types for c in self.components): - raise ValueError( - "Cannot add components to an ActionRow that already contains a select menu." - ) - - if len(self.components) >= 5: - raise ValueError("ActionRow cannot have more than 5 components.") - - self.components.append(component) - return self - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() - payload["components"] = [c.to_dict() for c in self.components] - return payload - - @classmethod - def from_dict( - cls, data: Dict[str, Any], client: Optional["Client"] = None - ) -> "ActionRow": - """Deserialize an action row payload.""" - from .components import component_factory - - row = cls() - for comp_data in data.get("components", []): - try: - row.add_component(component_factory(comp_data, client)) - except Exception: - # Skip components that fail to parse for now - continue - return row - - -class Button(Component): - """Represents a button component.""" - - def __init__( - self, - *, # Make parameters keyword-only for clarity - style: ButtonStyle, - label: Optional[str] = None, - emoji: Optional["PartialEmoji"] = None, # Changed to PartialEmoji type - custom_id: Optional[str] = None, - url: Optional[str] = None, - disabled: bool = False, - ): - super().__init__(ComponentType.BUTTON) - - if style == ButtonStyle.LINK and url is None: - raise ValueError("Link buttons must have a URL.") - if style != ButtonStyle.LINK and custom_id is None: - raise ValueError("Non-link buttons must have a custom_id.") - if label is None and emoji is None: - raise ValueError("Button must have a label or an emoji.") - - self.style: ButtonStyle = style - self.label: Optional[str] = label - self.emoji: Optional[PartialEmoji] = emoji - self.custom_id = custom_id - self.url: Optional[str] = url - self.disabled: bool = disabled - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() - payload["style"] = self.style.value - if self.label: - payload["label"] = self.label - if self.emoji: - payload["emoji"] = self.emoji.to_dict() # Call to_dict() - if self.custom_id: - payload["custom_id"] = self.custom_id - if self.url: - payload["url"] = self.url - if self.disabled: - payload["disabled"] = self.disabled - return payload - - -class SelectOption: - """Represents an option in a select menu.""" - - def __init__( - self, - *, # Make parameters keyword-only - label: str, - value: str, - description: Optional[str] = None, - emoji: Optional["PartialEmoji"] = None, # Changed to PartialEmoji type - default: bool = False, - ): - self.label: str = label - self.value: str = value - self.description: Optional[str] = description - self.emoji: Optional["PartialEmoji"] = emoji - self.default: bool = default - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = { - "label": self.label, - "value": self.value, - } - if self.description: - payload["description"] = self.description - if self.emoji: - payload["emoji"] = self.emoji.to_dict() # Call to_dict() - if self.default: - payload["default"] = self.default - return payload - - -class SelectMenu(Component): - """Represents a select menu component. - - Currently supports STRING_SELECT (type 3). - User (5), Role (6), Mentionable (7), Channel (8) selects are not yet fully modeled. - """ - - def __init__( - self, - *, # Make parameters keyword-only - custom_id: str, - options: List[SelectOption], - placeholder: Optional[str] = None, - min_values: int = 1, - max_values: int = 1, - disabled: bool = False, - channel_types: Optional[List[ChannelType]] = None, - # For other select types, specific fields would be needed. - # This constructor primarily targets STRING_SELECT (type 3). - type: ComponentType = ComponentType.STRING_SELECT, # Default to string select - ): - super().__init__(type) # Pass the specific select menu type - - if not (1 <= len(options) <= 25): - raise ValueError("Select menu must have between 1 and 25 options.") - if not ( - 0 <= min_values <= 25 - ): # Discord docs say min_values can be 0 for some types - raise ValueError("min_values must be between 0 and 25.") - if not (1 <= max_values <= 25): - raise ValueError("max_values must be between 1 and 25.") - if min_values > max_values: - raise ValueError("min_values cannot be greater than max_values.") - - self.custom_id = custom_id - self.options: List[SelectOption] = options - self.placeholder: Optional[str] = placeholder - self.min_values: int = min_values - self.max_values: int = max_values - self.disabled: bool = disabled - self.channel_types: Optional[List[ChannelType]] = channel_types - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() # Gets {"type": self.type.value} - payload["custom_id"] = self.custom_id - payload["options"] = [opt.to_dict() for opt in self.options] - if self.placeholder: - payload["placeholder"] = self.placeholder - payload["min_values"] = self.min_values - payload["max_values"] = self.max_values - if self.disabled: - payload["disabled"] = self.disabled - if self.type == ComponentType.CHANNEL_SELECT and self.channel_types: - payload["channel_types"] = [ct.value for ct in self.channel_types] - return payload - - -class UnfurledMediaItem: - """Represents an unfurled media item.""" - - def __init__( - self, - url: str, - proxy_url: Optional[str] = None, - height: Optional[int] = None, - width: Optional[int] = None, - content_type: Optional[str] = None, - ): - self.url = url - self.proxy_url = proxy_url - self.height = height - self.width = width - self.content_type = content_type - - def to_dict(self) -> Dict[str, Any]: - return { - "url": self.url, - "proxy_url": self.proxy_url, - "height": self.height, - "width": self.width, - "content_type": self.content_type, - } - - -class MediaGalleryItem: - """Represents an item in a media gallery.""" - - def __init__( - self, - media: UnfurledMediaItem, - description: Optional[str] = None, - spoiler: bool = False, - ): - self.media = media - self.description = description - self.spoiler = spoiler - - def to_dict(self) -> Dict[str, Any]: - return { - "media": self.media.to_dict(), - "description": self.description, - "spoiler": self.spoiler, - } - - -class TextDisplay(Component): - """Represents a text display component.""" - - def __init__(self, content: str, id: Optional[int] = None): - super().__init__(ComponentType.TEXT_DISPLAY) - self.content = content - self.id = id - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() - payload["content"] = self.content - if self.id is not None: - payload["id"] = self.id - return payload - - -class Thumbnail(Component): - """Represents a thumbnail component.""" - - def __init__( - self, - media: UnfurledMediaItem, - description: Optional[str] = None, - spoiler: bool = False, - id: Optional[int] = None, - ): - super().__init__(ComponentType.THUMBNAIL) - self.media = media - self.description = description - self.spoiler = spoiler - self.id = id - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() - payload["media"] = self.media.to_dict() - if self.description: - payload["description"] = self.description - if self.spoiler: - payload["spoiler"] = self.spoiler - if self.id is not None: - payload["id"] = self.id - return payload - - -class Section(Component): - """Represents a section component.""" - - def __init__( - self, - components: List[TextDisplay], - accessory: Optional[Union[Thumbnail, Button]] = None, - id: Optional[int] = None, - ): - super().__init__(ComponentType.SECTION) - self.components = components - self.accessory = accessory - self.id = id - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() - payload["components"] = [c.to_dict() for c in self.components] - if self.accessory: - payload["accessory"] = self.accessory.to_dict() - if self.id is not None: - payload["id"] = self.id - return payload - - -class MediaGallery(Component): - """Represents a media gallery component.""" - - def __init__(self, items: List[MediaGalleryItem], id: Optional[int] = None): - super().__init__(ComponentType.MEDIA_GALLERY) - self.items = items - self.id = id - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() - payload["items"] = [i.to_dict() for i in self.items] - if self.id is not None: - payload["id"] = self.id - return payload - - -class FileComponent(Component): - """Represents a file component.""" - - def __init__( - self, file: UnfurledMediaItem, spoiler: bool = False, id: Optional[int] = None - ): - super().__init__(ComponentType.FILE) - self.file = file - self.spoiler = spoiler - self.id = id - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() - payload["file"] = self.file.to_dict() - if self.spoiler: - payload["spoiler"] = self.spoiler - if self.id is not None: - payload["id"] = self.id - return payload - - -class Separator(Component): - """Represents a separator component.""" - - def __init__( - self, divider: bool = True, spacing: int = 1, id: Optional[int] = None - ): - super().__init__(ComponentType.SEPARATOR) - self.divider = divider - self.spacing = spacing - self.id = id - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() - payload["divider"] = self.divider - payload["spacing"] = self.spacing - if self.id is not None: - payload["id"] = self.id - return payload - - -class Container(Component): - """Represents a container component.""" - - def __init__( - self, - components: List[Component], - accent_color: Color | int | str | None = None, - spoiler: bool = False, - id: Optional[int] = None, - ): - super().__init__(ComponentType.CONTAINER) - self.components = components - self.accent_color = Color.parse(accent_color) - self.spoiler = spoiler - self.id = id - - def to_dict(self) -> Dict[str, Any]: - payload = super().to_dict() - payload["components"] = [c.to_dict() for c in self.components] - if self.accent_color: - payload["accent_color"] = self.accent_color.value - if self.spoiler: - payload["spoiler"] = self.spoiler - if self.id is not None: - payload["id"] = self.id - return payload - - -class WelcomeChannel: - """Represents a channel shown in the server's welcome screen. - - Attributes: - channel_id (str): The ID of the channel. - description (str): The description shown for the channel. - emoji_id (Optional[str]): The ID of the emoji, if custom. - emoji_name (Optional[str]): The name of the emoji if custom, or the unicode character if standard. - """ - - def __init__(self, data: Dict[str, Any]): - self.channel_id: str = data["channel_id"] - self.description: str = data["description"] - self.emoji_id: Optional[str] = data.get("emoji_id") - self.emoji_name: Optional[str] = data.get("emoji_name") - - def __repr__(self) -> str: - return ( - f"" - ) - - -class WelcomeScreen: - """Represents the welcome screen of a Community guild. - - Attributes: - description (Optional[str]): The server description shown in the welcome screen. - welcome_channels (List[WelcomeChannel]): The channels shown in the welcome screen. - """ - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - self._client: "Client" = ( - client_instance # May be useful for fetching channel objects - ) - self.description: Optional[str] = data.get("description") - self.welcome_channels: List[WelcomeChannel] = [ - WelcomeChannel(wc_data) for wc_data in data.get("welcome_channels", []) - ] - - def __repr__(self) -> str: - return f"" - - -class ThreadMember: - """Represents a member of a thread. - - Attributes: - id (Optional[str]): The ID of the thread. Not always present. - user_id (Optional[str]): The ID of the user. Not always present. - join_timestamp (str): When the user joined the thread (ISO8601 timestamp). - flags (int): User-specific flags for thread settings. - member (Optional[Member]): The guild member object for this user, if resolved. - Only available from GUILD_MEMBERS intent and if fetched. - """ - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): # client_instance for member resolution - self._client: Optional["Client"] = client_instance - self.id: Optional[str] = data.get("id") # Thread ID - self.user_id: Optional[str] = data.get("user_id") - self.join_timestamp: str = data["join_timestamp"] - self.flags: int = data["flags"] - - # The 'member' field in ThreadMember payload is a full guild member object. - # This is present in some contexts like when listing thread members. - self.member: Optional[Member] = ( - Member(data["member"], client_instance) if data.get("member") else None - ) - - # Note: The 'presence' field is not included as it's often unavailable or too dynamic for a simple model. - - def __repr__(self) -> str: - return f"" - - -class PresenceUpdate: - """Represents a PRESENCE_UPDATE event.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.user = User(data["user"]) - self.guild_id: Optional[str] = data.get("guild_id") - self.status: Optional[str] = data.get("status") - self.activities: List[Dict[str, Any]] = data.get("activities", []) - self.client_status: Dict[str, Any] = data.get("client_status", {}) - - def __repr__(self) -> str: - return f"" - - -class TypingStart: - """Represents a TYPING_START event.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.channel_id: str = data["channel_id"] - self.guild_id: Optional[str] = data.get("guild_id") - self.user_id: str = data["user_id"] - self.timestamp: int = data["timestamp"] - self.member: Optional[Member] = ( - Member(data["member"], client_instance) if data.get("member") else None - ) - - def __repr__(self) -> str: - return f"" - - -class Reaction: - """Represents a message reaction event.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.user_id: str = data["user_id"] - self.channel_id: str = data["channel_id"] - self.message_id: str = data["message_id"] - self.guild_id: Optional[str] = data.get("guild_id") - self.member: Optional[Member] = ( - Member(data["member"], client_instance) if data.get("member") else None - ) - self.emoji: Dict[str, Any] = data.get("emoji", {}) - - def __repr__(self) -> str: - emoji_value = self.emoji.get("name") or self.emoji.get("id") - return f"" - - -class ScheduledEvent: - """Represents a guild scheduled event.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.id: str = data["id"] - self.guild_id: str = data["guild_id"] - self.channel_id: Optional[str] = data.get("channel_id") - self.creator_id: Optional[str] = data.get("creator_id") - self.name: str = data["name"] - self.description: Optional[str] = data.get("description") - self.scheduled_start_time: str = data["scheduled_start_time"] - self.scheduled_end_time: Optional[str] = data.get("scheduled_end_time") - self.privacy_level: GuildScheduledEventPrivacyLevel = ( - GuildScheduledEventPrivacyLevel(data["privacy_level"]) - ) - self.status: GuildScheduledEventStatus = GuildScheduledEventStatus( - data["status"] - ) - self.entity_type: GuildScheduledEventEntityType = GuildScheduledEventEntityType( - data["entity_type"] - ) - self.entity_id: Optional[str] = data.get("entity_id") - self.entity_metadata: Optional[Dict[str, Any]] = data.get("entity_metadata") - self.creator: Optional[User] = ( - User(data["creator"]) if data.get("creator") else None - ) - self.user_count: Optional[int] = data.get("user_count") - self.image: Optional[str] = data.get("image") - - def __repr__(self) -> str: - return f"" - - -@dataclass -class Invite: - """Represents a Discord invite.""" - - code: str - channel_id: Optional[str] - guild_id: Optional[str] - inviter_id: Optional[str] - uses: Optional[int] - max_uses: Optional[int] - max_age: Optional[int] - temporary: Optional[bool] - created_at: Optional[str] - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Invite": - channel = data.get("channel") - guild = data.get("guild") - inviter = data.get("inviter") - return cls( - code=data["code"], - channel_id=(channel or {}).get("id") if channel else data.get("channel_id"), - guild_id=(guild or {}).get("id") if guild else data.get("guild_id"), - inviter_id=(inviter or {}).get("id"), - uses=data.get("uses"), - max_uses=data.get("max_uses"), - max_age=data.get("max_age"), - temporary=data.get("temporary"), - created_at=data.get("created_at"), - ) - - def __repr__(self) -> str: - return f"" - - -class GuildMemberRemove: - """Represents a GUILD_MEMBER_REMOVE event.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.guild_id: str = data["guild_id"] - self.user: User = User(data["user"]) - - def __repr__(self) -> str: - return ( - f"" - ) - - -class GuildBanAdd: - """Represents a GUILD_BAN_ADD event.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.guild_id: str = data["guild_id"] - self.user: User = User(data["user"]) - - def __repr__(self) -> str: - return f"" - - -class GuildBanRemove: - """Represents a GUILD_BAN_REMOVE event.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.guild_id: str = data["guild_id"] - self.user: User = User(data["user"]) - - def __repr__(self) -> str: - return f"" - - -class GuildRoleUpdate: - """Represents a GUILD_ROLE_UPDATE event.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.guild_id: str = data["guild_id"] - self.role: Role = Role(data["role"]) - - def __repr__(self) -> str: - return f"" - - -class AuditLogEntry: - """Represents a single entry in a guild's audit log.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ) -> None: - self._client = client_instance - self.id: str = data["id"] - self.user_id: Optional[str] = data.get("user_id") - self.target_id: Optional[str] = data.get("target_id") - self.action_type: int = data["action_type"] - self.reason: Optional[str] = data.get("reason") - self.changes: List[Dict[str, Any]] = data.get("changes", []) - self.options: Optional[Dict[str, Any]] = data.get("options") - - def __repr__(self) -> str: - return f"" - - -def channel_factory(data: Dict[str, Any], client: "Client") -> Channel: - """Create a channel object from raw API data.""" - channel_type = data.get("type") - - if channel_type in ( - ChannelType.GUILD_TEXT.value, - ChannelType.GUILD_ANNOUNCEMENT.value, - ): - return TextChannel(data, client) - if channel_type == ChannelType.GUILD_VOICE.value: - return VoiceChannel(data, client) - if channel_type == ChannelType.GUILD_STAGE_VOICE.value: - return StageChannel(data, client) - if channel_type == ChannelType.GUILD_CATEGORY.value: - return CategoryChannel(data, client) - if channel_type in ( - ChannelType.ANNOUNCEMENT_THREAD.value, - ChannelType.PUBLIC_THREAD.value, - ChannelType.PRIVATE_THREAD.value, - ): - return Thread(data, client) - if channel_type in (ChannelType.DM.value, ChannelType.GROUP_DM.value): - return DMChannel(data, client) - - return Channel(data, client) + +class DMChannel(Channel): + """Represents a Direct Message channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + self.last_message_id: Optional[str] = data.get("last_message_id") + self.recipients: List[User] = [ + User(u_data) for u_data in data.get("recipients", []) + ] + + @property + def recipient(self) -> Optional[User]: + return self.recipients[0] if self.recipients else None + + async def send( + self, + content: Optional[str] = None, + *, + embed: Optional[Embed] = None, + embeds: Optional[List[Embed]] = None, + components: Optional[List["ActionRow"]] = None, # Added components + ) -> "Message": + if not hasattr(self._client, "send_message"): + raise NotImplementedError( + "Client.send_message is required for DMChannel.send" + ) + + return await self._client.send_message( + channel_id=self.id, + content=content, + embed=embed, + embeds=embeds, + components=components, + ) + + async def history( + self, + *, + limit: Optional[int] = 100, + before: "Snowflake | None" = None, + ): + """An async iterator over messages in this DM.""" + + params: Dict[str, Union[int, str]] = {} + if before is not None: + params["before"] = before + + fetched = 0 + while True: + to_fetch = 100 if limit is None else min(100, limit - fetched) + if to_fetch <= 0: + break + params["limit"] = to_fetch + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params.copy() + ) + if not messages: + break + params["before"] = messages[-1]["id"] + for msg in messages: + yield Message(msg, self._client) + fetched += 1 + if limit is not None and fetched >= limit: + return + + def __repr__(self) -> str: + recipient_repr = self.recipient.username if self.recipient else "Unknown" + return f"" + + +class PartialChannel: + """Represents a partial channel object, often from interactions.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client: Optional["Client"] = client_instance + self.id: str = data["id"] + self.name: Optional[str] = data.get("name") + self._type_val: int = int(data["type"]) + self.permissions: Optional[str] = data.get("permissions") + + @property + def type(self) -> ChannelType: + return ChannelType(self._type_val) + + @property + def mention(self) -> str: + return f"<#{self.id}>" + + async def fetch_full_channel(self) -> Optional[Channel]: + if not self._client or not hasattr(self._client, "fetch_channel"): + # Log or raise if fetching is not possible + return None + try: + # This assumes Client.fetch_channel exists and returns a full Channel object + return await self._client.fetch_channel(self.id) + except HTTPException as exc: + print(f"HTTP error while fetching channel {self.id}: {exc}") + except (json.JSONDecodeError, KeyError, ValueError) as exc: + print(f"Failed to parse channel {self.id}: {exc}") + except DisagreementException as exc: + print(f"Error fetching channel {self.id}: {exc}") + return None + + def __repr__(self) -> str: + type_name = self.type.name if hasattr(self.type, "name") else self._type_val + return f"" + + +class Webhook: + """Represents a Discord Webhook.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client: Optional["Client"] = client_instance + self.id: str = data["id"] + self.type: int = int(data.get("type", 1)) + self.guild_id: Optional[str] = data.get("guild_id") + self.channel_id: Optional[str] = data.get("channel_id") + self.name: Optional[str] = data.get("name") + self.avatar: Optional[str] = data.get("avatar") + self.token: Optional[str] = data.get("token") + self.application_id: Optional[str] = data.get("application_id") + self.url: Optional[str] = data.get("url") + self.user: Optional[User] = User(data["user"]) if data.get("user") else None + + def __repr__(self) -> str: + return f"" + + @classmethod + def from_url( + cls, url: str, session: Optional[aiohttp.ClientSession] = None + ) -> "Webhook": + """Create a minimal :class:`Webhook` from a webhook URL. + + Parameters + ---------- + url: + The full Discord webhook URL. + session: + Unused for now. Present for API compatibility. + + Returns + ------- + Webhook + A webhook instance containing only the ``id``, ``token`` and ``url``. + """ + + parts = url.rstrip("/").split("/") + if len(parts) < 2: + raise ValueError("Invalid webhook URL") + token = parts[-1] + webhook_id = parts[-2] + + return cls({"id": webhook_id, "token": token, "url": url}) + + async def send( + self, + content: Optional[str] = None, + *, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + tts: bool = False, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + attachments: Optional[List[Any]] = None, + files: Optional[List[Any]] = None, + flags: Optional[int] = None, + ) -> "Message": + """Send a message using this webhook.""" + + if not self._client: + raise DisagreementException("Webhook is not bound to a Client") + assert self.token is not None, "Webhook token missing" + + if embed and embeds: + raise ValueError("Cannot provide both embed and embeds.") + + final_embeds_payload: Optional[List[Dict[str, Any]]] = None + if embed: + final_embeds_payload = [embed.to_dict()] + elif embeds: + final_embeds_payload = [e.to_dict() for e in embeds] + + components_payload: Optional[List[Dict[str, Any]]] = None + if components: + components_payload = [c.to_dict() for c in components] + + message_data = await self._client._http.execute_webhook( + self.id, + self.token, + content=content, + tts=tts, + embeds=final_embeds_payload, + components=components_payload, + allowed_mentions=allowed_mentions, + attachments=attachments, + files=files, + flags=flags, + username=username, + avatar_url=avatar_url, + ) + + return self._client.parse_message(message_data) + + +class GuildTemplate: + """Represents a guild template.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.code: str = data["code"] + self.name: str = data["name"] + self.description: Optional[str] = data.get("description") + self.usage_count: int = data.get("usage_count", 0) + self.creator_id: str = data.get("creator_id", "") + self.creator: Optional[User] = ( + User(data["creator"]) if data.get("creator") else None + ) + self.created_at: Optional[str] = data.get("created_at") + self.updated_at: Optional[str] = data.get("updated_at") + self.source_guild_id: Optional[str] = data.get("source_guild_id") + self.serialized_source_guild: Dict[str, Any] = data.get( + "serialized_source_guild", {} + ) + self.is_dirty: Optional[bool] = data.get("is_dirty") + + def __repr__(self) -> str: + return f"" + + +# --- Message Components --- + + +class Component: + """Base class for message components.""" + + def __init__(self, type: ComponentType): + self.type: ComponentType = type + self.custom_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"type": self.type.value} + if self.custom_id: + payload["custom_id"] = self.custom_id + return payload + + +class ActionRow(Component): + """Represents an Action Row, a container for other components.""" + + def __init__(self, components: Optional[List[Component]] = None): + super().__init__(ComponentType.ACTION_ROW) + self.components: List[Component] = components or [] + + def add_component(self, component: Component): + if isinstance(component, ActionRow): + raise ValueError("Cannot nest ActionRows inside another ActionRow.") + + select_types = { + ComponentType.STRING_SELECT, + ComponentType.USER_SELECT, + ComponentType.ROLE_SELECT, + ComponentType.MENTIONABLE_SELECT, + ComponentType.CHANNEL_SELECT, + } + + if component.type in select_types: + if self.components: + raise ValueError( + "Select menu components must be the only component in an ActionRow." + ) + self.components.append(component) + return self + + if any(c.type in select_types for c in self.components): + raise ValueError( + "Cannot add components to an ActionRow that already contains a select menu." + ) + + if len(self.components) >= 5: + raise ValueError("ActionRow cannot have more than 5 components.") + + self.components.append(component) + return self + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["components"] = [c.to_dict() for c in self.components] + return payload + + @classmethod + def from_dict( + cls, data: Dict[str, Any], client: Optional["Client"] = None + ) -> "ActionRow": + """Deserialize an action row payload.""" + from .components import component_factory + + row = cls() + for comp_data in data.get("components", []): + try: + row.add_component(component_factory(comp_data, client)) + except Exception: + # Skip components that fail to parse for now + continue + return row + + +class Button(Component): + """Represents a button component.""" + + def __init__( + self, + *, # Make parameters keyword-only for clarity + style: ButtonStyle, + label: Optional[str] = None, + emoji: Optional["PartialEmoji"] = None, # Changed to PartialEmoji type + custom_id: Optional[str] = None, + url: Optional[str] = None, + disabled: bool = False, + ): + super().__init__(ComponentType.BUTTON) + + if style == ButtonStyle.LINK and url is None: + raise ValueError("Link buttons must have a URL.") + if style != ButtonStyle.LINK and custom_id is None: + raise ValueError("Non-link buttons must have a custom_id.") + if label is None and emoji is None: + raise ValueError("Button must have a label or an emoji.") + + self.style: ButtonStyle = style + self.label: Optional[str] = label + self.emoji: Optional[PartialEmoji] = emoji + self.custom_id = custom_id + self.url: Optional[str] = url + self.disabled: bool = disabled + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["style"] = self.style.value + if self.label: + payload["label"] = self.label + if self.emoji: + payload["emoji"] = self.emoji.to_dict() # Call to_dict() + if self.custom_id: + payload["custom_id"] = self.custom_id + if self.url: + payload["url"] = self.url + if self.disabled: + payload["disabled"] = self.disabled + return payload + + +class SelectOption: + """Represents an option in a select menu.""" + + def __init__( + self, + *, # Make parameters keyword-only + label: str, + value: str, + description: Optional[str] = None, + emoji: Optional["PartialEmoji"] = None, # Changed to PartialEmoji type + default: bool = False, + ): + self.label: str = label + self.value: str = value + self.description: Optional[str] = description + self.emoji: Optional["PartialEmoji"] = emoji + self.default: bool = default + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "label": self.label, + "value": self.value, + } + if self.description: + payload["description"] = self.description + if self.emoji: + payload["emoji"] = self.emoji.to_dict() # Call to_dict() + if self.default: + payload["default"] = self.default + return payload + + +class SelectMenu(Component): + """Represents a select menu component. + + Currently supports STRING_SELECT (type 3). + User (5), Role (6), Mentionable (7), Channel (8) selects are not yet fully modeled. + """ + + def __init__( + self, + *, # Make parameters keyword-only + custom_id: str, + options: List[SelectOption], + placeholder: Optional[str] = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + channel_types: Optional[List[ChannelType]] = None, + # For other select types, specific fields would be needed. + # This constructor primarily targets STRING_SELECT (type 3). + type: ComponentType = ComponentType.STRING_SELECT, # Default to string select + ): + super().__init__(type) # Pass the specific select menu type + + if not (1 <= len(options) <= 25): + raise ValueError("Select menu must have between 1 and 25 options.") + if not ( + 0 <= min_values <= 25 + ): # Discord docs say min_values can be 0 for some types + raise ValueError("min_values must be between 0 and 25.") + if not (1 <= max_values <= 25): + raise ValueError("max_values must be between 1 and 25.") + if min_values > max_values: + raise ValueError("min_values cannot be greater than max_values.") + + self.custom_id = custom_id + self.options: List[SelectOption] = options + self.placeholder: Optional[str] = placeholder + self.min_values: int = min_values + self.max_values: int = max_values + self.disabled: bool = disabled + self.channel_types: Optional[List[ChannelType]] = channel_types + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() # Gets {"type": self.type.value} + payload["custom_id"] = self.custom_id + payload["options"] = [opt.to_dict() for opt in self.options] + if self.placeholder: + payload["placeholder"] = self.placeholder + payload["min_values"] = self.min_values + payload["max_values"] = self.max_values + if self.disabled: + payload["disabled"] = self.disabled + if self.type == ComponentType.CHANNEL_SELECT and self.channel_types: + payload["channel_types"] = [ct.value for ct in self.channel_types] + return payload + + +class UnfurledMediaItem: + """Represents an unfurled media item.""" + + def __init__( + self, + url: str, + proxy_url: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + content_type: Optional[str] = None, + ): + self.url = url + self.proxy_url = proxy_url + self.height = height + self.width = width + self.content_type = content_type + + def to_dict(self) -> Dict[str, Any]: + return { + "url": self.url, + "proxy_url": self.proxy_url, + "height": self.height, + "width": self.width, + "content_type": self.content_type, + } + + +class MediaGalleryItem: + """Represents an item in a media gallery.""" + + def __init__( + self, + media: UnfurledMediaItem, + description: Optional[str] = None, + spoiler: bool = False, + ): + self.media = media + self.description = description + self.spoiler = spoiler + + def to_dict(self) -> Dict[str, Any]: + return { + "media": self.media.to_dict(), + "description": self.description, + "spoiler": self.spoiler, + } + + +class TextDisplay(Component): + """Represents a text display component.""" + + def __init__(self, content: str, id: Optional[int] = None): + super().__init__(ComponentType.TEXT_DISPLAY) + self.content = content + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["content"] = self.content + if self.id is not None: + payload["id"] = self.id + return payload + + +class Thumbnail(Component): + """Represents a thumbnail component.""" + + def __init__( + self, + media: UnfurledMediaItem, + description: Optional[str] = None, + spoiler: bool = False, + id: Optional[int] = None, + ): + super().__init__(ComponentType.THUMBNAIL) + self.media = media + self.description = description + self.spoiler = spoiler + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["media"] = self.media.to_dict() + if self.description: + payload["description"] = self.description + if self.spoiler: + payload["spoiler"] = self.spoiler + if self.id is not None: + payload["id"] = self.id + return payload + + +class Section(Component): + """Represents a section component.""" + + def __init__( + self, + components: List[TextDisplay], + accessory: Optional[Union[Thumbnail, Button]] = None, + id: Optional[int] = None, + ): + super().__init__(ComponentType.SECTION) + self.components = components + self.accessory = accessory + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["components"] = [c.to_dict() for c in self.components] + if self.accessory: + payload["accessory"] = self.accessory.to_dict() + if self.id is not None: + payload["id"] = self.id + return payload + + +class MediaGallery(Component): + """Represents a media gallery component.""" + + def __init__(self, items: List[MediaGalleryItem], id: Optional[int] = None): + super().__init__(ComponentType.MEDIA_GALLERY) + self.items = items + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["items"] = [i.to_dict() for i in self.items] + if self.id is not None: + payload["id"] = self.id + return payload + + +class FileComponent(Component): + """Represents a file component.""" + + def __init__( + self, file: UnfurledMediaItem, spoiler: bool = False, id: Optional[int] = None + ): + super().__init__(ComponentType.FILE) + self.file = file + self.spoiler = spoiler + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["file"] = self.file.to_dict() + if self.spoiler: + payload["spoiler"] = self.spoiler + if self.id is not None: + payload["id"] = self.id + return payload + + +class Separator(Component): + """Represents a separator component.""" + + def __init__( + self, divider: bool = True, spacing: int = 1, id: Optional[int] = None + ): + super().__init__(ComponentType.SEPARATOR) + self.divider = divider + self.spacing = spacing + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["divider"] = self.divider + payload["spacing"] = self.spacing + if self.id is not None: + payload["id"] = self.id + return payload + + +class Container(Component): + """Represents a container component.""" + + def __init__( + self, + components: List[Component], + accent_color: Color | int | str | None = None, + spoiler: bool = False, + id: Optional[int] = None, + ): + super().__init__(ComponentType.CONTAINER) + self.components = components + self.accent_color = Color.parse(accent_color) + self.spoiler = spoiler + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["components"] = [c.to_dict() for c in self.components] + if self.accent_color: + payload["accent_color"] = self.accent_color.value + if self.spoiler: + payload["spoiler"] = self.spoiler + if self.id is not None: + payload["id"] = self.id + return payload + + +class WelcomeChannel: + """Represents a channel shown in the server's welcome screen. + + Attributes: + channel_id (str): The ID of the channel. + description (str): The description shown for the channel. + emoji_id (Optional[str]): The ID of the emoji, if custom. + emoji_name (Optional[str]): The name of the emoji if custom, or the unicode character if standard. + """ + + def __init__(self, data: Dict[str, Any]): + self.channel_id: str = data["channel_id"] + self.description: str = data["description"] + self.emoji_id: Optional[str] = data.get("emoji_id") + self.emoji_name: Optional[str] = data.get("emoji_name") + + def __repr__(self) -> str: + return ( + f"" + ) + + +class WelcomeScreen: + """Represents the welcome screen of a Community guild. + + Attributes: + description (Optional[str]): The server description shown in the welcome screen. + welcome_channels (List[WelcomeChannel]): The channels shown in the welcome screen. + """ + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + self._client: "Client" = ( + client_instance # May be useful for fetching channel objects + ) + self.description: Optional[str] = data.get("description") + self.welcome_channels: List[WelcomeChannel] = [ + WelcomeChannel(wc_data) for wc_data in data.get("welcome_channels", []) + ] + + def __repr__(self) -> str: + return f"" + + +class ThreadMember: + """Represents a member of a thread. + + Attributes: + id (Optional[str]): The ID of the thread. Not always present. + user_id (Optional[str]): The ID of the user. Not always present. + join_timestamp (str): When the user joined the thread (ISO8601 timestamp). + flags (int): User-specific flags for thread settings. + member (Optional[Member]): The guild member object for this user, if resolved. + Only available from GUILD_MEMBERS intent and if fetched. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): # client_instance for member resolution + self._client: Optional["Client"] = client_instance + self.id: Optional[str] = data.get("id") # Thread ID + self.user_id: Optional[str] = data.get("user_id") + self.join_timestamp: str = data["join_timestamp"] + self.flags: int = data["flags"] + + # The 'member' field in ThreadMember payload is a full guild member object. + # This is present in some contexts like when listing thread members. + self.member: Optional[Member] = ( + Member(data["member"], client_instance) if data.get("member") else None + ) + + # Note: The 'presence' field is not included as it's often unavailable or too dynamic for a simple model. + + def __repr__(self) -> str: + return f"" + + +class PresenceUpdate: + """Represents a PRESENCE_UPDATE event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.user = User(data["user"]) + self.guild_id: Optional[str] = data.get("guild_id") + self.status: Optional[str] = data.get("status") + self.activities: List[Dict[str, Any]] = data.get("activities", []) + self.client_status: Dict[str, Any] = data.get("client_status", {}) + + def __repr__(self) -> str: + return f"" + + +class TypingStart: + """Represents a TYPING_START event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.channel_id: str = data["channel_id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.user_id: str = data["user_id"] + self.timestamp: int = data["timestamp"] + self.member: Optional[Member] = ( + Member(data["member"], client_instance) if data.get("member") else None + ) + + def __repr__(self) -> str: + return f"" + + +class Reaction: + """Represents a message reaction event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.user_id: str = data["user_id"] + self.channel_id: str = data["channel_id"] + self.message_id: str = data["message_id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.member: Optional[Member] = ( + Member(data["member"], client_instance) if data.get("member") else None + ) + self.emoji: Dict[str, Any] = data.get("emoji", {}) + + def __repr__(self) -> str: + emoji_value = self.emoji.get("name") or self.emoji.get("id") + return f"" + + +class ScheduledEvent: + """Represents a guild scheduled event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.id: str = data["id"] + self.guild_id: str = data["guild_id"] + self.channel_id: Optional[str] = data.get("channel_id") + self.creator_id: Optional[str] = data.get("creator_id") + self.name: str = data["name"] + self.description: Optional[str] = data.get("description") + self.scheduled_start_time: str = data["scheduled_start_time"] + self.scheduled_end_time: Optional[str] = data.get("scheduled_end_time") + self.privacy_level: GuildScheduledEventPrivacyLevel = ( + GuildScheduledEventPrivacyLevel(data["privacy_level"]) + ) + self.status: GuildScheduledEventStatus = GuildScheduledEventStatus( + data["status"] + ) + self.entity_type: GuildScheduledEventEntityType = GuildScheduledEventEntityType( + data["entity_type"] + ) + self.entity_id: Optional[str] = data.get("entity_id") + self.entity_metadata: Optional[Dict[str, Any]] = data.get("entity_metadata") + self.creator: Optional[User] = ( + User(data["creator"]) if data.get("creator") else None + ) + self.user_count: Optional[int] = data.get("user_count") + self.image: Optional[str] = data.get("image") + + def __repr__(self) -> str: + return f"" + + +@dataclass +class Invite: + """Represents a Discord invite.""" + + code: str + channel_id: Optional[str] + guild_id: Optional[str] + inviter_id: Optional[str] + uses: Optional[int] + max_uses: Optional[int] + max_age: Optional[int] + temporary: Optional[bool] + created_at: Optional[str] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Invite": + channel = data.get("channel") + guild = data.get("guild") + inviter = data.get("inviter") + return cls( + code=data["code"], + channel_id=(channel or {}).get("id") if channel else data.get("channel_id"), + guild_id=(guild or {}).get("id") if guild else data.get("guild_id"), + inviter_id=(inviter or {}).get("id"), + uses=data.get("uses"), + max_uses=data.get("max_uses"), + max_age=data.get("max_age"), + temporary=data.get("temporary"), + created_at=data.get("created_at"), + ) + + def __repr__(self) -> str: + return f"" + + +class GuildMemberRemove: + """Represents a GUILD_MEMBER_REMOVE event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.guild_id: str = data["guild_id"] + self.user: User = User(data["user"]) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class GuildBanAdd: + """Represents a GUILD_BAN_ADD event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.guild_id: str = data["guild_id"] + self.user: User = User(data["user"]) + + def __repr__(self) -> str: + return f"" + + +class GuildBanRemove: + """Represents a GUILD_BAN_REMOVE event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.guild_id: str = data["guild_id"] + self.user: User = User(data["user"]) + + def __repr__(self) -> str: + return f"" + + +class GuildRoleUpdate: + """Represents a GUILD_ROLE_UPDATE event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.guild_id: str = data["guild_id"] + self.role: Role = Role(data["role"]) + + def __repr__(self) -> str: + return f"" + + +class AuditLogEntry: + """Represents a single entry in a guild's audit log.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ) -> None: + self._client = client_instance + self.id: str = data["id"] + self.user_id: Optional[str] = data.get("user_id") + self.target_id: Optional[str] = data.get("target_id") + self.action_type: int = data["action_type"] + self.reason: Optional[str] = data.get("reason") + self.changes: List[Dict[str, Any]] = data.get("changes", []) + self.options: Optional[Dict[str, Any]] = data.get("options") + + def __repr__(self) -> str: + return f"" + + +def channel_factory(data: Dict[str, Any], client: "Client") -> Channel: + """Create a channel object from raw API data.""" + channel_type = data.get("type") + + if channel_type in ( + ChannelType.GUILD_TEXT.value, + ChannelType.GUILD_ANNOUNCEMENT.value, + ): + return TextChannel(data, client) + if channel_type == ChannelType.GUILD_VOICE.value: + return VoiceChannel(data, client) + if channel_type == ChannelType.GUILD_STAGE_VOICE.value: + return StageChannel(data, client) + if channel_type == ChannelType.GUILD_CATEGORY.value: + return CategoryChannel(data, client) + if channel_type in ( + ChannelType.ANNOUNCEMENT_THREAD.value, + ChannelType.PUBLIC_THREAD.value, + ChannelType.PRIVATE_THREAD.value, + ): + return Thread(data, client) + if channel_type in (ChannelType.DM.value, ChannelType.GROUP_DM.value): + return DMChannel(data, client) + + return Channel(data, client) diff --git a/disagreement/ui/view.py b/disagreement/ui/view.py index c91cb6c..a1a13da 100644 --- a/disagreement/ui/view.py +++ b/disagreement/ui/view.py @@ -1,167 +1,167 @@ -from __future__ import annotations - -import asyncio -import uuid -from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING - -from ..models import ActionRow -from .item import Item - -if TYPE_CHECKING: - from ..client import Client - from ..interactions import Interaction - - -class View: - """Represents a container for UI components that can be sent with a message. - - Args: - timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out. - Defaults to 180. - """ - - def __init__(self, *, timeout: Optional[float] = 180.0): - self.timeout = timeout - self.id = str(uuid.uuid4()) - self.__children: List[Item] = [] - self.__stopped = asyncio.Event() - self._client: Optional[Client] = None - self._message_id: Optional[str] = None - +from __future__ import annotations + +import asyncio +import uuid +from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING + +from ..models import ActionRow +from .item import Item + +if TYPE_CHECKING: + from ..client import Client + from ..interactions import Interaction + + +class View: + """Represents a container for UI components that can be sent with a message. + + Args: + timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out. + Defaults to 180. + """ + + def __init__(self, *, timeout: Optional[float] = 180.0): + self.timeout = timeout + self.id = str(uuid.uuid4()) + self.__children: List[Item] = [] + self.__stopped = asyncio.Event() + self._client: Optional[Client] = None + self._message_id: Optional[str] = None + # The below is a bit of a hack to support items defined as class members # e.g. button = Button(...) - for item in self.__class__.__dict__.values(): - if isinstance(item, Item): - self.add_item(item) - - @property - def children(self) -> List[Item]: - return self.__children - - def add_item(self, item: Item): - """Adds an item to the view.""" - if not isinstance(item, Item): - raise TypeError("Only instances of 'Item' can be added to a View.") - - if len(self.__children) >= 25: - raise ValueError("A view can only have a maximum of 25 components.") - + for item in self.__class__.__dict__.values(): + if isinstance(item, Item): + self.add_item(item) + + @property + def children(self) -> List[Item]: + return self.__children + + def add_item(self, item: Item): + """Adds an item to the view.""" + if not isinstance(item, Item): + raise TypeError("Only instances of 'Item' can be added to a View.") + + if len(self.__children) >= 25: + raise ValueError("A view can only have a maximum of 25 components.") + if self.timeout is None and item.custom_id is None: raise ValueError( "All components in a persistent view must have a 'custom_id'." ) - item._view = self - self.__children.append(item) - - @property - def message_id(self) -> Optional[str]: - return self._message_id - - @message_id.setter - def message_id(self, value: str): - self._message_id = value - - def to_components(self) -> List[ActionRow]: - """Converts the view's children into a list of ActionRow components. - - This retains the original, simple layout behaviour where each item is - placed in its own :class:`ActionRow` to ensure backward compatibility. - """ - - rows: List[ActionRow] = [] - - for item in self.children: + item._view = self + self.__children.append(item) + + @property + def message_id(self) -> Optional[str]: + return self._message_id + + @message_id.setter + def message_id(self, value: str): + self._message_id = value + + def to_components(self) -> List[ActionRow]: + """Converts the view's children into a list of ActionRow components. + + This retains the original, simple layout behaviour where each item is + placed in its own :class:`ActionRow` to ensure backward compatibility. + """ + + rows: List[ActionRow] = [] + + for item in self.children: rows.append(ActionRow(components=[item])) - - return rows - - def layout_components_advanced(self) -> List[ActionRow]: - """Group compatible components into rows following Discord rules.""" - - rows: List[ActionRow] = [] - - for item in self.children: - if item.custom_id is None: - item.custom_id = ( - f"{self.id}:{item.__class__.__name__}:{len(self.__children)}" - ) - - target_row = item.row - if target_row is not None: - if not 0 <= target_row <= 4: - raise ValueError("Row index must be between 0 and 4.") - - while len(rows) <= target_row: - if len(rows) >= 5: - raise ValueError("A view can have at most 5 action rows.") - rows.append(ActionRow()) - - rows[target_row].add_component(item) - continue - - placed = False - for row in rows: - try: - row.add_component(item) - placed = True - break - except ValueError: - continue - - if not placed: - if len(rows) >= 5: - raise ValueError("A view can have at most 5 action rows.") - new_row = ActionRow([item]) - rows.append(new_row) - - return rows - - def to_components_payload(self) -> List[Dict[str, Any]]: - """Converts the view's children into a list of component dictionaries - that can be sent to the Discord API.""" - return [row.to_dict() for row in self.to_components()] - - async def _dispatch(self, interaction: Interaction): - """Called by the client to dispatch an interaction to the correct item.""" - if self.timeout is not None: - self.__stopped.set() # Reset the timeout on each interaction - self.__stopped.clear() - - if interaction.data: - custom_id = interaction.data.custom_id - for child in self.children: - if child.custom_id == custom_id: - if child.callback: - await child.callback(self, interaction) - break - - async def wait(self) -> bool: - """Waits until the view has stopped interacting.""" - return await self.__stopped.wait() - - def stop(self): - """Stops the view from listening to interactions.""" - if not self.__stopped.is_set(): - self.__stopped.set() - - async def on_timeout(self): - """Called when the view times out.""" - pass # User can override this - - async def _start(self, client: Client): - """Starts the view's internal listener.""" - self._client = client - if self.timeout is not None: - asyncio.create_task(self._timeout_task()) - - async def _timeout_task(self): - """The task that waits for the timeout and then stops the view.""" - try: - await asyncio.wait_for(self.wait(), timeout=self.timeout) - except asyncio.TimeoutError: - self.stop() - await self.on_timeout() - if self._client and self._message_id: - # Remove the view from the client's listeners - self._client._views.pop(self._message_id, None) + + return rows + + def layout_components_advanced(self) -> List[ActionRow]: + """Group compatible components into rows following Discord rules.""" + + rows: List[ActionRow] = [] + + for item in self.children: + if item.custom_id is None: + item.custom_id = ( + f"{self.id}:{item.__class__.__name__}:{len(self.__children)}" + ) + + target_row = item.row + if target_row is not None: + if not 0 <= target_row <= 4: + raise ValueError("Row index must be between 0 and 4.") + + while len(rows) <= target_row: + if len(rows) >= 5: + raise ValueError("A view can have at most 5 action rows.") + rows.append(ActionRow()) + + rows[target_row].add_component(item) + continue + + placed = False + for row in rows: + try: + row.add_component(item) + placed = True + break + except ValueError: + continue + + if not placed: + if len(rows) >= 5: + raise ValueError("A view can have at most 5 action rows.") + new_row = ActionRow([item]) + rows.append(new_row) + + return rows + + def to_components_payload(self) -> List[Dict[str, Any]]: + """Converts the view's children into a list of component dictionaries + that can be sent to the Discord API.""" + return [row.to_dict() for row in self.to_components()] + + async def _dispatch(self, interaction: Interaction): + """Called by the client to dispatch an interaction to the correct item.""" + if self.timeout is not None: + self.__stopped.set() # Reset the timeout on each interaction + self.__stopped.clear() + + if interaction.data: + custom_id = interaction.data.custom_id + for child in self.children: + if child.custom_id == custom_id: + if child.callback: + await child.callback(self, interaction) + break + + async def wait(self) -> bool: + """Waits until the view has stopped interacting.""" + return await self.__stopped.wait() + + def stop(self): + """Stops the view from listening to interactions.""" + if not self.__stopped.is_set(): + self.__stopped.set() + + async def on_timeout(self): + """Called when the view times out.""" + pass # User can override this + + async def _start(self, client: Client): + """Starts the view's internal listener.""" + self._client = client + if self.timeout is not None: + asyncio.create_task(self._timeout_task()) + + async def _timeout_task(self): + """The task that waits for the timeout and then stops the view.""" + try: + await asyncio.wait_for(self.wait(), timeout=self.timeout) + except asyncio.TimeoutError: + self.stop() + await self.on_timeout() + if self._client and self._message_id: + # Remove the view from the client's listeners + self._client._views.pop(self._message_id, None) diff --git a/disagreement/voice_client.py b/disagreement/voice_client.py index 38b8d09..c771869 100644 --- a/disagreement/voice_client.py +++ b/disagreement/voice_client.py @@ -1,130 +1,130 @@ -# disagreement/voice_client.py -"""Voice gateway and UDP audio client.""" - -from __future__ import annotations - -import asyncio -import contextlib -import socket +# disagreement/voice_client.py +"""Voice gateway and UDP audio client.""" + +from __future__ import annotations + +import asyncio +import contextlib +import socket import threading from typing import TYPE_CHECKING, Optional, Sequence - -import aiohttp + +import aiohttp # The following import is correct, but may be flagged by Pylance if the virtual # environment is not configured correctly. from nacl.secret import SecretBox - + from .audio import AudioSink, AudioSource, FFmpegAudioSource from .models import User if TYPE_CHECKING: from .client import Client - - -class VoiceClient: - """Handles the Discord voice WebSocket connection and UDP streaming.""" - - def __init__( - self, + + +class VoiceClient: + """Handles the Discord voice WebSocket connection and UDP streaming.""" + + def __init__( + self, client: Client, - endpoint: str, - session_id: str, - token: str, - guild_id: int, - user_id: int, - *, - ws=None, - udp: Optional[socket.socket] = None, - loop: Optional[asyncio.AbstractEventLoop] = None, - verbose: bool = False, - ) -> None: + endpoint: str, + session_id: str, + token: str, + guild_id: int, + user_id: int, + *, + ws=None, + udp: Optional[socket.socket] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + verbose: bool = False, + ) -> None: self.client = client - self.endpoint = endpoint - self.session_id = session_id - self.token = token - self.guild_id = str(guild_id) - self.user_id = str(user_id) - self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws - self._udp = udp - self._session: Optional[aiohttp.ClientSession] = None - self._heartbeat_task: Optional[asyncio.Task] = None + self.endpoint = endpoint + self.session_id = session_id + self.token = token + self.guild_id = str(guild_id) + self.user_id = str(user_id) + self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws + self._udp = udp + self._session: Optional[aiohttp.ClientSession] = None + self._heartbeat_task: Optional[asyncio.Task] = None self._receive_task: Optional[asyncio.Task] = None self._udp_receive_thread: Optional[threading.Thread] = None - self._heartbeat_interval: Optional[float] = None + self._heartbeat_interval: Optional[float] = None try: self._loop = loop or asyncio.get_running_loop() except RuntimeError: self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - self.verbose = verbose - self.ssrc: Optional[int] = None - self.secret_key: Optional[Sequence[int]] = None - self._server_ip: Optional[str] = None - self._server_port: Optional[int] = None - self._current_source: Optional[AudioSource] = None - self._play_task: Optional[asyncio.Task] = None + self.verbose = verbose + self.ssrc: Optional[int] = None + self.secret_key: Optional[Sequence[int]] = None + self._server_ip: Optional[str] = None + self._server_port: Optional[int] = None + self._current_source: Optional[AudioSource] = None + self._play_task: Optional[asyncio.Task] = None self._sink: Optional[AudioSink] = None self._ssrc_map: dict[int, int] = {} self._ssrc_lock = threading.Lock() - - async def connect(self) -> None: - if self._ws is None: - self._session = aiohttp.ClientSession() - self._ws = await self._session.ws_connect(self.endpoint) - - hello = await self._ws.receive_json() - self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000 - self._heartbeat_task = self._loop.create_task(self._heartbeat()) - - await self._ws.send_json( - { - "op": 0, - "d": { - "server_id": self.guild_id, - "user_id": self.user_id, - "session_id": self.session_id, - "token": self.token, - }, - } - ) - - ready = await self._ws.receive_json() - data = ready["d"] - self.ssrc = data["ssrc"] - self._server_ip = data["ip"] - self._server_port = data["port"] - - if self._udp is None: - self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self._udp.connect((self._server_ip, self._server_port)) - - await self._ws.send_json( - { - "op": 1, - "d": { - "protocol": "udp", - "data": { - "address": self._udp.getsockname()[0], - "port": self._udp.getsockname()[1], - "mode": "xsalsa20_poly1305", - }, - }, - } - ) - - session_desc = await self._ws.receive_json() - self.secret_key = session_desc["d"].get("secret_key") - - async def _heartbeat(self) -> None: - assert self._ws is not None - assert self._heartbeat_interval is not None - try: - while True: - await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)}) - await asyncio.sleep(self._heartbeat_interval) - except asyncio.CancelledError: - pass - + + async def connect(self) -> None: + if self._ws is None: + self._session = aiohttp.ClientSession() + self._ws = await self._session.ws_connect(self.endpoint) + + hello = await self._ws.receive_json() + self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000 + self._heartbeat_task = self._loop.create_task(self._heartbeat()) + + await self._ws.send_json( + { + "op": 0, + "d": { + "server_id": self.guild_id, + "user_id": self.user_id, + "session_id": self.session_id, + "token": self.token, + }, + } + ) + + ready = await self._ws.receive_json() + data = ready["d"] + self.ssrc = data["ssrc"] + self._server_ip = data["ip"] + self._server_port = data["port"] + + if self._udp is None: + self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self._udp.connect((self._server_ip, self._server_port)) + + await self._ws.send_json( + { + "op": 1, + "d": { + "protocol": "udp", + "data": { + "address": self._udp.getsockname()[0], + "port": self._udp.getsockname()[1], + "mode": "xsalsa20_poly1305", + }, + }, + } + ) + + session_desc = await self._ws.receive_json() + self.secret_key = session_desc["d"].get("secret_key") + + async def _heartbeat(self) -> None: + assert self._ws is not None + assert self._heartbeat_interval is not None + try: + while True: + await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)}) + await asyncio.sleep(self._heartbeat_interval) + except asyncio.CancelledError: + pass + async def _receive_loop(self) -> None: assert self._ws is not None while True: @@ -168,48 +168,48 @@ class VoiceClient: if self.verbose: print(f"Error in UDP receive loop: {e}") - async def send_audio_frame(self, frame: bytes) -> None: - if not self._udp: - raise RuntimeError("UDP socket not initialised") - self._udp.send(frame) - - async def _play_loop(self) -> None: - assert self._current_source is not None - try: - while True: - data = await self._current_source.read() - if not data: - break - await self.send_audio_frame(data) - finally: - await self._current_source.close() - self._current_source = None - self._play_task = None - - async def stop(self) -> None: - if self._play_task: - self._play_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._play_task - self._play_task = None - if self._current_source: - await self._current_source.close() - self._current_source = None - - async def play(self, source: AudioSource, *, wait: bool = True) -> None: - """|coro| Play an :class:`AudioSource` on the voice connection.""" - - await self.stop() - self._current_source = source - self._play_task = self._loop.create_task(self._play_loop()) - if wait: - await self._play_task - - async def play_file(self, filename: str, *, wait: bool = True) -> None: - """|coro| Stream an audio file or URL using FFmpeg.""" - - await self.play(FFmpegAudioSource(filename), wait=wait) - + async def send_audio_frame(self, frame: bytes) -> None: + if not self._udp: + raise RuntimeError("UDP socket not initialised") + self._udp.send(frame) + + async def _play_loop(self) -> None: + assert self._current_source is not None + try: + while True: + data = await self._current_source.read() + if not data: + break + await self.send_audio_frame(data) + finally: + await self._current_source.close() + self._current_source = None + self._play_task = None + + async def stop(self) -> None: + if self._play_task: + self._play_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._play_task + self._play_task = None + if self._current_source: + await self._current_source.close() + self._current_source = None + + async def play(self, source: AudioSource, *, wait: bool = True) -> None: + """|coro| Play an :class:`AudioSource` on the voice connection.""" + + await self.stop() + self._current_source = source + self._play_task = self._loop.create_task(self._play_loop()) + if wait: + await self._play_task + + async def play_file(self, filename: str, *, wait: bool = True) -> None: + """|coro| Stream an audio file or URL using FFmpeg.""" + + await self.play(FFmpegAudioSource(filename), wait=wait) + def listen(self, sink: AudioSink) -> None: """Start listening to voice and routing to a sink.""" if not isinstance(sink, AudioSink): @@ -222,22 +222,22 @@ class VoiceClient: ) self._udp_receive_thread.start() - async def close(self) -> None: - await self.stop() - if self._heartbeat_task: - self._heartbeat_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._heartbeat_task + async def close(self) -> None: + await self.stop() + if self._heartbeat_task: + self._heartbeat_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._heartbeat_task if self._receive_task: self._receive_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._receive_task - if self._ws: - await self._ws.close() - if self._session: - await self._session.close() - if self._udp: - self._udp.close() + if self._ws: + await self._ws.close() + if self._session: + await self._session.close() + if self._udp: + self._udp.close() if self._udp_receive_thread: self._udp_receive_thread.join(timeout=1) if self._sink: diff --git a/pyproject.toml b/pyproject.toml index aefbedb..e7d3856 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,57 +1,57 @@ -[project] -name = "disagreement" -version = "0.2.0rc1" -description = "A Python library for the Discord API." -readme = "README.md" -requires-python = ">=3.10" -license = {text = "BSD 3-Clause"} -authors = [ - {name = "Slipstream", email = "me@slipstreamm.dev"} -] -keywords = ["discord", "api", "bot", "async", "aiohttp"] -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: Internet", -] - -dependencies = [ - "aiohttp>=3.9.0,<4.0.0", +[project] +name = "disagreement" +version = "0.2.0rc1" +description = "A Python library for the Discord API." +readme = "README.md" +requires-python = ">=3.10" +license = {text = "BSD 3-Clause"} +authors = [ + {name = "Slipstream", email = "me@slipstreamm.dev"} +] +keywords = ["discord", "api", "bot", "async", "aiohttp"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Internet", +] + +dependencies = [ + "aiohttp>=3.9.0,<4.0.0", "PyNaCl>=1.5.0,<2.0.0", -] - -[project.optional-dependencies] -test = [ - "pytest>=8.0.0", - "pytest-asyncio>=1.0.0", - "hypothesis>=6.132.0", -] -dev = [ - "python-dotenv>=1.0.0", -] - -[project.urls] -Homepage = "https://github.com/Slipstreamm/disagreement" -Issues = "https://github.com/Slipstreamm/disagreement/issues" - -[build-system] -requires = ["setuptools>=61.0"] -build-backend = "setuptools.build_meta" - -# Optional: for linting/formatting, e.g., Ruff -# [tool.ruff] -# line-length = 88 -# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set -# ignore = [] - -# [tool.ruff.format] -# quote-style = "double" +] + +[project.optional-dependencies] +test = [ + "pytest>=8.0.0", + "pytest-asyncio>=1.0.0", + "hypothesis>=6.132.0", +] +dev = [ + "python-dotenv>=1.0.0", +] + +[project.urls] +Homepage = "https://github.com/Slipstreamm/disagreement" +Issues = "https://github.com/Slipstreamm/disagreement/issues" + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +# Optional: for linting/formatting, e.g., Ruff +# [tool.ruff] +# line-length = 88 +# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set +# ignore = [] + +# [tool.ruff.format] +# quote-style = "double"