From ed83a9da85b2e95e8d60109835909c491f348864 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Wed, 11 Jun 2025 02:11:33 -0600 Subject: [PATCH] Implements caching system with TTL and member filtering Introduces a flexible caching infrastructure with time-to-live support and configurable member caching based on status, voice state, and join events. Adds AudioSink abstract base class to support audio output handling in voice connections. Replaces direct dictionary access with cache objects throughout the client, enabling automatic expiration and intelligent member filtering based on user-defined flags. Updates guild parsing to incorporate presence and voice state data for more accurate member caching decisions. --- disagreement/audio.py | 17 + disagreement/cache.py | 32 +- disagreement/caching.py | 120 ++ disagreement/client.py | 2801 +++++++++++++++--------------- disagreement/event_dispatcher.py | 4 +- disagreement/models.py | 2446 +++++++++++++------------- 6 files changed, 2846 insertions(+), 2574 deletions(-) create mode 100644 disagreement/caching.py diff --git a/disagreement/audio.py b/disagreement/audio.py index f70369c..9e58530 100644 --- a/disagreement/audio.py +++ b/disagreement/audio.py @@ -114,3 +114,20 @@ class FFmpegAudioSource(AudioSource): if isinstance(self.source, io.IOBase): with contextlib.suppress(Exception): self.source.close() + +class AudioSink: + """Abstract base class for audio sinks.""" + + def write(self, user, data): + """Write a chunk of PCM audio. + + Subclasses must implement this. The data is raw PCM at 48kHz + stereo. + """ + + raise NotImplementedError + + def close(self) -> None: + """Cleanup the sink when the voice client disconnects.""" + + return None diff --git a/disagreement/cache.py b/disagreement/cache.py index 666d46b..178e8ae 100644 --- a/disagreement/cache.py +++ b/disagreement/cache.py @@ -4,7 +4,8 @@ import time from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar if TYPE_CHECKING: - from .models import Channel, Guild + from .models import Channel, Guild, Member + from .caching import MemberCacheFlags T = TypeVar("T") @@ -53,3 +54,32 @@ class GuildCache(Cache["Guild"]): class ChannelCache(Cache["Channel"]): """Cache specifically for :class:`Channel` objects.""" + + +class MemberCache(Cache["Member"]): + """ + A cache for :class:`Member` objects that respects :class:`MemberCacheFlags`. + """ + + def __init__(self, flags: MemberCacheFlags, ttl: Optional[float] = None) -> None: + super().__init__(ttl) + self.flags = flags + + def _should_cache(self, member: Member) -> bool: + """Determines if a member should be cached based on the flags.""" + if self.flags.all: + return True + if self.flags.none: + return False + + if self.flags.online and member.status != "offline": + return True + if self.flags.voice and member.voice_state is not None: + return True + if self.flags.joined and getattr(member, "_just_joined", False): + return True + return False + + def set(self, key: str, value: Member) -> None: + if self._should_cache(value): + super().set(key, value) diff --git a/disagreement/caching.py b/disagreement/caching.py new file mode 100644 index 0000000..7aca2b8 --- /dev/null +++ b/disagreement/caching.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import operator +from typing import Any, Callable, ClassVar, Dict, Iterator, Tuple + + +class _MemberCacheFlagValue: + flag: int + + def __init__(self, func: Callable[[Any], bool]): + self.flag = getattr(func, 'flag', 0) + self.__doc__ = func.__doc__ + + def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any: + if instance is None: + return self + return instance.value & self.flag != 0 + + def __set__(self, instance: Any, value: bool) -> None: + if value: + instance.value |= self.flag + else: + instance.value &= ~self.flag + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} flag={self.flag}>' + + +def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]: + def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue: + setattr(func, 'flag', flag) + return _MemberCacheFlagValue(func) + return decorator + + +class MemberCacheFlags: + __slots__ = ('value',) + + VALID_FLAGS: ClassVar[Dict[str, int]] = { + 'joined': 1 << 0, + 'voice': 1 << 1, + 'online': 1 << 2, + } + DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4 + ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values()) + + def __init__(self, **kwargs: bool): + self.value = self.DEFAULT_FLAGS + for key, value in kwargs.items(): + if key not in self.VALID_FLAGS: + raise TypeError(f'{key!r} is not a valid member cache flag.') + setattr(self, key, value) + + @classmethod + def _from_value(cls, value: int) -> MemberCacheFlags: + self = cls.__new__(cls) + self.value = value + return self + + def __eq__(self, other: object) -> bool: + return isinstance(other, MemberCacheFlags) and self.value == other.value + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash(self.value) + + def __repr__(self) -> str: + return f'' + + def __iter__(self) -> Iterator[Tuple[str, bool]]: + for name in self.VALID_FLAGS: + yield name, getattr(self, name) + + def __int__(self) -> int: + return self.value + + def __index__(self) -> int: + return self.value + + @classmethod + def all(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with all flags enabled.""" + return cls._from_value(cls.ALL_FLAGS) + + @classmethod + def none(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with all flags disabled.""" + return cls._from_value(0) + + @classmethod + def only_joined(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled.""" + return cls._from_value(cls.VALID_FLAGS['joined']) + + @classmethod + def only_voice(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled.""" + return cls._from_value(cls.VALID_FLAGS['voice']) + + @classmethod + def only_online(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled.""" + return cls._from_value(cls.VALID_FLAGS['online']) + + @flag_value(1 << 0) + def joined(self) -> bool: + """Whether to cache members that have just joined the guild.""" + return False + + @flag_value(1 << 1) + def voice(self) -> bool: + """Whether to cache members that are in a voice channel.""" + return False + + @flag_value(1 << 2) + def online(self) -> bool: + """Whether to cache members that are online.""" + return False \ No newline at end of file diff --git a/disagreement/client.py b/disagreement/client.py index df73b26..23919eb 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -1,584 +1,587 @@ -# disagreement/client.py - -""" -The main Client class for interacting with the Discord API. -""" - -import asyncio -import signal -from typing import ( - Optional, - Callable, - Any, - TYPE_CHECKING, - Awaitable, - AsyncIterator, - Union, - List, - Dict, -) -from types import ModuleType - -from .http import HTTPClient -from .gateway import GatewayClient -from .shard_manager import ShardManager -from .event_dispatcher import EventDispatcher -from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion -from .errors import DisagreementException, AuthenticationError -from .typing import Typing -from .ext.commands.core import CommandHandler -from .ext.commands.cog import Cog -from .ext.app_commands.handler import AppCommandHandler -from .ext.app_commands.context import AppCommandContext -from .ext import loader as ext_loader -from .interactions import Interaction, Snowflake -from .error_handler import setup_global_error_handler -from .voice_client import VoiceClient - -if TYPE_CHECKING: - from .models import ( - Message, - Embed, - ActionRow, - Guild, - Channel, - User, - Member, - Role, - TextChannel, - VoiceChannel, - CategoryChannel, - Thread, - DMChannel, - Webhook, - GuildTemplate, - ScheduledEvent, - AuditLogEntry, - Invite, - ) - from .ui.view import View - from .enums import ChannelType as EnumChannelType - from .ext.commands.core import CommandContext - from .ext.commands.errors import CommandError, CommandInvokeError - from .ext.app_commands.commands import AppCommand, AppCommandGroup - - -class Client: - """ - Represents a client connection that connects to Discord. - This class is used to interact with the Discord WebSocket and API. - - Args: - token (str): The bot token for authentication. - intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`. - You might need to enable privileged intents in your bot's application page. - loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations. - Defaults to `asyncio.get_event_loop()`. - command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]): - The prefix(es) for commands. Defaults to '!'. - verbose (bool): If True, print raw HTTP and Gateway traffic for debugging. - http_options (Optional[Dict[str, Any]]): Extra options passed to - :class:`HTTPClient` for creating the internal - :class:`aiohttp.ClientSession`. - """ - - def __init__( - self, - token: str, - intents: Optional[int] = None, - loop: Optional[asyncio.AbstractEventLoop] = None, - command_prefix: Union[ - str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] - ] = "!", - application_id: Optional[Union[str, int]] = None, - verbose: bool = False, - mention_replies: bool = False, - shard_count: Optional[int] = None, - gateway_max_retries: int = 5, - gateway_max_backoff: float = 60.0, - http_options: Optional[Dict[str, Any]] = None, - ): - if not token: - raise ValueError("A bot token must be provided.") - - self.token: str = token - self.intents: int = intents if intents is not None else GatewayIntent.default() - self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() - self.application_id: Optional[Snowflake] = ( - str(application_id) if application_id else None - ) - setup_global_error_handler(self.loop) - - self.verbose: bool = verbose - self._http: HTTPClient = HTTPClient( - token=self.token, - verbose=verbose, - **(http_options or {}), - ) - self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self) - self._gateway: Optional[GatewayClient] = ( - None # Initialized in run() or connect() - ) - self.shard_count: Optional[int] = shard_count - self.gateway_max_retries: int = gateway_max_retries - self.gateway_max_backoff: float = gateway_max_backoff - self._shard_manager: Optional[ShardManager] = None - - # Initialize CommandHandler - self.command_handler: CommandHandler = CommandHandler( - client=self, prefix=command_prefix - ) - self.app_command_handler: AppCommandHandler = AppCommandHandler(client=self) - # Register internal listener for processing commands from messages - self._event_dispatcher.register( - "MESSAGE_CREATE", self._process_message_for_commands - ) - - self._closed: bool = False - self._ready_event: asyncio.Event = asyncio.Event() - self.user: Optional["User"] = ( - None # The bot's own user object, populated on READY - ) - - # Internal Caches - self._guilds: Dict[Snowflake, "Guild"] = {} - self._channels: Dict[Snowflake, "Channel"] = ( - {} - ) # Stores all channel types by ID - self._users: Dict[Snowflake, Any] = ( - {} - ) # Placeholder for User model cache if needed - self._messages: Dict[Snowflake, "Message"] = {} - self._views: Dict[Snowflake, "View"] = {} - self._voice_clients: Dict[Snowflake, VoiceClient] = {} - self._webhooks: Dict[Snowflake, "Webhook"] = {} - - # Default whether replies mention the user - self.mention_replies: bool = mention_replies - - # Basic signal handling for graceful shutdown - # This might be better handled by the user's application code, but can be a nice default. - # For more robust handling, consider libraries or more advanced patterns. - try: - self.loop.add_signal_handler( - signal.SIGINT, lambda: self.loop.create_task(self.close()) - ) - self.loop.add_signal_handler( - signal.SIGTERM, lambda: self.loop.create_task(self.close()) - ) - except NotImplementedError: - # add_signal_handler is not available on all platforms (e.g., Windows default event loop policy) - # Users on these platforms would need to handle shutdown differently. - print( - "Warning: Signal handlers for SIGINT/SIGTERM could not be added. " - "Graceful shutdown via signals might not work as expected on this platform." - ) - - async def _initialize_gateway(self): - """Initializes the GatewayClient if it doesn't exist.""" - if self._gateway is None: - self._gateway = GatewayClient( - http_client=self._http, - event_dispatcher=self._event_dispatcher, - token=self.token, - intents=self.intents, - client_instance=self, - verbose=self.verbose, - max_retries=self.gateway_max_retries, - max_backoff=self.gateway_max_backoff, - ) - - async def _initialize_shard_manager(self) -> None: - """Initializes the :class:`ShardManager` if not already created.""" - if self._shard_manager is None: - count = self.shard_count or 1 - self._shard_manager = ShardManager(self, count) - - async def connect(self, reconnect: bool = True) -> None: - """ - Establishes a connection to Discord. This includes logging in and connecting to the Gateway. - This method is a coroutine. - - Args: - reconnect (bool): Whether to automatically attempt to reconnect on disconnect. - (Note: Basic reconnect logic is within GatewayClient for now) - - Raises: - GatewayException: If the connection to the gateway fails. - AuthenticationError: If the token is invalid. - """ - if self._closed: - raise DisagreementException("Client is closed and cannot connect.") - if self.shard_count and self.shard_count > 1: - await self._initialize_shard_manager() - assert self._shard_manager is not None - await self._shard_manager.start() - print( - f"Client connected using {self.shard_count} shards, waiting for READY signal..." - ) - await self.wait_until_ready() - print("Client is READY!") - return - - await self._initialize_gateway() - assert self._gateway is not None # Should be initialized by now - - retry_delay = 5 # seconds - max_retries = 5 # For initial connection attempts by Client.run, Gateway has its own internal retries for some cases. - - for attempt in range(max_retries): - try: - await self._gateway.connect() - # After successful connection, GatewayClient's HELLO handler will trigger IDENTIFY/RESUME - # and its READY handler will set self._ready_event via dispatcher. - print("Client connected to Gateway, waiting for READY signal...") - await self.wait_until_ready() # Wait for the READY event from Gateway - print("Client is READY!") - return # Successfully connected and ready - except AuthenticationError: # Non-recoverable by retry here - print("Authentication failed. Please check your bot token.") - await self.close() # Ensure cleanup - raise - except DisagreementException as e: # Includes GatewayException - print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}") - if attempt < max_retries - 1: - print(f"Retrying in {retry_delay} seconds...") - await asyncio.sleep(retry_delay) - retry_delay = min( - retry_delay * 2, 60 - ) # Exponential backoff up to 60s - else: - print("Max connection retries reached. Giving up.") - await self.close() # Ensure cleanup - raise - # Should not be reached if max_retries is > 0 - if max_retries == 0: # If max_retries was 0, means no retries attempted - raise DisagreementException("Connection failed with 0 retries allowed.") - - async def run(self) -> None: - """ - A blocking call that connects the client to Discord and runs until the client is closed. - This method is a coroutine. - It handles login, Gateway connection, and keeping the connection alive. - """ - if self._closed: - raise DisagreementException("Client is already closed.") - - try: - await self.connect() - # The GatewayClient's _receive_loop will keep running. - # This run method effectively waits until the client is closed or an unhandled error occurs. - # A more robust implementation might have a main loop here that monitors gateway health. - # For now, we rely on the gateway's tasks. - while not self._closed: - if ( - self._gateway - and self._gateway._receive_task - and self._gateway._receive_task.done() - ): - # If receive task ended unexpectedly, try to handle it or re-raise - try: - exc = self._gateway._receive_task.exception() - if exc: - print( - f"Gateway receive task ended with exception: {exc}. Attempting to reconnect..." - ) - # This is a basic reconnect strategy from the client side. - # GatewayClient itself might handle some reconnects. - await self.close_gateway( - code=1000 - ) # Close current gateway state - await asyncio.sleep(5) # Wait before reconnecting - if ( - not self._closed - ): # If client wasn't closed by the exception handler - await self.connect() - else: - break # Client was closed, exit run loop - else: - print( - "Gateway receive task ended without exception. Assuming clean shutdown or reconnect handled internally." - ) - if ( - not self._closed - ): # If not explicitly closed, might be an issue - print( - "Warning: Gateway receive task ended but client not closed. This might indicate an issue." - ) - # Consider a more robust health check or reconnect strategy here. - await asyncio.sleep( - 1 - ) # Prevent tight loop if something is wrong - else: - break # Client was closed - except asyncio.CancelledError: - print("Gateway receive task was cancelled.") - break # Exit if cancelled - except Exception as e: - print(f"Error checking gateway receive task: {e}") - break # Exit on other errors - await asyncio.sleep(1) # Main loop check interval - except DisagreementException as e: - print(f"Client run loop encountered an error: {e}") - # Error already logged by connect or other methods - except asyncio.CancelledError: - print("Client run loop was cancelled.") - finally: - if not self._closed: - await self.close() - - async def close(self) -> None: - """ - Closes the connection to Discord. This method is a coroutine. - """ - if self._closed: - return - - self._closed = True - print("Closing client...") - - if self._shard_manager: - await self._shard_manager.close() - self._shard_manager = None - if self._gateway: - await self._gateway.close() - - if self._http: # HTTPClient has its own session to close - await self._http.close() - - self._ready_event.set() # Ensure any waiters for ready are unblocked - print("Client closed.") - - async def __aenter__(self) -> "Client": - """Enter the context manager by connecting to Discord.""" - await self.connect() - return self - - async def __aexit__( - self, - exc_type: Optional[type], - exc: Optional[BaseException], - tb: Optional[BaseException], - ) -> bool: - """Exit the context manager and close the client.""" - await self.close() - return False - - async def close_gateway(self, code: int = 1000) -> None: - """Closes only the gateway connection, allowing for potential reconnect.""" - if self._shard_manager: - await self._shard_manager.close() - self._shard_manager = None - if self._gateway: - await self._gateway.close(code=code) - self._gateway = None - self._ready_event.clear() # No longer ready if gateway is closed - - def is_closed(self) -> bool: - """Indicates if the client has been closed.""" - return self._closed - - def is_ready(self) -> bool: - """Indicates if the client has successfully connected to the Gateway and is ready.""" - return self._ready_event.is_set() - - @property - def latency(self) -> Optional[float]: - """Returns the gateway latency in seconds, or ``None`` if unavailable.""" - if self._gateway: - return self._gateway.latency - return None - - async def wait_until_ready(self) -> None: - """|coro| - Waits until the client is fully connected to Discord and the initial state is processed. - This is mainly useful for waiting for the READY event from the Gateway. - """ - await self._ready_event.wait() - - async def wait_for( - self, - event_name: str, - check: Optional[Callable[[Any], bool]] = None, - timeout: Optional[float] = None, - ) -> Any: - """|coro| - Waits for a specific event to occur that satisfies the ``check``. - - Parameters - ---------- - event_name: str - The name of the event to wait for. - check: Optional[Callable[[Any], bool]] - A function that determines whether the received event should resolve the wait. - timeout: Optional[float] - How long to wait for the event before raising :class:`asyncio.TimeoutError`. - """ - - future: asyncio.Future = self.loop.create_future() - self._event_dispatcher.add_waiter(event_name, future, check) - try: - return await asyncio.wait_for(future, timeout=timeout) - finally: - self._event_dispatcher.remove_waiter(event_name, future) - - async def change_presence( - self, - status: str, - activity_name: Optional[str] = None, - activity_type: int = 0, - since: int = 0, - afk: bool = False, - ): - """ - Changes the client's presence on Discord. - - Args: - status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible"). - activity_name (Optional[str]): The name of the activity. - activity_type (int): The type of the activity. - since (int): The timestamp (in milliseconds) of when the client went idle. - afk (bool): Whether the client is AFK. - """ - if self._closed: - raise DisagreementException("Client is closed.") - - if self._gateway: - await self._gateway.update_presence( - status=status, - activity_name=activity_name, - activity_type=activity_type, - since=since, - afk=afk, - ) - - # --- Event Handling --- - - def event( - self, coro: Callable[..., Awaitable[None]] - ) -> Callable[..., Awaitable[None]]: - """ - A decorator that registers an event to listen to. - The name of the coroutine is used as the event name. - Example: - @client.event - async def on_ready(): # Will listen for the 'READY' event - print("Bot is ready!") - - @client.event - async def on_message(message: disagreement.Message): # Will listen for 'MESSAGE_CREATE' - print(f"Message from {message.author}: {message.content}") - """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError("Event registered must be a coroutine function.") - - event_name = coro.__name__ - # Map common function names to Discord event types - # e.g., on_ready -> READY, on_message -> MESSAGE_CREATE - if event_name.startswith("on_"): - discord_event_name = event_name[3:].upper() - mapping = { - "MESSAGE": "MESSAGE_CREATE", - "MESSAGE_EDIT": "MESSAGE_UPDATE", - "MESSAGE_UPDATE": "MESSAGE_UPDATE", - "MESSAGE_DELETE": "MESSAGE_DELETE", - "REACTION_ADD": "MESSAGE_REACTION_ADD", - "REACTION_REMOVE": "MESSAGE_REACTION_REMOVE", - } - discord_event_name = mapping.get(discord_event_name, discord_event_name) - self._event_dispatcher.register(discord_event_name, coro) - else: - # If not starting with "on_", assume it's the direct Discord event name (e.g. "TYPING_START") - # Or raise an error if a specific format is required. - # For now, let's assume direct mapping if no "on_" prefix. - self._event_dispatcher.register(event_name.upper(), coro) - - return coro # Return the original coroutine - - def on_event( - self, event_name: str - ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: - """ - A decorator that registers an event to listen to with a specific event name. - Example: - @client.on_event('MESSAGE_CREATE') - async def my_message_handler(message: disagreement.Message): - print(f"Message: {message.content}") - """ - - def decorator( - coro: Callable[..., Awaitable[None]], - ) -> Callable[..., Awaitable[None]]: - if not asyncio.iscoroutinefunction(coro): - raise TypeError("Event registered must be a coroutine function.") - self._event_dispatcher.register(event_name.upper(), coro) - return coro - - return decorator - - async def _process_message_for_commands(self, message: "Message") -> None: - """Internal listener to process messages for commands.""" - # Make sure message object is valid and not from a bot (optional, common check) - if ( - not message or not message.author or message.author.bot - ): # Add .bot check to User model - return - await self.command_handler.process_commands(message) - - # --- Command Framework Methods --- - - def add_cog(self, cog: Cog) -> None: - """ - Adds a Cog to the bot. - Cogs are classes that group commands, listeners, and state. - This will also discover and register any application commands defined in the cog. - - Args: - cog (Cog): An instance of a class derived from `disagreement.ext.commands.Cog`. - """ - # Add to prefix command handler - self.command_handler.add_cog( - cog - ) # This should call cog._inject() internally or cog._inject() is called on Cog init - - # Discover and add application commands from the cog - # AppCommand and AppCommandGroup are already imported in TYPE_CHECKING block - for app_cmd_obj in cog.get_app_commands_and_groups(): # Uses the new method - # The cog attribute should have been set within Cog._inject() for AppCommands - self.app_command_handler.add_command(app_cmd_obj) - print( - f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'." - ) - - def remove_cog(self, cog_name: str) -> Optional[Cog]: - """ - Removes a Cog from the bot. - - Args: - cog_name (str): The name of the Cog to remove. - - Returns: - Optional[Cog]: The Cog that was removed, or None if not found. - """ - removed_cog = self.command_handler.remove_cog(cog_name) - if removed_cog: - # Also remove associated application commands - # This requires AppCommand to store a reference to its cog, or iterate all app_commands. - # Assuming AppCommand has a .cog attribute, which is set in Cog._inject() - # And AppCommandGroup might store commands that have .cog attribute - for app_cmd_or_group in removed_cog.get_app_commands_and_groups(): - # The AppCommandHandler.remove_command needs to handle both AppCommand and AppCommandGroup - self.app_command_handler.remove_command( - app_cmd_or_group.name - ) # Assuming name is unique enough for removal here - print( - f"Removed app command/group '{app_cmd_or_group.name}' from cog '{cog_name}'." - ) - # Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique - # (e.g. if it needs type or if groups and commands can share names). - # For now, assuming name is sufficient for removal from the handler's flat list. - return removed_cog - +# disagreement/client.py + +""" +The main Client class for interacting with the Discord API. +""" + +import asyncio +import signal +from typing import ( + Optional, + Callable, + Any, + TYPE_CHECKING, + Awaitable, + AsyncIterator, + Union, + List, + Dict, +) +from types import ModuleType + +from .http import HTTPClient +from .gateway import GatewayClient +from .shard_manager import ShardManager +from .event_dispatcher import EventDispatcher +from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion +from .errors import DisagreementException, AuthenticationError +from .typing import Typing +from .caching import MemberCacheFlags +from .cache import Cache, GuildCache, ChannelCache, MemberCache +from .ext.commands.core import Command, CommandHandler, Group +from .ext.commands.cog import Cog +from .ext.app_commands.handler import AppCommandHandler +from .ext.app_commands.context import AppCommandContext +from .ext import loader as ext_loader +from .interactions import Interaction, Snowflake +from .error_handler import setup_global_error_handler +from .voice_client import VoiceClient + +if TYPE_CHECKING: + from .models import ( + Message, + Embed, + ActionRow, + Guild, + Channel, + User, + Member, + Role, + TextChannel, + VoiceChannel, + CategoryChannel, + Thread, + DMChannel, + Webhook, + GuildTemplate, + ScheduledEvent, + AuditLogEntry, + Invite, + ) + from .ui.view import View + from .enums import ChannelType as EnumChannelType + from .ext.commands.core import CommandContext + from .ext.commands.errors import CommandError, CommandInvokeError + from .ext.app_commands.commands import AppCommand, AppCommandGroup + + +class Client: + """ + Represents a client connection that connects to Discord. + This class is used to interact with the Discord WebSocket and API. + + Args: + token (str): The bot token for authentication. + intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`. + You might need to enable privileged intents in your bot's application page. + loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations. + Defaults to `asyncio.get_event_loop()`. + command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]): + The prefix(es) for commands. Defaults to '!'. + verbose (bool): If True, print raw HTTP and Gateway traffic for debugging. + http_options (Optional[Dict[str, Any]]): Extra options passed to + :class:`HTTPClient` for creating the internal + :class:`aiohttp.ClientSession`. + """ + + def __init__( + self, + token: str, + intents: Optional[int] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + command_prefix: Union[ + str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] + ] = "!", + application_id: Optional[Union[str, int]] = None, + verbose: bool = False, + mention_replies: bool = False, + shard_count: Optional[int] = None, + gateway_max_retries: int = 5, + gateway_max_backoff: float = 60.0, + member_cache_flags: Optional[MemberCacheFlags] = None, + http_options: Optional[Dict[str, Any]] = None, + ): + if not token: + raise ValueError("A bot token must be provided.") + + self.token: str = token + self.member_cache_flags: MemberCacheFlags = ( + member_cache_flags if member_cache_flags is not None else MemberCacheFlags() + ) + self.intents: int = intents if intents is not None else GatewayIntent.default() + self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + self.application_id: Optional[Snowflake] = ( + str(application_id) if application_id else None + ) + setup_global_error_handler(self.loop) + + self.verbose: bool = verbose + self._http: HTTPClient = HTTPClient( + token=self.token, + verbose=verbose, + **(http_options or {}), + ) + self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self) + self._gateway: Optional[GatewayClient] = ( + None # Initialized in run() or connect() + ) + self.shard_count: Optional[int] = shard_count + self.gateway_max_retries: int = gateway_max_retries + self.gateway_max_backoff: float = gateway_max_backoff + self._shard_manager: Optional[ShardManager] = None + + # Initialize CommandHandler + self.command_handler: CommandHandler = CommandHandler( + client=self, prefix=command_prefix + ) + self.app_command_handler: AppCommandHandler = AppCommandHandler(client=self) + # Register internal listener for processing commands from messages + self._event_dispatcher.register( + "MESSAGE_CREATE", self._process_message_for_commands + ) + + self._closed: bool = False + self._ready_event: asyncio.Event = asyncio.Event() + self.user: Optional["User"] = ( + None # The bot's own user object, populated on READY + ) + + # Internal Caches + self._guilds: GuildCache = GuildCache() + self._channels: ChannelCache = ChannelCache() + self._users: Cache["User"] = Cache() + self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour + self._views: Dict[Snowflake, "View"] = {} + self._persistent_views: Dict[str, "View"] = {} + self._voice_clients: Dict[Snowflake, VoiceClient] = {} + self._webhooks: Dict[Snowflake, "Webhook"] = {} + + # Default whether replies mention the user + self.mention_replies: bool = mention_replies + + # Basic signal handling for graceful shutdown + # This might be better handled by the user's application code, but can be a nice default. + # For more robust handling, consider libraries or more advanced patterns. + try: + self.loop.add_signal_handler( + signal.SIGINT, lambda: self.loop.create_task(self.close()) + ) + self.loop.add_signal_handler( + signal.SIGTERM, lambda: self.loop.create_task(self.close()) + ) + except NotImplementedError: + # add_signal_handler is not available on all platforms (e.g., Windows default event loop policy) + # Users on these platforms would need to handle shutdown differently. + print( + "Warning: Signal handlers for SIGINT/SIGTERM could not be added. " + "Graceful shutdown via signals might not work as expected on this platform." + ) + + async def _initialize_gateway(self): + """Initializes the GatewayClient if it doesn't exist.""" + if self._gateway is None: + self._gateway = GatewayClient( + http_client=self._http, + event_dispatcher=self._event_dispatcher, + token=self.token, + intents=self.intents, + client_instance=self, + verbose=self.verbose, + max_retries=self.gateway_max_retries, + max_backoff=self.gateway_max_backoff, + ) + + async def _initialize_shard_manager(self) -> None: + """Initializes the :class:`ShardManager` if not already created.""" + if self._shard_manager is None: + count = self.shard_count or 1 + self._shard_manager = ShardManager(self, count) + + async def connect(self, reconnect: bool = True) -> None: + """ + Establishes a connection to Discord. This includes logging in and connecting to the Gateway. + This method is a coroutine. + + Args: + reconnect (bool): Whether to automatically attempt to reconnect on disconnect. + (Note: Basic reconnect logic is within GatewayClient for now) + + Raises: + GatewayException: If the connection to the gateway fails. + AuthenticationError: If the token is invalid. + """ + if self._closed: + raise DisagreementException("Client is closed and cannot connect.") + if self.shard_count and self.shard_count > 1: + await self._initialize_shard_manager() + assert self._shard_manager is not None + await self._shard_manager.start() + print( + f"Client connected using {self.shard_count} shards, waiting for READY signal..." + ) + await self.wait_until_ready() + print("Client is READY!") + return + + await self._initialize_gateway() + assert self._gateway is not None # Should be initialized by now + + retry_delay = 5 # seconds + max_retries = 5 # For initial connection attempts by Client.run, Gateway has its own internal retries for some cases. + + for attempt in range(max_retries): + try: + await self._gateway.connect() + # After successful connection, GatewayClient's HELLO handler will trigger IDENTIFY/RESUME + # and its READY handler will set self._ready_event via dispatcher. + print("Client connected to Gateway, waiting for READY signal...") + await self.wait_until_ready() # Wait for the READY event from Gateway + print("Client is READY!") + return # Successfully connected and ready + except AuthenticationError: # Non-recoverable by retry here + print("Authentication failed. Please check your bot token.") + await self.close() # Ensure cleanup + raise + except DisagreementException as e: # Includes GatewayException + print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}") + if attempt < max_retries - 1: + print(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + retry_delay = min( + retry_delay * 2, 60 + ) # Exponential backoff up to 60s + else: + print("Max connection retries reached. Giving up.") + await self.close() # Ensure cleanup + raise + # Should not be reached if max_retries is > 0 + if max_retries == 0: # If max_retries was 0, means no retries attempted + raise DisagreementException("Connection failed with 0 retries allowed.") + + async def run(self) -> None: + """ + A blocking call that connects the client to Discord and runs until the client is closed. + This method is a coroutine. + It handles login, Gateway connection, and keeping the connection alive. + """ + if self._closed: + raise DisagreementException("Client is already closed.") + + try: + await self.connect() + # The GatewayClient's _receive_loop will keep running. + # This run method effectively waits until the client is closed or an unhandled error occurs. + # A more robust implementation might have a main loop here that monitors gateway health. + # For now, we rely on the gateway's tasks. + while not self._closed: + if ( + self._gateway + and self._gateway._receive_task + and self._gateway._receive_task.done() + ): + # If receive task ended unexpectedly, try to handle it or re-raise + try: + exc = self._gateway._receive_task.exception() + if exc: + print( + f"Gateway receive task ended with exception: {exc}. Attempting to reconnect..." + ) + # This is a basic reconnect strategy from the client side. + # GatewayClient itself might handle some reconnects. + await self.close_gateway( + code=1000 + ) # Close current gateway state + await asyncio.sleep(5) # Wait before reconnecting + if ( + not self._closed + ): # If client wasn't closed by the exception handler + await self.connect() + else: + break # Client was closed, exit run loop + else: + print( + "Gateway receive task ended without exception. Assuming clean shutdown or reconnect handled internally." + ) + if ( + not self._closed + ): # If not explicitly closed, might be an issue + print( + "Warning: Gateway receive task ended but client not closed. This might indicate an issue." + ) + # Consider a more robust health check or reconnect strategy here. + await asyncio.sleep( + 1 + ) # Prevent tight loop if something is wrong + else: + break # Client was closed + except asyncio.CancelledError: + print("Gateway receive task was cancelled.") + break # Exit if cancelled + except Exception as e: + print(f"Error checking gateway receive task: {e}") + break # Exit on other errors + await asyncio.sleep(1) # Main loop check interval + except DisagreementException as e: + print(f"Client run loop encountered an error: {e}") + # Error already logged by connect or other methods + except asyncio.CancelledError: + print("Client run loop was cancelled.") + finally: + if not self._closed: + await self.close() + + async def close(self) -> None: + """ + Closes the connection to Discord. This method is a coroutine. + """ + if self._closed: + return + + self._closed = True + print("Closing client...") + + if self._shard_manager: + await self._shard_manager.close() + self._shard_manager = None + if self._gateway: + await self._gateway.close() + + if self._http: # HTTPClient has its own session to close + await self._http.close() + + self._ready_event.set() # Ensure any waiters for ready are unblocked + print("Client closed.") + + async def __aenter__(self) -> "Client": + """Enter the context manager by connecting to Discord.""" + await self.connect() + return self + + async def __aexit__( + self, + exc_type: Optional[type], + exc: Optional[BaseException], + tb: Optional[BaseException], + ) -> bool: + """Exit the context manager and close the client.""" + await self.close() + return False + + async def close_gateway(self, code: int = 1000) -> None: + """Closes only the gateway connection, allowing for potential reconnect.""" + if self._shard_manager: + await self._shard_manager.close() + self._shard_manager = None + if self._gateway: + await self._gateway.close(code=code) + self._gateway = None + self._ready_event.clear() # No longer ready if gateway is closed + + def is_closed(self) -> bool: + """Indicates if the client has been closed.""" + return self._closed + + def is_ready(self) -> bool: + """Indicates if the client has successfully connected to the Gateway and is ready.""" + return self._ready_event.is_set() + + @property + def latency(self) -> Optional[float]: + """Returns the gateway latency in seconds, or ``None`` if unavailable.""" + if self._gateway: + return self._gateway.latency + return None + + async def wait_until_ready(self) -> None: + """|coro| + Waits until the client is fully connected to Discord and the initial state is processed. + This is mainly useful for waiting for the READY event from the Gateway. + """ + await self._ready_event.wait() + + async def wait_for( + self, + event_name: str, + check: Optional[Callable[[Any], bool]] = None, + timeout: Optional[float] = None, + ) -> Any: + """|coro| + Waits for a specific event to occur that satisfies the ``check``. + + Parameters + ---------- + event_name: str + The name of the event to wait for. + check: Optional[Callable[[Any], bool]] + A function that determines whether the received event should resolve the wait. + timeout: Optional[float] + How long to wait for the event before raising :class:`asyncio.TimeoutError`. + """ + + future: asyncio.Future = self.loop.create_future() + self._event_dispatcher.add_waiter(event_name, future, check) + try: + return await asyncio.wait_for(future, timeout=timeout) + finally: + self._event_dispatcher.remove_waiter(event_name, future) + + async def change_presence( + self, + status: str, + activity_name: Optional[str] = None, + activity_type: int = 0, + since: int = 0, + afk: bool = False, + ): + """ + Changes the client's presence on Discord. + + Args: + status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible"). + activity_name (Optional[str]): The name of the activity. + activity_type (int): The type of the activity. + since (int): The timestamp (in milliseconds) of when the client went idle. + afk (bool): Whether the client is AFK. + """ + if self._closed: + raise DisagreementException("Client is closed.") + + if self._gateway: + await self._gateway.update_presence( + status=status, + activity_name=activity_name, + activity_type=activity_type, + since=since, + afk=afk, + ) + + # --- Event Handling --- + + def event( + self, coro: Callable[..., Awaitable[None]] + ) -> Callable[..., Awaitable[None]]: + """ + A decorator that registers an event to listen to. + The name of the coroutine is used as the event name. + Example: + @client.event + async def on_ready(): # Will listen for the 'READY' event + print("Bot is ready!") + + @client.event + async def on_message(message: disagreement.Message): # Will listen for 'MESSAGE_CREATE' + print(f"Message from {message.author}: {message.content}") + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("Event registered must be a coroutine function.") + + event_name = coro.__name__ + # Map common function names to Discord event types + # e.g., on_ready -> READY, on_message -> MESSAGE_CREATE + if event_name.startswith("on_"): + discord_event_name = event_name[3:].upper() + mapping = { + "MESSAGE": "MESSAGE_CREATE", + "MESSAGE_EDIT": "MESSAGE_UPDATE", + "MESSAGE_UPDATE": "MESSAGE_UPDATE", + "MESSAGE_DELETE": "MESSAGE_DELETE", + "REACTION_ADD": "MESSAGE_REACTION_ADD", + "REACTION_REMOVE": "MESSAGE_REACTION_REMOVE", + } + discord_event_name = mapping.get(discord_event_name, discord_event_name) + self._event_dispatcher.register(discord_event_name, coro) + else: + # If not starting with "on_", assume it's the direct Discord event name (e.g. "TYPING_START") + # Or raise an error if a specific format is required. + # For now, let's assume direct mapping if no "on_" prefix. + self._event_dispatcher.register(event_name.upper(), coro) + + return coro # Return the original coroutine + + def on_event( + self, event_name: str + ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """ + A decorator that registers an event to listen to with a specific event name. + Example: + @client.on_event('MESSAGE_CREATE') + async def my_message_handler(message: disagreement.Message): + print(f"Message: {message.content}") + """ + + def decorator( + coro: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + if not asyncio.iscoroutinefunction(coro): + raise TypeError("Event registered must be a coroutine function.") + self._event_dispatcher.register(event_name.upper(), coro) + return coro + + return decorator + + async def _process_message_for_commands(self, message: "Message") -> None: + """Internal listener to process messages for commands.""" + # Make sure message object is valid and not from a bot (optional, common check) + if ( + not message or not message.author or message.author.bot + ): # Add .bot check to User model + return + await self.command_handler.process_commands(message) + + # --- Command Framework Methods --- + + def add_cog(self, cog: Cog) -> None: + """ + Adds a Cog to the bot. + Cogs are classes that group commands, listeners, and state. + This will also discover and register any application commands defined in the cog. + + Args: + cog (Cog): An instance of a class derived from `disagreement.ext.commands.Cog`. + """ + # Add to prefix command handler + self.command_handler.add_cog( + cog + ) # This should call cog._inject() internally or cog._inject() is called on Cog init + + # Discover and add application commands from the cog + # AppCommand and AppCommandGroup are already imported in TYPE_CHECKING block + for app_cmd_obj in cog.get_app_commands_and_groups(): # Uses the new method + # The cog attribute should have been set within Cog._inject() for AppCommands + self.app_command_handler.add_command(app_cmd_obj) + print( + f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'." + ) + + def remove_cog(self, cog_name: str) -> Optional[Cog]: + """ + Removes a Cog from the bot. + + Args: + cog_name (str): The name of the Cog to remove. + + Returns: + Optional[Cog]: The Cog that was removed, or None if not found. + """ + removed_cog = self.command_handler.remove_cog(cog_name) + if removed_cog: + # Also remove associated application commands + # This requires AppCommand to store a reference to its cog, or iterate all app_commands. + # Assuming AppCommand has a .cog attribute, which is set in Cog._inject() + # And AppCommandGroup might store commands that have .cog attribute + for app_cmd_or_group in removed_cog.get_app_commands_and_groups(): + # The AppCommandHandler.remove_command needs to handle both AppCommand and AppCommandGroup + self.app_command_handler.remove_command( + app_cmd_or_group.name + ) # Assuming name is unique enough for removal here + print( + f"Removed app command/group '{app_cmd_or_group.name}' from cog '{cog_name}'." + ) + # Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique + # (e.g. if it needs type or if groups and commands can share names). + # For now, assuming name is sufficient for removal from the handler's flat list. + return removed_cog + def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]): """ A decorator that adds a global check to the bot. @@ -693,804 +696,816 @@ class Client: ctx (CommandContext): The context of the command that completed. """ pass - - # --- Extension Management Methods --- - - def load_extension(self, name: str) -> ModuleType: - """Load an extension by name using :mod:`disagreement.ext.loader`.""" - - return ext_loader.load_extension(name) - - def unload_extension(self, name: str) -> None: - """Unload a previously loaded extension.""" - - ext_loader.unload_extension(name) - - def reload_extension(self, name: str) -> ModuleType: - """Reload an extension by name.""" - - return ext_loader.reload_extension(name) - - # --- Model Parsing and Fetching --- - - def parse_user(self, data: Dict[str, Any]) -> "User": - """Parses user data and returns a User object, updating cache.""" - from .models import User # Ensure User model is available - - user = User(data) - self._users[user.id] = user # Cache the user - return user - - def parse_channel(self, data: Dict[str, Any]) -> "Channel": - """Parses channel data and returns a Channel object, updating caches.""" - - from .models import channel_factory - - channel = channel_factory(data, self) - self._channels[channel.id] = channel - if channel.guild_id: - guild = self._guilds.get(channel.guild_id) - if guild: - guild._channels[channel.id] = channel - return channel - - def parse_message(self, data: Dict[str, Any]) -> "Message": - """Parses message data and returns a Message object, updating cache.""" - - from .models import Message - - message = Message(data, client_instance=self) - self._messages[message.id] = message - return message - - def parse_webhook(self, data: Union[Dict[str, Any], "Webhook"]) -> "Webhook": - """Parses webhook data and returns a Webhook object, updating cache.""" - - from .models import Webhook - - if isinstance(data, Webhook): - webhook = data - webhook._client = self # type: ignore[attr-defined] - else: - webhook = Webhook(data, client_instance=self) - self._webhooks[webhook.id] = webhook - return webhook - - def parse_template(self, data: Dict[str, Any]) -> "GuildTemplate": - """Parses template data into a GuildTemplate object.""" - - from .models import GuildTemplate - - return GuildTemplate(data, client_instance=self) - - def parse_scheduled_event(self, data: Dict[str, Any]) -> "ScheduledEvent": - """Parses scheduled event data and updates cache.""" - - from .models import ScheduledEvent - - event = ScheduledEvent(data, client_instance=self) - # Cache by ID under guild if guild cache exists - guild = self._guilds.get(event.guild_id) - if guild is not None: - events = getattr(guild, "_scheduled_events", {}) - events[event.id] = event - setattr(guild, "_scheduled_events", events) - return event - - def parse_audit_log_entry(self, data: Dict[str, Any]) -> "AuditLogEntry": - """Parses audit log entry data.""" - from .models import AuditLogEntry - - return AuditLogEntry(data, client_instance=self) - - def parse_invite(self, data: Dict[str, Any]) -> "Invite": - """Parses invite data into an :class:`Invite`.""" - - from .models import Invite - - return Invite.from_dict(data) - - async def fetch_user(self, user_id: Snowflake) -> Optional["User"]: - """Fetches a user by ID from Discord.""" - if self._closed: - raise DisagreementException("Client is closed.") - - cached_user = self._users.get(user_id) - if cached_user: - return cached_user # Return cached if available, though fetch implies wanting fresh - - try: - user_data = await self._http.get_user(user_id) - return self.parse_user(user_data) - except DisagreementException as e: # Catch HTTP exceptions from http client - print(f"Failed to fetch user {user_id}: {e}") - return None - - async def fetch_message( - self, channel_id: Snowflake, message_id: Snowflake - ) -> Optional["Message"]: - """Fetches a message by ID from Discord and caches it.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - cached_message = self._messages.get(message_id) - if cached_message: - return cached_message - - try: - message_data = await self._http.get_message(channel_id, message_id) - return self.parse_message(message_data) - except DisagreementException as e: - print( - f"Failed to fetch message {message_id} from channel {channel_id}: {e}" - ) - return None - - def parse_member(self, data: Dict[str, Any], guild_id: Snowflake) -> "Member": - """Parses member data and returns a Member object, updating relevant caches.""" - from .models import Member # Ensure Member model is available - - # Member's __init__ should handle the nested 'user' data. - member = Member(data, client_instance=self) - member.guild_id = str(guild_id) - - # Cache the member in the guild's member cache - guild = self._guilds.get(guild_id) - if guild: - guild._members[member.id] = member # Assuming Guild has _members dict - - # Also cache the user part if not already cached or if this is newer - # Since Member inherits from User, the member object itself is the user. - self._users[member.id] = member - # If 'user' was in data and Member.__init__ used it, it's already part of 'member'. - return member - - async def fetch_member( - self, guild_id: Snowflake, member_id: Snowflake - ) -> Optional["Member"]: - """Fetches a member from a guild by ID.""" - if self._closed: - raise DisagreementException("Client is closed.") - - guild = self.get_guild(guild_id) - if guild: - cached_member = guild.get_member(member_id) # Use Guild's get_member - if cached_member: - return cached_member # Return cached if available - - try: - member_data = await self._http.get_guild_member(guild_id, member_id) - return self.parse_member(member_data, guild_id) - except DisagreementException as e: - print(f"Failed to fetch member {member_id} from guild {guild_id}: {e}") - return None - - def parse_role(self, data: Dict[str, Any], guild_id: Snowflake) -> "Role": - """Parses role data and returns a Role object, updating guild's role cache.""" - from .models import Role # Ensure Role model is available - - role = Role(data) - guild = self._guilds.get(guild_id) - if guild: - # Update the role in the guild's roles list if it exists, or add it. - # Guild.roles is List[Role]. We need to find and replace or append. - found = False - for i, existing_role in enumerate(guild.roles): - if existing_role.id == role.id: - guild.roles[i] = role - found = True - break - if not found: - guild.roles.append(role) - return role - - def parse_guild(self, data: Dict[str, Any]) -> "Guild": - """Parses guild data and returns a Guild object, updating cache.""" - - from .models import Guild - - guild = Guild(data, client_instance=self) - self._guilds[guild.id] = guild - - # Populate channel and member caches if provided - for ch in data.get("channels", []): - channel_obj = self.parse_channel(ch) - guild._channels[channel_obj.id] = channel_obj - - for member in data.get("members", []): - member_obj = self.parse_member(member, guild.id) - guild._members[member_obj.id] = member_obj - - return guild - - async def fetch_roles(self, guild_id: Snowflake) -> List["Role"]: - """Fetches all roles for a given guild and caches them. - - If the guild is not cached, it will be retrieved first using - :meth:`fetch_guild`. - """ - if self._closed: - raise DisagreementException("Client is closed.") - guild = self.get_guild(guild_id) - if not guild: - guild = await self.fetch_guild(guild_id) - if not guild: - return [] - - try: - roles_data = await self._http.get_guild_roles(guild_id) - parsed_roles = [] - for role_data in roles_data: - # parse_role will add/update it in the guild.roles list - parsed_roles.append(self.parse_role(role_data, guild_id)) - guild.roles = parsed_roles # Replace the entire list with the fresh one - return parsed_roles - except DisagreementException as e: - print(f"Failed to fetch roles for guild {guild_id}: {e}") - return [] - - async def fetch_role( - self, guild_id: Snowflake, role_id: Snowflake - ) -> Optional["Role"]: - """Fetches a specific role from a guild by ID. - If roles for the guild aren't cached or might be stale, it fetches all roles first. - """ - guild = self.get_guild(guild_id) - if guild: - # Try to find in existing guild.roles - for role in guild.roles: - if role.id == role_id: - return role - - # If not found in cache or guild doesn't exist yet in cache, fetch all roles for the guild - await self.fetch_roles(guild_id) # This will populate/update guild.roles - - # Try again from the now (hopefully) populated cache - guild = self.get_guild( - guild_id - ) # Re-get guild in case it was populated by fetch_roles - if guild: - for role in guild.roles: - if role.id == role_id: - return role - - return None # Role not found even after fetching - - # --- API Methods --- - - # --- API Methods --- - - async def send_message( - self, - channel_id: str, - content: Optional[str] = None, - *, # Make additional params keyword-only - tts: bool = False, - embed: Optional["Embed"] = None, - embeds: Optional[List["Embed"]] = None, - components: Optional[List["ActionRow"]] = None, - allowed_mentions: Optional[Dict[str, Any]] = None, - message_reference: Optional[Dict[str, Any]] = None, - attachments: Optional[List[Any]] = None, - files: Optional[List[Any]] = None, - flags: Optional[int] = None, - view: Optional["View"] = None, - ) -> "Message": - """|coro| - Sends a message to the specified channel. - - Args: - channel_id (str): The ID of the channel to send the message to. - content (Optional[str]): The content of the message. - tts (bool): Whether the message should be sent with text-to-speech. Defaults to False. - embed (Optional[Embed]): A single embed to send. Cannot be used with `embeds`. - embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`. - Discord supports up to 10 embeds per message. - components (Optional[List[ActionRow]]): A list of ActionRow components to include. - allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. - message_reference (Optional[Dict[str, Any]]): Message reference for replying. - attachments (Optional[List[Any]]): Attachments to include with the message. - files (Optional[List[Any]]): Files to upload with the message. - flags (Optional[int]): Message flags. - view (Optional[View]): A view to send with the message. - - Returns: - Message: The message that was sent. - - Raises: - HTTPException: Sending the message failed. - ValueError: If both `embed` and `embeds` are provided, or if both `components` and `view` are provided. - """ - if self._closed: - raise DisagreementException("Client is closed.") - - if embed and embeds: - raise ValueError("Cannot provide both embed and embeds.") - if components and view: - raise ValueError("Cannot provide both 'components' and 'view'.") - - final_embeds_payload: Optional[List[Dict[str, Any]]] = None - if embed: - final_embeds_payload = [embed.to_dict()] - elif embeds: - from .models import ( - Embed as EmbedModel, - ) - - final_embeds_payload = [ - e.to_dict() for e in embeds if isinstance(e, EmbedModel) - ] - - components_payload: Optional[List[Dict[str, Any]]] = None - if view: - await view._start(self) - components_payload = view.to_components_payload() - elif components: - from .models import Component as ComponentModel - - components_payload = [ - comp.to_dict() - for comp in components - if isinstance(comp, ComponentModel) - ] - - message_data = await self._http.send_message( - channel_id=channel_id, - content=content, - tts=tts, - embeds=final_embeds_payload, - components=components_payload, - allowed_mentions=allowed_mentions, - message_reference=message_reference, - attachments=attachments, - files=files, - flags=flags, - ) - - if view: - message_id = message_data["id"] - view.message_id = message_id - self._views[message_id] = view - - return self.parse_message(message_data) - - def typing(self, channel_id: str) -> Typing: - """Return a context manager to show a typing indicator in a channel.""" - - return Typing(self, channel_id) - - async def join_voice( - self, - guild_id: Snowflake, - channel_id: Snowflake, - *, - self_mute: bool = False, - self_deaf: bool = False, - ) -> VoiceClient: - """|coro| Join a voice channel and return a :class:`VoiceClient`.""" - - if self._closed: - raise DisagreementException("Client is closed.") - if not self.is_ready(): - await self.wait_until_ready() - if self._gateway is None: - raise DisagreementException("Gateway is not connected.") - if not self.user: - raise DisagreementException("Client user unavailable.") - assert self.user is not None - user_id = self.user.id - - if guild_id in self._voice_clients: - return self._voice_clients[guild_id] - - payload = { - "op": GatewayOpcode.VOICE_STATE_UPDATE, - "d": { - "guild_id": str(guild_id), - "channel_id": str(channel_id), - "self_mute": self_mute, - "self_deaf": self_deaf, - }, - } - await self._gateway._send_json(payload) # type: ignore[attr-defined] - - server = await self.wait_for( - "VOICE_SERVER_UPDATE", - check=lambda d: d.get("guild_id") == str(guild_id), - timeout=10, - ) - state = await self.wait_for( - "VOICE_STATE_UPDATE", - check=lambda d, uid=user_id: d.get("guild_id") == str(guild_id) - and d.get("user_id") == str(uid), - timeout=10, - ) - - endpoint = f"wss://{server['endpoint']}?v=10" - token = server["token"] - session_id = state["session_id"] - - voice = VoiceClient( + + # --- Extension Management Methods --- + + def load_extension(self, name: str) -> ModuleType: + """Load an extension by name using :mod:`disagreement.ext.loader`.""" + + return ext_loader.load_extension(name) + + def unload_extension(self, name: str) -> None: + """Unload a previously loaded extension.""" + + ext_loader.unload_extension(name) + + def reload_extension(self, name: str) -> ModuleType: + """Reload an extension by name.""" + + return ext_loader.reload_extension(name) + + # --- Model Parsing and Fetching --- + + def parse_user(self, data: Dict[str, Any]) -> "User": + """Parses user data and returns a User object, updating cache.""" + from .models import User # Ensure User model is available + + user = User(data) + self._users.set(user.id, user) # Cache the user + return user + + def parse_channel(self, data: Dict[str, Any]) -> "Channel": + """Parses channel data and returns a Channel object, updating caches.""" + + from .models import channel_factory + + channel = channel_factory(data, self) + self._channels.set(channel.id, channel) + if channel.guild_id: + guild = self._guilds.get(channel.guild_id) + if guild: + guild._channels.set(channel.id, channel) + return channel + + def parse_message(self, data: Dict[str, Any]) -> "Message": + """Parses message data and returns a Message object, updating cache.""" + + from .models import Message + + message = Message(data, client_instance=self) + self._messages.set(message.id, message) + return message + + def parse_webhook(self, data: Union[Dict[str, Any], "Webhook"]) -> "Webhook": + """Parses webhook data and returns a Webhook object, updating cache.""" + + from .models import Webhook + + if isinstance(data, Webhook): + webhook = data + webhook._client = self # type: ignore[attr-defined] + else: + webhook = Webhook(data, client_instance=self) + self._webhooks[webhook.id] = webhook + return webhook + + def parse_template(self, data: Dict[str, Any]) -> "GuildTemplate": + """Parses template data into a GuildTemplate object.""" + + from .models import GuildTemplate + + return GuildTemplate(data, client_instance=self) + + def parse_scheduled_event(self, data: Dict[str, Any]) -> "ScheduledEvent": + """Parses scheduled event data and updates cache.""" + + from .models import ScheduledEvent + + event = ScheduledEvent(data, client_instance=self) + # Cache by ID under guild if guild cache exists + guild = self._guilds.get(event.guild_id) + if guild is not None: + events = getattr(guild, "_scheduled_events", {}) + events[event.id] = event + setattr(guild, "_scheduled_events", events) + return event + + def parse_audit_log_entry(self, data: Dict[str, Any]) -> "AuditLogEntry": + """Parses audit log entry data.""" + from .models import AuditLogEntry + + return AuditLogEntry(data, client_instance=self) + + def parse_invite(self, data: Dict[str, Any]) -> "Invite": + """Parses invite data into an :class:`Invite`.""" + + from .models import Invite + + return Invite.from_dict(data) + + async def fetch_user(self, user_id: Snowflake) -> Optional["User"]: + """Fetches a user by ID from Discord.""" + if self._closed: + raise DisagreementException("Client is closed.") + + cached_user = self._users.get(user_id) + if cached_user: + return cached_user + + try: + user_data = await self._http.get_user(user_id) + return self.parse_user(user_data) + except DisagreementException as e: # Catch HTTP exceptions from http client + print(f"Failed to fetch user {user_id}: {e}") + return None + + async def fetch_message( + self, channel_id: Snowflake, message_id: Snowflake + ) -> Optional["Message"]: + """Fetches a message by ID from Discord and caches it.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + cached_message = self._messages.get(message_id) + if cached_message: + return cached_message + + try: + message_data = await self._http.get_message(channel_id, message_id) + return self.parse_message(message_data) + except DisagreementException as e: + print( + f"Failed to fetch message {message_id} from channel {channel_id}: {e}" + ) + return None + + def parse_member( + self, data: Dict[str, Any], guild_id: Snowflake, *, just_joined: bool = False + ) -> "Member": + """Parses member data and returns a Member object, updating relevant caches.""" + from .models import Member + + member = Member(data, client_instance=self) + member.guild_id = str(guild_id) + + if just_joined: + setattr(member, "_just_joined", True) + + guild = self._guilds.get(guild_id) + if guild: + guild._members.set(member.id, member) + + if just_joined and hasattr(member, "_just_joined"): + delattr(member, "_just_joined") + + self._users.set(member.id, member) + return member + + async def fetch_member( + self, guild_id: Snowflake, member_id: Snowflake + ) -> Optional["Member"]: + """Fetches a member from a guild by ID.""" + if self._closed: + raise DisagreementException("Client is closed.") + + guild = self.get_guild(guild_id) + if guild: + cached_member = guild.get_member(member_id) # Use Guild's get_member + if cached_member: + return cached_member # Return cached if available + + try: + member_data = await self._http.get_guild_member(guild_id, member_id) + return self.parse_member(member_data, guild_id) + except DisagreementException as e: + print(f"Failed to fetch member {member_id} from guild {guild_id}: {e}") + return None + + def parse_role(self, data: Dict[str, Any], guild_id: Snowflake) -> "Role": + """Parses role data and returns a Role object, updating guild's role cache.""" + from .models import Role # Ensure Role model is available + + role = Role(data) + guild = self._guilds.get(guild_id) + if guild: + # Update the role in the guild's roles list if it exists, or add it. + # Guild.roles is List[Role]. We need to find and replace or append. + found = False + for i, existing_role in enumerate(guild.roles): + if existing_role.id == role.id: + guild.roles[i] = role + found = True + break + if not found: + guild.roles.append(role) + return role + + def parse_guild(self, data: Dict[str, Any]) -> "Guild": + """Parses guild data and returns a Guild object, updating cache.""" + from .models import Guild + + guild = Guild(data, client_instance=self) + self._guilds.set(guild.id, guild) + + presences = {p["user"]["id"]: p for p in data.get("presences", [])} + voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])} + + for ch_data in data.get("channels", []): + self.parse_channel(ch_data) + + for member_data in data.get("members", []): + user_id = member_data.get("user", {}).get("id") + if user_id: + presence = presences.get(user_id) + if presence: + member_data["status"] = presence.get("status", "offline") + + voice_state = voice_states.get(user_id) + if voice_state: + member_data["voice_state"] = voice_state + + self.parse_member(member_data, guild.id) + + return guild + + async def fetch_roles(self, guild_id: Snowflake) -> List["Role"]: + """Fetches all roles for a given guild and caches them. + + If the guild is not cached, it will be retrieved first using + :meth:`fetch_guild`. + """ + if self._closed: + raise DisagreementException("Client is closed.") + guild = self.get_guild(guild_id) + if not guild: + guild = await self.fetch_guild(guild_id) + if not guild: + return [] + + try: + roles_data = await self._http.get_guild_roles(guild_id) + parsed_roles = [] + for role_data in roles_data: + # parse_role will add/update it in the guild.roles list + parsed_roles.append(self.parse_role(role_data, guild_id)) + guild.roles = parsed_roles # Replace the entire list with the fresh one + return parsed_roles + except DisagreementException as e: + print(f"Failed to fetch roles for guild {guild_id}: {e}") + return [] + + async def fetch_role( + self, guild_id: Snowflake, role_id: Snowflake + ) -> Optional["Role"]: + """Fetches a specific role from a guild by ID. + If roles for the guild aren't cached or might be stale, it fetches all roles first. + """ + guild = self.get_guild(guild_id) + if guild: + # Try to find in existing guild.roles + for role in guild.roles: + if role.id == role_id: + return role + + # If not found in cache or guild doesn't exist yet in cache, fetch all roles for the guild + await self.fetch_roles(guild_id) # This will populate/update guild.roles + + # Try again from the now (hopefully) populated cache + guild = self.get_guild( + guild_id + ) # Re-get guild in case it was populated by fetch_roles + if guild: + for role in guild.roles: + if role.id == role_id: + return role + + return None # Role not found even after fetching + + # --- API Methods --- + + # --- API Methods --- + + async def send_message( + self, + channel_id: str, + content: Optional[str] = None, + *, # Make additional params keyword-only + tts: bool = False, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + message_reference: Optional[Dict[str, Any]] = None, + attachments: Optional[List[Any]] = None, + files: Optional[List[Any]] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """|coro| + Sends a message to the specified channel. + + Args: + channel_id (str): The ID of the channel to send the message to. + content (Optional[str]): The content of the message. + tts (bool): Whether the message should be sent with text-to-speech. Defaults to False. + embed (Optional[Embed]): A single embed to send. Cannot be used with `embeds`. + embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`. + Discord supports up to 10 embeds per message. + components (Optional[List[ActionRow]]): A list of ActionRow components to include. + allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. + message_reference (Optional[Dict[str, Any]]): Message reference for replying. + attachments (Optional[List[Any]]): Attachments to include with the message. + files (Optional[List[Any]]): Files to upload with the message. + flags (Optional[int]): Message flags. + view (Optional[View]): A view to send with the message. + + Returns: + Message: The message that was sent. + + Raises: + HTTPException: Sending the message failed. + ValueError: If both `embed` and `embeds` are provided, or if both `components` and `view` are provided. + """ + if self._closed: + raise DisagreementException("Client is closed.") + + if embed and embeds: + raise ValueError("Cannot provide both embed and embeds.") + if components and view: + raise ValueError("Cannot provide both 'components' and 'view'.") + + final_embeds_payload: Optional[List[Dict[str, Any]]] = None + if embed: + final_embeds_payload = [embed.to_dict()] + elif embeds: + from .models import ( + Embed as EmbedModel, + ) + + final_embeds_payload = [ + e.to_dict() for e in embeds if isinstance(e, EmbedModel) + ] + + components_payload: Optional[List[Dict[str, Any]]] = None + if view: + await view._start(self) + components_payload = view.to_components_payload() + elif components: + from .models import Component as ComponentModel + + components_payload = [ + comp.to_dict() + for comp in components + if isinstance(comp, ComponentModel) + ] + + message_data = await self._http.send_message( + channel_id=channel_id, + content=content, + tts=tts, + embeds=final_embeds_payload, + components=components_payload, + allowed_mentions=allowed_mentions, + message_reference=message_reference, + attachments=attachments, + files=files, + flags=flags, + ) + + if view: + message_id = message_data["id"] + view.message_id = message_id + self._views[message_id] = view + + return self.parse_message(message_data) + + def typing(self, channel_id: str) -> Typing: + """Return a context manager to show a typing indicator in a channel.""" + + return Typing(self, channel_id) + + async def join_voice( + self, + guild_id: Snowflake, + channel_id: Snowflake, + *, + self_mute: bool = False, + self_deaf: bool = False, + ) -> VoiceClient: + """|coro| Join a voice channel and return a :class:`VoiceClient`.""" + + if self._closed: + raise DisagreementException("Client is closed.") + if not self.is_ready(): + await self.wait_until_ready() + if self._gateway is None: + raise DisagreementException("Gateway is not connected.") + if not self.user: + raise DisagreementException("Client user unavailable.") + assert self.user is not None + user_id = self.user.id + + if guild_id in self._voice_clients: + return self._voice_clients[guild_id] + + payload = { + "op": GatewayOpcode.VOICE_STATE_UPDATE, + "d": { + "guild_id": str(guild_id), + "channel_id": str(channel_id), + "self_mute": self_mute, + "self_deaf": self_deaf, + }, + } + await self._gateway._send_json(payload) # type: ignore[attr-defined] + + server = await self.wait_for( + "VOICE_SERVER_UPDATE", + check=lambda d: d.get("guild_id") == str(guild_id), + timeout=10, + ) + state = await self.wait_for( + "VOICE_STATE_UPDATE", + check=lambda d, uid=user_id: d.get("guild_id") == str(guild_id) + and d.get("user_id") == str(uid), + timeout=10, + ) + + endpoint = f"wss://{server['endpoint']}?v=10" + token = server["token"] + session_id = state["session_id"] + + voice = VoiceClient( self, - endpoint, - session_id, - token, - int(guild_id), - int(self.user.id), - verbose=self.verbose, - ) - await voice.connect() - self._voice_clients[guild_id] = voice - return voice - - async def add_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: - """|coro| Add a reaction to a message.""" - - await self.create_reaction(channel_id, message_id, emoji) - - async def remove_reaction( - self, channel_id: str, message_id: str, emoji: str - ) -> None: - """|coro| Remove the bot's reaction from a message.""" - - await self.delete_reaction(channel_id, message_id, emoji) - - async def clear_reactions(self, channel_id: str, message_id: str) -> None: - """|coro| Remove all reactions from a message.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.clear_reactions(channel_id, message_id) - - async def create_reaction( - self, channel_id: str, message_id: str, emoji: str - ) -> None: - """|coro| Add a reaction to a message.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.create_reaction(channel_id, message_id, emoji) - - user_id = getattr(getattr(self, "user", None), "id", None) - payload = { - "user_id": user_id, - "channel_id": channel_id, - "message_id": message_id, - "emoji": {"name": emoji, "id": None}, - } - if hasattr(self, "_event_dispatcher"): - await self._event_dispatcher.dispatch("MESSAGE_REACTION_ADD", payload) - - async def delete_reaction( - self, channel_id: str, message_id: str, emoji: str - ) -> None: - """|coro| Remove the bot's reaction from a message.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_reaction(channel_id, message_id, emoji) - - user_id = getattr(getattr(self, "user", None), "id", None) - payload = { - "user_id": user_id, - "channel_id": channel_id, - "message_id": message_id, - "emoji": {"name": emoji, "id": None}, - } - if hasattr(self, "_event_dispatcher"): - await self._event_dispatcher.dispatch("MESSAGE_REACTION_REMOVE", payload) - - async def get_reactions( - self, channel_id: str, message_id: str, emoji: str - ) -> List["User"]: - """|coro| Return the users who reacted with the given emoji.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - users_data = await self._http.get_reactions(channel_id, message_id, emoji) - return [self.parse_user(u) for u in users_data] - - async def edit_message( - self, - channel_id: str, - message_id: str, - *, - content: Optional[str] = None, - embed: Optional["Embed"] = None, - embeds: Optional[List["Embed"]] = None, - components: Optional[List["ActionRow"]] = None, - allowed_mentions: Optional[Dict[str, Any]] = None, - flags: Optional[int] = None, - view: Optional["View"] = None, - ) -> "Message": - """Edits a previously sent message.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - if embed and embeds: - raise ValueError("Cannot provide both embed and embeds.") - if components and view: - raise ValueError("Cannot provide both 'components' and 'view'.") - - final_embeds_payload: Optional[List[Dict[str, Any]]] = None - if embed: - final_embeds_payload = [embed.to_dict()] - elif embeds: - final_embeds_payload = [e.to_dict() for e in embeds] - - components_payload: Optional[List[Dict[str, Any]]] = None - if view: - await view._start(self) - components_payload = view.to_components_payload() - elif components: - components_payload = [c.to_dict() for c in components] - - payload: Dict[str, Any] = {} - if content is not None: - payload["content"] = content - if final_embeds_payload is not None: - payload["embeds"] = final_embeds_payload - if components_payload is not None: - payload["components"] = components_payload - if allowed_mentions is not None: - payload["allowed_mentions"] = allowed_mentions - if flags is not None: - payload["flags"] = flags - - message_data = await self._http.edit_message( - channel_id=channel_id, - message_id=message_id, - payload=payload, - ) - - if view: - view.message_id = message_data["id"] - self._views[message_data["id"]] = view - - return self.parse_message(message_data) - - def get_guild(self, guild_id: Snowflake) -> Optional["Guild"]: - """Returns a guild from the internal cache. - - Use :meth:`fetch_guild` to retrieve it from Discord if it's not cached. - """ - - return self._guilds.get(guild_id) - - def get_channel(self, channel_id: Snowflake) -> Optional["Channel"]: - """Returns a channel from the internal cache.""" - - return self._channels.get(channel_id) - - def get_message(self, message_id: Snowflake) -> Optional["Message"]: - """Returns a message from the internal cache.""" - - return self._messages.get(message_id) - - async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]: - """Fetches a guild by ID from Discord and caches it.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - cached_guild = self._guilds.get(guild_id) - if cached_guild: - return cached_guild - - try: - guild_data = await self._http.get_guild(guild_id) - return self.parse_guild(guild_data) - except DisagreementException as e: - print(f"Failed to fetch guild {guild_id}: {e}") - return None - - async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]: - """Fetches a channel from Discord by its ID and updates the cache.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - try: - channel_data = await self._http.get_channel(channel_id) - if not channel_data: - return None - - from .models import channel_factory - - channel = channel_factory(channel_data, self) - - self._channels[channel.id] = channel - return channel - - except DisagreementException as e: # Includes HTTPException - print(f"Failed to fetch channel {channel_id}: {e}") - return None - - async def fetch_audit_logs( - self, guild_id: Snowflake, **filters: Any - ) -> AsyncIterator["AuditLogEntry"]: - """Fetch audit log entries for a guild.""" - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_audit_logs(guild_id, **filters) - for entry in data.get("audit_log_entries", []): - yield self.parse_audit_log_entry(entry) - - async def fetch_voice_regions(self) -> List[VoiceRegion]: - """Fetches available voice regions.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_voice_regions() - regions = [] - for region in data: - region_id = region.get("id") - if region_id: - regions.append(VoiceRegion(region_id)) - return regions - - async def create_webhook( - self, channel_id: Snowflake, payload: Dict[str, Any] - ) -> "Webhook": - """|coro| Create a webhook in the given channel.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.create_webhook(channel_id, payload) - return self.parse_webhook(data) - - async def edit_webhook( - self, webhook_id: Snowflake, payload: Dict[str, Any] - ) -> "Webhook": - """|coro| Edit an existing webhook.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.edit_webhook(webhook_id, payload) - return self.parse_webhook(data) - - async def delete_webhook(self, webhook_id: Snowflake) -> None: - """|coro| Delete a webhook by ID.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_webhook(webhook_id) - - async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]: - """|coro| Fetch all templates for a guild.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_guild_templates(guild_id) - return [self.parse_template(t) for t in data] - - async def create_template( - self, guild_id: Snowflake, payload: Dict[str, Any] - ) -> "GuildTemplate": - """|coro| Create a template for a guild.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.create_guild_template(guild_id, payload) - return self.parse_template(data) - - async def sync_template( - self, guild_id: Snowflake, template_code: str - ) -> "GuildTemplate": - """|coro| Sync a template to the guild's current state.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.sync_guild_template(guild_id, template_code) - return self.parse_template(data) - - async def delete_template(self, guild_id: Snowflake, template_code: str) -> None: - """|coro| Delete a guild template.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_guild_template(guild_id, template_code) - - async def fetch_scheduled_events( - self, guild_id: Snowflake - ) -> List["ScheduledEvent"]: - """|coro| Fetch all scheduled events for a guild.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_guild_scheduled_events(guild_id) - return [self.parse_scheduled_event(ev) for ev in data] - - async def fetch_scheduled_event( - self, guild_id: Snowflake, event_id: Snowflake - ) -> Optional["ScheduledEvent"]: - """|coro| Fetch a single scheduled event.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - try: - data = await self._http.get_guild_scheduled_event(guild_id, event_id) - return self.parse_scheduled_event(data) - except DisagreementException as e: - print(f"Failed to fetch scheduled event {event_id}: {e}") - return None - - async def create_scheduled_event( - self, guild_id: Snowflake, payload: Dict[str, Any] - ) -> "ScheduledEvent": - """|coro| Create a scheduled event in a guild.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.create_guild_scheduled_event(guild_id, payload) - return self.parse_scheduled_event(data) - - async def edit_scheduled_event( - self, guild_id: Snowflake, event_id: Snowflake, payload: Dict[str, Any] - ) -> "ScheduledEvent": - """|coro| Edit an existing scheduled event.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.edit_guild_scheduled_event(guild_id, event_id, payload) - return self.parse_scheduled_event(data) - - async def delete_scheduled_event( - self, guild_id: Snowflake, event_id: Snowflake - ) -> None: - """|coro| Delete a scheduled event.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_guild_scheduled_event(guild_id, event_id) - - async def create_invite( - self, channel_id: Snowflake, payload: Dict[str, Any] - ) -> "Invite": - """|coro| Create an invite for the given channel.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - return await self._http.create_invite(channel_id, payload) - - async def delete_invite(self, code: str) -> None: - """|coro| Delete an invite by code.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_invite(code) - - async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]: - """|coro| Fetch all invites for a channel.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_channel_invites(channel_id) - return [self.parse_invite(inv) for inv in data] - + endpoint, + session_id, + token, + int(guild_id), + int(self.user.id), + verbose=self.verbose, + ) + await voice.connect() + self._voice_clients[guild_id] = voice + return voice + + async def add_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: + """|coro| Add a reaction to a message.""" + + await self.create_reaction(channel_id, message_id, emoji) + + async def remove_reaction( + self, channel_id: str, message_id: str, emoji: str + ) -> None: + """|coro| Remove the bot's reaction from a message.""" + + await self.delete_reaction(channel_id, message_id, emoji) + + async def clear_reactions(self, channel_id: str, message_id: str) -> None: + """|coro| Remove all reactions from a message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.clear_reactions(channel_id, message_id) + + async def create_reaction( + self, channel_id: str, message_id: str, emoji: str + ) -> None: + """|coro| Add a reaction to a message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.create_reaction(channel_id, message_id, emoji) + + user_id = getattr(getattr(self, "user", None), "id", None) + payload = { + "user_id": user_id, + "channel_id": channel_id, + "message_id": message_id, + "emoji": {"name": emoji, "id": None}, + } + if hasattr(self, "_event_dispatcher"): + await self._event_dispatcher.dispatch("MESSAGE_REACTION_ADD", payload) + + async def delete_reaction( + self, channel_id: str, message_id: str, emoji: str + ) -> None: + """|coro| Remove the bot's reaction from a message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_reaction(channel_id, message_id, emoji) + + user_id = getattr(getattr(self, "user", None), "id", None) + payload = { + "user_id": user_id, + "channel_id": channel_id, + "message_id": message_id, + "emoji": {"name": emoji, "id": None}, + } + if hasattr(self, "_event_dispatcher"): + await self._event_dispatcher.dispatch("MESSAGE_REACTION_REMOVE", payload) + + async def get_reactions( + self, channel_id: str, message_id: str, emoji: str + ) -> List["User"]: + """|coro| Return the users who reacted with the given emoji.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + users_data = await self._http.get_reactions(channel_id, message_id, emoji) + return [self.parse_user(u) for u in users_data] + + async def edit_message( + self, + channel_id: str, + message_id: str, + *, + content: Optional[str] = None, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """Edits a previously sent message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + if embed and embeds: + raise ValueError("Cannot provide both embed and embeds.") + if components and view: + raise ValueError("Cannot provide both 'components' and 'view'.") + + final_embeds_payload: Optional[List[Dict[str, Any]]] = None + if embed: + final_embeds_payload = [embed.to_dict()] + elif embeds: + final_embeds_payload = [e.to_dict() for e in embeds] + + components_payload: Optional[List[Dict[str, Any]]] = None + if view: + await view._start(self) + components_payload = view.to_components_payload() + elif components: + components_payload = [c.to_dict() for c in components] + + payload: Dict[str, Any] = {} + if content is not None: + payload["content"] = content + if final_embeds_payload is not None: + payload["embeds"] = final_embeds_payload + if components_payload is not None: + payload["components"] = components_payload + if allowed_mentions is not None: + payload["allowed_mentions"] = allowed_mentions + if flags is not None: + payload["flags"] = flags + + message_data = await self._http.edit_message( + channel_id=channel_id, + message_id=message_id, + payload=payload, + ) + + if view: + view.message_id = message_data["id"] + self._views[message_data["id"]] = view + + return self.parse_message(message_data) + + def get_guild(self, guild_id: Snowflake) -> Optional["Guild"]: + """Returns a guild from the internal cache. + + Use :meth:`fetch_guild` to retrieve it from Discord if it's not cached. + """ + + return self._guilds.get(guild_id) + + def get_channel(self, channel_id: Snowflake) -> Optional["Channel"]: + """Returns a channel from the internal cache.""" + + return self._channels.get(channel_id) + + def get_message(self, message_id: Snowflake) -> Optional["Message"]: + """Returns a message from the internal cache.""" + + return self._messages.get(message_id) + + async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]: + """Fetches a guild by ID from Discord and caches it.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + cached_guild = self._guilds.get(guild_id) + if cached_guild: + return cached_guild + + try: + guild_data = await self._http.get_guild(guild_id) + return self.parse_guild(guild_data) + except DisagreementException as e: + print(f"Failed to fetch guild {guild_id}: {e}") + return None + + async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]: + """Fetches a channel from Discord by its ID and updates the cache.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + try: + channel_data = await self._http.get_channel(channel_id) + if not channel_data: + return None + + from .models import channel_factory + + channel = channel_factory(channel_data, self) + + self._channels.set(channel.id, channel) + return channel + + except DisagreementException as e: # Includes HTTPException + print(f"Failed to fetch channel {channel_id}: {e}") + return None + + async def fetch_audit_logs( + self, guild_id: Snowflake, **filters: Any + ) -> AsyncIterator["AuditLogEntry"]: + """Fetch audit log entries for a guild.""" + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_audit_logs(guild_id, **filters) + for entry in data.get("audit_log_entries", []): + yield self.parse_audit_log_entry(entry) + + async def fetch_voice_regions(self) -> List[VoiceRegion]: + """Fetches available voice regions.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_voice_regions() + regions = [] + for region in data: + region_id = region.get("id") + if region_id: + regions.append(VoiceRegion(region_id)) + return regions + + async def create_webhook( + self, channel_id: Snowflake, payload: Dict[str, Any] + ) -> "Webhook": + """|coro| Create a webhook in the given channel.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.create_webhook(channel_id, payload) + return self.parse_webhook(data) + + async def edit_webhook( + self, webhook_id: Snowflake, payload: Dict[str, Any] + ) -> "Webhook": + """|coro| Edit an existing webhook.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.edit_webhook(webhook_id, payload) + return self.parse_webhook(data) + + async def delete_webhook(self, webhook_id: Snowflake) -> None: + """|coro| Delete a webhook by ID.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_webhook(webhook_id) + + async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]: + """|coro| Fetch all templates for a guild.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_guild_templates(guild_id) + return [self.parse_template(t) for t in data] + + async def create_template( + self, guild_id: Snowflake, payload: Dict[str, Any] + ) -> "GuildTemplate": + """|coro| Create a template for a guild.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.create_guild_template(guild_id, payload) + return self.parse_template(data) + + async def sync_template( + self, guild_id: Snowflake, template_code: str + ) -> "GuildTemplate": + """|coro| Sync a template to the guild's current state.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.sync_guild_template(guild_id, template_code) + return self.parse_template(data) + + async def delete_template(self, guild_id: Snowflake, template_code: str) -> None: + """|coro| Delete a guild template.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_guild_template(guild_id, template_code) + + async def fetch_scheduled_events( + self, guild_id: Snowflake + ) -> List["ScheduledEvent"]: + """|coro| Fetch all scheduled events for a guild.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_guild_scheduled_events(guild_id) + return [self.parse_scheduled_event(ev) for ev in data] + + async def fetch_scheduled_event( + self, guild_id: Snowflake, event_id: Snowflake + ) -> Optional["ScheduledEvent"]: + """|coro| Fetch a single scheduled event.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + try: + data = await self._http.get_guild_scheduled_event(guild_id, event_id) + return self.parse_scheduled_event(data) + except DisagreementException as e: + print(f"Failed to fetch scheduled event {event_id}: {e}") + return None + + async def create_scheduled_event( + self, guild_id: Snowflake, payload: Dict[str, Any] + ) -> "ScheduledEvent": + """|coro| Create a scheduled event in a guild.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.create_guild_scheduled_event(guild_id, payload) + return self.parse_scheduled_event(data) + + async def edit_scheduled_event( + self, guild_id: Snowflake, event_id: Snowflake, payload: Dict[str, Any] + ) -> "ScheduledEvent": + """|coro| Edit an existing scheduled event.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.edit_guild_scheduled_event(guild_id, event_id, payload) + return self.parse_scheduled_event(data) + + async def delete_scheduled_event( + self, guild_id: Snowflake, event_id: Snowflake + ) -> None: + """|coro| Delete a scheduled event.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_guild_scheduled_event(guild_id, event_id) + + async def create_invite( + self, channel_id: Snowflake, payload: Dict[str, Any] + ) -> "Invite": + """|coro| Create an invite for the given channel.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + return await self._http.create_invite(channel_id, payload) + + async def delete_invite(self, code: str) -> None: + """|coro| Delete an invite by code.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_invite(code) + + async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]: + """|coro| Fetch all invites for a channel.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_channel_invites(channel_id) + return [self.parse_invite(inv) for inv in data] + def add_persistent_view(self, view: "View") -> None: """ Registers a persistent view with the client. @@ -1522,22 +1537,22 @@ class Client: ) self._persistent_views[item.custom_id] = view - # --- Application Command Methods --- - async def process_interaction(self, interaction: Interaction) -> None: - """Internal method to process an interaction from the gateway.""" - - if hasattr(self, "on_interaction_create"): - asyncio.create_task(self.on_interaction_create(interaction)) - # Route component interactions to the appropriate View - if ( - interaction.type == InteractionType.MESSAGE_COMPONENT - and interaction.message + # --- Application Command Methods --- + async def process_interaction(self, interaction: Interaction) -> None: + """Internal method to process an interaction from the gateway.""" + + if hasattr(self, "on_interaction_create"): + asyncio.create_task(self.on_interaction_create(interaction)) + # Route component interactions to the appropriate View + if ( + interaction.type == InteractionType.MESSAGE_COMPONENT + and interaction.message and interaction.data - ): - view = self._views.get(interaction.message.id) - if view: - asyncio.create_task(view._dispatch(interaction)) - return + ): + view = self._views.get(interaction.message.id) + if view: + asyncio.create_task(view._dispatch(interaction)) + return else: # No active view found, check for persistent views custom_id = interaction.data.custom_id diff --git a/disagreement/event_dispatcher.py b/disagreement/event_dispatcher.py index b6ea317..5f4ebd2 100644 --- a/disagreement/event_dispatcher.py +++ b/disagreement/event_dispatcher.py @@ -76,7 +76,7 @@ class EventDispatcher: """Parses MESSAGE_DELETE and updates message cache.""" message_id = data.get("id") if message_id: - self._client._messages.pop(message_id, None) + self._client._messages.invalidate(message_id) return data def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]: @@ -124,7 +124,7 @@ class EventDispatcher: """Parses GUILD_MEMBER_ADD into a Member object.""" guild_id = str(data.get("guild_id")) - return self._client.parse_member(data, guild_id) + return self._client.parse_member(data, guild_id, just_joined=True) def _parse_guild_member_remove(self, data: Dict[str, Any]): """Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model.""" diff --git a/disagreement/models.py b/disagreement/models.py index be671a0..aa18424 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1,116 +1,118 @@ -# disagreement/models.py - -""" -Data models for Discord objects. -""" - -import asyncio -import json -from dataclasses import dataclass -from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union - -import aiohttp # pylint: disable=import-error -from .color import Color -from .errors import DisagreementException, HTTPException -from .enums import ( # These enums will need to be defined in disagreement/enums.py - VerificationLevel, - MessageNotificationLevel, - ExplicitContentFilterLevel, - MFALevel, - GuildNSFWLevel, - PremiumTier, - GuildFeature, - ChannelType, - ComponentType, - ButtonStyle, # Added for Button - GuildScheduledEventPrivacyLevel, - GuildScheduledEventStatus, - GuildScheduledEventEntityType, - # SelectMenuType will be part of ComponentType or a new enum if needed -) -from .permissions import Permissions - - -if TYPE_CHECKING: - from .client import Client # For type hinting to avoid circular imports - from .enums import OverwriteType # For PermissionOverwrite model - from .ui.view import View - from .interactions import Snowflake - - # Forward reference Message if it were used in type hints before its definition - # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. - from .components import component_factory - - -class User: - """Represents a Discord User. - - Attributes: - id (str): The user's unique ID. - username (str): The user's username. - discriminator (str): The user's 4-digit discord-tag. - bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False. - avatar (Optional[str]): The user's avatar hash, if any. - """ - - def __init__(self, data: dict): - self.id: str = data["id"] - self.username: str = data["username"] - self.discriminator: str = data["discriminator"] - self.bot: bool = data.get("bot", False) - self.avatar: Optional[str] = data.get("avatar") - - @property - def mention(self) -> str: - """str: Returns a string that allows you to mention the user.""" - return f"<@{self.id}>" - - def __repr__(self) -> str: - return f"" - - -class Message: - """Represents a message sent in a channel on Discord. - - Attributes: - id (str): The message's unique ID. - channel_id (str): The ID of the channel the message was sent in. - guild_id (Optional[str]): The ID of the guild the message was sent in, if applicable. - author (User): The user who sent the message. - content (str): The actual content of the message. - timestamp (str): When this message was sent (ISO8601 timestamp). - components (Optional[List[ActionRow]]): Structured components attached - to the message if present. - attachments (List[Attachment]): Attachments included with the message. - """ - - def __init__(self, data: dict, client_instance: "Client"): - self._client: "Client" = ( - client_instance # Store reference to client for methods like reply - ) - - self.id: str = data["id"] - self.channel_id: str = data["channel_id"] - self.guild_id: Optional[str] = data.get("guild_id") - self.author: User = User(data["author"]) - self.content: str = data["content"] - self.timestamp: str = data["timestamp"] - if data.get("components"): - self.components: Optional[List[ActionRow]] = [ - ActionRow.from_dict(c, client_instance) - for c in data.get("components", []) - ] - else: - self.components = None - self.attachments: List[Attachment] = [ - Attachment(a) for a in data.get("attachments", []) - ] +# disagreement/models.py + +""" +Data models for Discord objects. +""" + +import asyncio +import json +from dataclasses import dataclass +from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast + +from .cache import ChannelCache, MemberCache + +import aiohttp # pylint: disable=import-error +from .color import Color +from .errors import DisagreementException, HTTPException +from .enums import ( # These enums will need to be defined in disagreement/enums.py + VerificationLevel, + MessageNotificationLevel, + ExplicitContentFilterLevel, + MFALevel, + GuildNSFWLevel, + PremiumTier, + GuildFeature, + ChannelType, + ComponentType, + ButtonStyle, # Added for Button + GuildScheduledEventPrivacyLevel, + GuildScheduledEventStatus, + GuildScheduledEventEntityType, + # SelectMenuType will be part of ComponentType or a new enum if needed +) +from .permissions import Permissions + + +if TYPE_CHECKING: + from .client import Client # For type hinting to avoid circular imports + from .enums import OverwriteType # For PermissionOverwrite model + from .ui.view import View + from .interactions import Snowflake + + # Forward reference Message if it were used in type hints before its definition + # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. + from .components import component_factory + + +class User: + """Represents a Discord User. + + Attributes: + id (str): The user's unique ID. + username (str): The user's username. + discriminator (str): The user's 4-digit discord-tag. + bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False. + avatar (Optional[str]): The user's avatar hash, if any. + """ + + def __init__(self, data: dict): + self.id: str = data["id"] + self.username: str = data["username"] + self.discriminator: str = data["discriminator"] + self.bot: bool = data.get("bot", False) + self.avatar: Optional[str] = data.get("avatar") + + @property + def mention(self) -> str: + """str: Returns a string that allows you to mention the user.""" + return f"<@{self.id}>" + + def __repr__(self) -> str: + return f"" + + +class Message: + """Represents a message sent in a channel on Discord. + + Attributes: + id (str): The message's unique ID. + channel_id (str): The ID of the channel the message was sent in. + guild_id (Optional[str]): The ID of the guild the message was sent in, if applicable. + author (User): The user who sent the message. + content (str): The actual content of the message. + timestamp (str): When this message was sent (ISO8601 timestamp). + components (Optional[List[ActionRow]]): Structured components attached + to the message if present. + attachments (List[Attachment]): Attachments included with the message. + """ + + def __init__(self, data: dict, client_instance: "Client"): + self._client: "Client" = ( + client_instance # Store reference to client for methods like reply + ) + + self.id: str = data["id"] + self.channel_id: str = data["channel_id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.author: User = User(data["author"]) + self.content: str = data["content"] + self.timestamp: str = data["timestamp"] + if data.get("components"): + self.components: Optional[List[ActionRow]] = [ + ActionRow.from_dict(c, client_instance) + for c in data.get("components", []) + ] + else: + self.components = None + self.attachments: List[Attachment] = [ + Attachment(a) for a in data.get("attachments", []) + ] self.pinned: bool = data.get("pinned", False) - # Add other fields as needed, e.g., attachments, embeds, reactions, etc. - # self.mentions: List[User] = [User(u) for u in data.get("mentions", [])] - # self.mention_roles: List[str] = data.get("mention_roles", []) - # self.mention_everyone: bool = data.get("mention_everyone", False) - + # Add other fields as needed, e.g., attachments, embeds, reactions, etc. + # self.mentions: List[User] = [User(u) for u in data.get("mentions", [])] + # self.mention_roles: List[str] = data.get("mention_roles", []) + # self.mention_everyone: bool = data.get("mention_everyone", False) + async def pin(self) -> None: """|coro| @@ -392,737 +394,739 @@ class PartialMessage: else: await self._client._http.delete_reaction(self.channel.id, self.id, emoji) - -class EmbedFooter: - """Represents an embed footer.""" - - def __init__(self, data: Dict[str, Any]): - self.text: str = data["text"] - self.icon_url: Optional[str] = data.get("icon_url") - self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") - - def to_dict(self) -> Dict[str, Any]: - payload = {"text": self.text} - if self.icon_url: - payload["icon_url"] = self.icon_url - if self.proxy_icon_url: - payload["proxy_icon_url"] = self.proxy_icon_url - return payload - - -class EmbedImage: - """Represents an embed image.""" - - def __init__(self, data: Dict[str, Any]): - self.url: str = data["url"] - self.proxy_url: Optional[str] = data.get("proxy_url") - self.height: Optional[int] = data.get("height") - self.width: Optional[int] = data.get("width") - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"url": self.url} - if self.proxy_url: - payload["proxy_url"] = self.proxy_url - if self.height: - payload["height"] = self.height - if self.width: - payload["width"] = self.width - return payload - - def __repr__(self) -> str: - return f"" - - -class EmbedThumbnail(EmbedImage): # Similar structure to EmbedImage - """Represents an embed thumbnail.""" - - pass - - -class EmbedAuthor: - """Represents an embed author.""" - - def __init__(self, data: Dict[str, Any]): - self.name: str = data["name"] - self.url: Optional[str] = data.get("url") - self.icon_url: Optional[str] = data.get("icon_url") - self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") - - def to_dict(self) -> Dict[str, Any]: - payload = {"name": self.name} - if self.url: - payload["url"] = self.url - if self.icon_url: - payload["icon_url"] = self.icon_url - if self.proxy_icon_url: - payload["proxy_icon_url"] = self.proxy_icon_url - return payload - - -class EmbedField: - """Represents an embed field.""" - - def __init__(self, data: Dict[str, Any]): - self.name: str = data["name"] - self.value: str = data["value"] - self.inline: bool = data.get("inline", False) - - def to_dict(self) -> Dict[str, Any]: - return {"name": self.name, "value": self.value, "inline": self.inline} - - -class Embed: - """Represents a Discord embed. - - Attributes can be set directly or via methods like `set_author`, `add_field`. - """ - - def __init__(self, data: Optional[Dict[str, Any]] = None): - data = data or {} - self.title: Optional[str] = data.get("title") - self.type: str = data.get("type", "rich") # Default to "rich" for sending - self.description: Optional[str] = data.get("description") - self.url: Optional[str] = data.get("url") - self.timestamp: Optional[str] = data.get("timestamp") # ISO8601 timestamp - self.color = Color.parse(data.get("color")) - - self.footer: Optional[EmbedFooter] = ( - EmbedFooter(data["footer"]) if data.get("footer") else None - ) - self.image: Optional[EmbedImage] = ( - EmbedImage(data["image"]) if data.get("image") else None - ) - self.thumbnail: Optional[EmbedThumbnail] = ( - EmbedThumbnail(data["thumbnail"]) if data.get("thumbnail") else None - ) - # Video and Provider are less common for bot-sent embeds, can be added if needed. - self.author: Optional[EmbedAuthor] = ( - EmbedAuthor(data["author"]) if data.get("author") else None - ) - self.fields: List[EmbedField] = ( - [EmbedField(f) for f in data["fields"]] if data.get("fields") else [] - ) - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"type": self.type} - if self.title: - payload["title"] = self.title - if self.description: - payload["description"] = self.description - if self.url: - payload["url"] = self.url - if self.timestamp: - payload["timestamp"] = self.timestamp - if self.color is not None: - payload["color"] = self.color.value - if self.footer: - payload["footer"] = self.footer.to_dict() - if self.image: - payload["image"] = self.image.to_dict() - if self.thumbnail: - payload["thumbnail"] = self.thumbnail.to_dict() - if self.author: - payload["author"] = self.author.to_dict() - if self.fields: - payload["fields"] = [f.to_dict() for f in self.fields] - return payload - - # Convenience methods for building embeds can be added here - # e.g., set_author, add_field, set_footer, set_image, etc. - - -class Attachment: - """Represents a message attachment.""" - - def __init__(self, data: Dict[str, Any]): - self.id: str = data["id"] - self.filename: str = data["filename"] - self.description: Optional[str] = data.get("description") - self.content_type: Optional[str] = data.get("content_type") - self.size: Optional[int] = data.get("size") - self.url: Optional[str] = data.get("url") - self.proxy_url: Optional[str] = data.get("proxy_url") - self.height: Optional[int] = data.get("height") # If image - self.width: Optional[int] = data.get("width") # If image - self.ephemeral: bool = data.get("ephemeral", False) - - def __repr__(self) -> str: - return f"" - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"id": self.id, "filename": self.filename} - if self.description is not None: - payload["description"] = self.description - if self.content_type is not None: - payload["content_type"] = self.content_type - if self.size is not None: - payload["size"] = self.size - if self.url is not None: - payload["url"] = self.url - if self.proxy_url is not None: - payload["proxy_url"] = self.proxy_url - if self.height is not None: - payload["height"] = self.height - if self.width is not None: - payload["width"] = self.width - if self.ephemeral: - payload["ephemeral"] = self.ephemeral - return payload - - -class File: - """Represents a file to be uploaded.""" - - def __init__(self, filename: str, data: bytes): - self.filename = filename - self.data = data - - -class AllowedMentions: - """Represents allowed mentions for a message or interaction response.""" - - def __init__(self, data: Dict[str, Any]): - self.parse: List[str] = data.get("parse", []) - self.roles: List[str] = data.get("roles", []) - self.users: List[str] = data.get("users", []) - self.replied_user: bool = data.get("replied_user", False) - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"parse": self.parse} - if self.roles: - payload["roles"] = self.roles - if self.users: - payload["users"] = self.users - if self.replied_user: - payload["replied_user"] = self.replied_user - return payload - - -class RoleTags: - """Represents tags for a role.""" - - def __init__(self, data: Dict[str, Any]): - self.bot_id: Optional[str] = data.get("bot_id") - self.integration_id: Optional[str] = data.get("integration_id") - self.premium_subscriber: Optional[bool] = ( - data.get("premium_subscriber") is None - ) # presence of null value means true - - def to_dict(self) -> Dict[str, Any]: - payload = {} - if self.bot_id: - payload["bot_id"] = self.bot_id - if self.integration_id: - payload["integration_id"] = self.integration_id - if self.premium_subscriber: - payload["premium_subscriber"] = None # Explicitly null - return payload - - -class Role: - """Represents a Discord Role.""" - - def __init__(self, data: Dict[str, Any]): - self.id: str = data["id"] - self.name: str = data["name"] - self.color: int = data["color"] - self.hoist: bool = data["hoist"] - self.icon: Optional[str] = data.get("icon") - self.unicode_emoji: Optional[str] = data.get("unicode_emoji") - self.position: int = data["position"] - self.permissions: str = data["permissions"] # String of bitwise permissions - self.managed: bool = data["managed"] - self.mentionable: bool = data["mentionable"] - self.tags: Optional[RoleTags] = ( - RoleTags(data["tags"]) if data.get("tags") else None - ) - - @property - def mention(self) -> str: - """str: Returns a string that allows you to mention the role.""" - return f"<@&{self.id}>" - - def __repr__(self) -> str: - return f"" - - -class Member(User): # Member inherits from User - """Represents a Guild Member. - This class combines User attributes with guild-specific Member attributes. - """ - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client: Optional["Client"] = client_instance - self.guild_id: Optional[str] = None - # User part is nested under 'user' key in member data from gateway/API - user_data = data.get("user", {}) - # If 'id' is not in user_data but is top-level (e.g. from interaction resolved member without user object) - if "id" not in user_data and "id" in data: - # This case is less common for full member objects but can happen. - # We'd need to construct a partial user from top-level member fields if 'user' is missing. - # For now, assume 'user' object is present for full Member hydration. - # If 'user' is missing, the User part might be incomplete. - pass # User fields will be missing or default if 'user' not in data. - - super().__init__( - user_data if user_data else data - ) # Pass user_data or data if user_data is empty - - self.nick: Optional[str] = data.get("nick") - self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash - self.roles: List[str] = data.get("roles", []) # List of role IDs - self.joined_at: str = data["joined_at"] # ISO8601 timestamp - self.premium_since: Optional[str] = data.get( - "premium_since" - ) # ISO8601 timestamp - self.deaf: bool = data.get("deaf", False) - self.mute: bool = data.get("mute", False) - self.pending: bool = data.get("pending", False) - self.permissions: Optional[str] = data.get( - "permissions" - ) # Permissions in the channel, if applicable - self.communication_disabled_until: Optional[str] = data.get( - "communication_disabled_until" - ) # ISO8601 timestamp - - # If 'user' object was present, ensure User attributes are from there - if user_data: - self.id = user_data.get("id", self.id) # Prefer user.id if available - self.username = user_data.get("username", self.username) - self.discriminator = user_data.get("discriminator", self.discriminator) - self.bot = user_data.get("bot", self.bot) - # User's global avatar is User.avatar, Member.avatar is guild-specific - # super() already set self.avatar from user_data if present. - # The self.avatar = data.get("avatar") line above overwrites it with guild avatar. This is correct. - - def __repr__(self) -> str: - return f"" - - @property - def display_name(self) -> str: - """Return the nickname if set, otherwise the username.""" - - return self.nick or self.username - - async def kick(self, *, reason: Optional[str] = None) -> None: - if not self.guild_id or not self._client: - raise DisagreementException("Member.kick requires guild_id and client") - await self._client._http.kick_member(self.guild_id, self.id, reason=reason) - - async def ban( - self, - *, - delete_message_seconds: int = 0, - reason: Optional[str] = None, - ) -> None: - if not self.guild_id or not self._client: - raise DisagreementException("Member.ban requires guild_id and client") - await self._client._http.ban_member( - self.guild_id, - self.id, - delete_message_seconds=delete_message_seconds, - reason=reason, - ) - - async def timeout( - self, until: Optional[str], *, reason: Optional[str] = None - ) -> None: - if not self.guild_id or not self._client: - raise DisagreementException("Member.timeout requires guild_id and client") - await self._client._http.timeout_member( - self.guild_id, - self.id, - until=until, - reason=reason, - ) - - @property - def top_role(self) -> Optional["Role"]: - """Return the member's highest role from the guild cache.""" - - if not self.guild_id or not self._client: - return None - - guild = self._client.get_guild(self.guild_id) - if not guild: - return None - - if not guild.roles and hasattr(self._client, "fetch_roles"): - try: - self._client.loop.run_until_complete( - self._client.fetch_roles(self.guild_id) - ) - except RuntimeError: - future = asyncio.run_coroutine_threadsafe( - self._client.fetch_roles(self.guild_id), self._client.loop - ) - future.result() - - role_objects = [r for r in guild.roles if r.id in self.roles] - if not role_objects: - return None - - return max(role_objects, key=lambda r: r.position) - - -class PartialEmoji: - """Represents a partial emoji, often used in components or reactions. - - This typically means only id, name, and animated are known. - For unicode emojis, id will be None and name will be the unicode character. - """ - - def __init__(self, data: Dict[str, Any]): - self.id: Optional[str] = data.get("id") - self.name: Optional[str] = data.get( - "name" - ) # Can be None for unknown custom emoji, or unicode char - self.animated: bool = data.get("animated", False) - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {} - if self.id: - payload["id"] = self.id - if self.name: - payload["name"] = self.name - if self.animated: # Only include if true, as per some Discord patterns - payload["animated"] = self.animated - return payload - - def __str__(self) -> str: - if self.id: - return f"<{'a' if self.animated else ''}:{self.name}:{self.id}>" - return self.name or "" # For unicode emoji - - def __repr__(self) -> str: - return ( - f"" - ) - - -def to_partial_emoji( - value: Union[str, "PartialEmoji", None], -) -> Optional["PartialEmoji"]: - """Convert a string or PartialEmoji to a PartialEmoji instance. - - Args: - value: Either a unicode emoji string, a :class:`PartialEmoji`, or ``None``. - - Returns: - A :class:`PartialEmoji` or ``None`` if ``value`` was ``None``. - - Raises: - TypeError: If ``value`` is not ``str`` or :class:`PartialEmoji`. - """ - - if value is None or isinstance(value, PartialEmoji): - return value - if isinstance(value, str): - return PartialEmoji({"name": value, "id": None}) - raise TypeError("emoji must be a str or PartialEmoji") - - -class Emoji(PartialEmoji): - """Represents a custom guild emoji. - - Inherits id, name, animated from PartialEmoji. - """ - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - super().__init__(data) - self._client: Optional["Client"] = ( - client_instance # For potential future methods - ) - - # Roles this emoji is whitelisted to - self.roles: List[str] = data.get("roles", []) # List of role IDs - - # User object for the user that created this emoji (optional, only for GUILD_EMOJIS_AND_STICKERS intent) - self.user: Optional[User] = User(data["user"]) if data.get("user") else None - - self.require_colons: bool = data.get("require_colons", False) - self.managed: bool = data.get( - "managed", False - ) # If this emoji is managed by an integration - self.available: bool = data.get( - "available", True - ) # Whether this emoji can be used - - def __repr__(self) -> str: - return f"" - - -class StickerItem: - """Represents a sticker item, a basic representation of a sticker. - - Used in sticker packs and sometimes in message data. - """ - - def __init__(self, data: Dict[str, Any]): - self.id: str = data["id"] - self.name: str = data["name"] - self.format_type: int = data["format_type"] # StickerFormatType enum - - def __repr__(self) -> str: - return f"" - - -class Sticker(StickerItem): - """Represents a Discord sticker. - - Inherits id, name, format_type from StickerItem. - """ - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - super().__init__(data) - self._client: Optional["Client"] = client_instance - - self.pack_id: Optional[str] = data.get( - "pack_id" - ) # For standard stickers, ID of the pack - self.description: Optional[str] = data.get("description") - self.tags: str = data.get( - "tags", "" - ) # Comma-separated list of tags for guild stickers - # type is StickerType enum (STANDARD or GUILD) - # For guild stickers, this is 2. For standard stickers, this is 1. - self.type: int = data["type"] - self.available: bool = data.get( - "available", True - ) # Whether this sticker can be used - self.guild_id: Optional[str] = data.get( - "guild_id" - ) # ID of the guild that owns this sticker - - # User object of the user that uploaded the guild sticker - self.user: Optional[User] = User(data["user"]) if data.get("user") else None - - self.sort_value: Optional[int] = data.get( - "sort_value" - ) # The standard sticker's sort order within its pack - - def __repr__(self) -> str: - return f"" - - -class StickerPack: - """Represents a pack of standard stickers.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client: Optional["Client"] = client_instance - self.id: str = data["id"] - self.stickers: List[Sticker] = [ - Sticker(s_data, client_instance) for s_data in data.get("stickers", []) - ] - self.name: str = data["name"] - self.sku_id: str = data["sku_id"] - self.cover_sticker_id: Optional[str] = data.get("cover_sticker_id") - self.description: str = data["description"] - self.banner_asset_id: Optional[str] = data.get( - "banner_asset_id" - ) # ID of the pack's banner image - - def __repr__(self) -> str: - return f"" - - -class PermissionOverwrite: - """Represents a permission overwrite for a role or member in a channel.""" - - def __init__(self, data: Dict[str, Any]): - self.id: str = data["id"] # Role or user ID - self._type_val: int = int(data["type"]) # Store raw type for enum property - self.allow: str = data["allow"] # Bitwise value of allowed permissions - self.deny: str = data["deny"] # Bitwise value of denied permissions - - @property - def type(self) -> "OverwriteType": - from .enums import ( - OverwriteType, - ) # Local import to avoid circularity at module level - - return OverwriteType(self._type_val) - - def to_dict(self) -> Dict[str, Any]: - return { - "id": self.id, - "type": self.type.value, - "allow": self.allow, - "deny": self.deny, - } - - def __repr__(self) -> str: - return f"" - - -class Guild: - """Represents a Discord Guild (Server). - - Attributes: - id (str): Guild ID. - name (str): Guild name (2-100 characters, excluding @, #, :, ```). - icon (Optional[str]): Icon hash. - splash (Optional[str]): Splash hash. - discovery_splash (Optional[str]): Discovery splash hash; only present for discoverable guilds. - owner (Optional[bool]): True if the user is the owner of the guild. (Only for /users/@me/guilds endpoint) - owner_id (str): ID of owner. - permissions (Optional[str]): Total permissions for the user in the guild (excludes overwrites). (Only for /users/@me/guilds endpoint) - afk_channel_id (Optional[str]): ID of afk channel. - afk_timeout (int): AFK timeout in seconds. - widget_enabled (Optional[bool]): True if the server widget is enabled. - widget_channel_id (Optional[str]): The channel id that the widget will generate an invite to, or null if set to no invite. - verification_level (VerificationLevel): Verification level required for the guild. - default_message_notifications (MessageNotificationLevel): Default message notifications level. - explicit_content_filter (ExplicitContentFilterLevel): Explicit content filter level. - roles (List[Role]): Roles in the guild. - emojis (List[Dict]): Custom emojis. (Consider creating an Emoji model) - features (List[GuildFeature]): Enabled guild features. - mfa_level (MFALevel): Required MFA level for the guild. - application_id (Optional[str]): Application ID of the guild creator if it is bot-created. - system_channel_id (Optional[str]): The id of the channel where guild notices such as welcome messages and boost events are posted. - system_channel_flags (int): System channel flags. - rules_channel_id (Optional[str]): The id of the channel where Community guilds can display rules. - max_members (Optional[int]): The maximum number of members for the guild. - vanity_url_code (Optional[str]): The vanity url code for the guild. - description (Optional[str]): The description of a Community guild. - banner (Optional[str]): Banner hash. - premium_tier (PremiumTier): Premium tier (Server Boost level). - premium_subscription_count (Optional[int]): The number of boosts this guild currently has. - preferred_locale (str): The preferred locale of a Community guild. Defaults to "en-US". - public_updates_channel_id (Optional[str]): The id of the channel where admins and moderators of Community guilds receive notices from Discord. - max_video_channel_users (Optional[int]): The maximum number of users in a video channel. - welcome_screen (Optional[Dict]): The welcome screen of a Community guild. (Consider a WelcomeScreen model) - nsfw_level (GuildNSFWLevel): Guild NSFW level. - stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model) - premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled. - """ - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - self._client: "Client" = client_instance - self.id: str = data["id"] - self.name: str = data["name"] - self.icon: Optional[str] = data.get("icon") - self.splash: Optional[str] = data.get("splash") - self.discovery_splash: Optional[str] = data.get("discovery_splash") - self.owner: Optional[bool] = data.get("owner") - self.owner_id: str = data["owner_id"] - self.permissions: Optional[str] = data.get("permissions") - self.afk_channel_id: Optional[str] = data.get("afk_channel_id") - self.afk_timeout: int = data["afk_timeout"] - self.widget_enabled: Optional[bool] = data.get("widget_enabled") - self.widget_channel_id: Optional[str] = data.get("widget_channel_id") - self.verification_level: VerificationLevel = VerificationLevel( - data["verification_level"] - ) - self.default_message_notifications: MessageNotificationLevel = ( - MessageNotificationLevel(data["default_message_notifications"]) - ) - self.explicit_content_filter: ExplicitContentFilterLevel = ( - ExplicitContentFilterLevel(data["explicit_content_filter"]) - ) - - self.roles: List[Role] = [Role(r) for r in data.get("roles", [])] - self.emojis: List[Emoji] = [ - Emoji(e_data, client_instance) for e_data in data.get("emojis", []) - ] - - # Assuming GuildFeature can be constructed from string feature names or their values - self.features: List[GuildFeature] = [ - GuildFeature(f) if not isinstance(f, GuildFeature) else f - for f in data.get("features", []) - ] - - self.mfa_level: MFALevel = MFALevel(data["mfa_level"]) - self.application_id: Optional[str] = data.get("application_id") - self.system_channel_id: Optional[str] = data.get("system_channel_id") - self.system_channel_flags: int = data["system_channel_flags"] - self.rules_channel_id: Optional[str] = data.get("rules_channel_id") - self.max_members: Optional[int] = data.get("max_members") - self.vanity_url_code: Optional[str] = data.get("vanity_url_code") - self.description: Optional[str] = data.get("description") - self.banner: Optional[str] = data.get("banner") - self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"]) - self.premium_subscription_count: Optional[int] = data.get( - "premium_subscription_count" - ) - self.preferred_locale: str = data.get("preferred_locale", "en-US") - self.public_updates_channel_id: Optional[str] = data.get( - "public_updates_channel_id" - ) - self.max_video_channel_users: Optional[int] = data.get( - "max_video_channel_users" - ) - self.approximate_member_count: Optional[int] = data.get( - "approximate_member_count" - ) - self.approximate_presence_count: Optional[int] = data.get( - "approximate_presence_count" - ) - self.welcome_screen: Optional["WelcomeScreen"] = ( - WelcomeScreen(data["welcome_screen"], client_instance) - if data.get("welcome_screen") - else None - ) - self.nsfw_level: GuildNSFWLevel = GuildNSFWLevel(data["nsfw_level"]) - self.stickers: Optional[List[Sticker]] = ( - [Sticker(s_data, client_instance) for s_data in data.get("stickers", [])] - if data.get("stickers") - else None - ) - self.premium_progress_bar_enabled: bool = data.get( - "premium_progress_bar_enabled", False - ) - - # Internal caches, populated by events or specific fetches - self._channels: Dict[str, "Channel"] = {} - self._members: Dict[str, Member] = {} - self._threads: Dict[str, "Thread"] = {} - - def get_channel(self, channel_id: str) -> Optional["Channel"]: - return self._channels.get(channel_id) - - def get_member(self, user_id: str) -> Optional[Member]: - return self._members.get(user_id) - - def get_member_named(self, name: str) -> Optional[Member]: - """Retrieve a cached member by username or nickname. - - The lookup is case-insensitive and searches both the username and - guild nickname for a match. - - Parameters - ---------- - name: str - The username or nickname to search for. - - Returns - ------- - Optional[Member] - The matching member if found, otherwise ``None``. - """ - - lowered = name.lower() - for member in self._members.values(): - if member.username.lower() == lowered: - return member - if member.nick and member.nick.lower() == lowered: - return member - return None - - def get_role(self, role_id: str) -> Optional[Role]: - return next((role for role in self.roles if role.id == role_id), None) - - def __repr__(self) -> str: - return f"" - + +class EmbedFooter: + """Represents an embed footer.""" + + def __init__(self, data: Dict[str, Any]): + self.text: str = data["text"] + self.icon_url: Optional[str] = data.get("icon_url") + self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") + + def to_dict(self) -> Dict[str, Any]: + payload = {"text": self.text} + if self.icon_url: + payload["icon_url"] = self.icon_url + if self.proxy_icon_url: + payload["proxy_icon_url"] = self.proxy_icon_url + return payload + + +class EmbedImage: + """Represents an embed image.""" + + def __init__(self, data: Dict[str, Any]): + self.url: str = data["url"] + self.proxy_url: Optional[str] = data.get("proxy_url") + self.height: Optional[int] = data.get("height") + self.width: Optional[int] = data.get("width") + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"url": self.url} + if self.proxy_url: + payload["proxy_url"] = self.proxy_url + if self.height: + payload["height"] = self.height + if self.width: + payload["width"] = self.width + return payload + + def __repr__(self) -> str: + return f"" + + +class EmbedThumbnail(EmbedImage): # Similar structure to EmbedImage + """Represents an embed thumbnail.""" + + pass + + +class EmbedAuthor: + """Represents an embed author.""" + + def __init__(self, data: Dict[str, Any]): + self.name: str = data["name"] + self.url: Optional[str] = data.get("url") + self.icon_url: Optional[str] = data.get("icon_url") + self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") + + def to_dict(self) -> Dict[str, Any]: + payload = {"name": self.name} + if self.url: + payload["url"] = self.url + if self.icon_url: + payload["icon_url"] = self.icon_url + if self.proxy_icon_url: + payload["proxy_icon_url"] = self.proxy_icon_url + return payload + + +class EmbedField: + """Represents an embed field.""" + + def __init__(self, data: Dict[str, Any]): + self.name: str = data["name"] + self.value: str = data["value"] + self.inline: bool = data.get("inline", False) + + def to_dict(self) -> Dict[str, Any]: + return {"name": self.name, "value": self.value, "inline": self.inline} + + +class Embed: + """Represents a Discord embed. + + Attributes can be set directly or via methods like `set_author`, `add_field`. + """ + + def __init__(self, data: Optional[Dict[str, Any]] = None): + data = data or {} + self.title: Optional[str] = data.get("title") + self.type: str = data.get("type", "rich") # Default to "rich" for sending + self.description: Optional[str] = data.get("description") + self.url: Optional[str] = data.get("url") + self.timestamp: Optional[str] = data.get("timestamp") # ISO8601 timestamp + self.color = Color.parse(data.get("color")) + + self.footer: Optional[EmbedFooter] = ( + EmbedFooter(data["footer"]) if data.get("footer") else None + ) + self.image: Optional[EmbedImage] = ( + EmbedImage(data["image"]) if data.get("image") else None + ) + self.thumbnail: Optional[EmbedThumbnail] = ( + EmbedThumbnail(data["thumbnail"]) if data.get("thumbnail") else None + ) + # Video and Provider are less common for bot-sent embeds, can be added if needed. + self.author: Optional[EmbedAuthor] = ( + EmbedAuthor(data["author"]) if data.get("author") else None + ) + self.fields: List[EmbedField] = ( + [EmbedField(f) for f in data["fields"]] if data.get("fields") else [] + ) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"type": self.type} + if self.title: + payload["title"] = self.title + if self.description: + payload["description"] = self.description + if self.url: + payload["url"] = self.url + if self.timestamp: + payload["timestamp"] = self.timestamp + if self.color is not None: + payload["color"] = self.color.value + if self.footer: + payload["footer"] = self.footer.to_dict() + if self.image: + payload["image"] = self.image.to_dict() + if self.thumbnail: + payload["thumbnail"] = self.thumbnail.to_dict() + if self.author: + payload["author"] = self.author.to_dict() + if self.fields: + payload["fields"] = [f.to_dict() for f in self.fields] + return payload + + # Convenience methods for building embeds can be added here + # e.g., set_author, add_field, set_footer, set_image, etc. + + +class Attachment: + """Represents a message attachment.""" + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] + self.filename: str = data["filename"] + self.description: Optional[str] = data.get("description") + self.content_type: Optional[str] = data.get("content_type") + self.size: Optional[int] = data.get("size") + self.url: Optional[str] = data.get("url") + self.proxy_url: Optional[str] = data.get("proxy_url") + self.height: Optional[int] = data.get("height") # If image + self.width: Optional[int] = data.get("width") # If image + self.ephemeral: bool = data.get("ephemeral", False) + + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"id": self.id, "filename": self.filename} + if self.description is not None: + payload["description"] = self.description + if self.content_type is not None: + payload["content_type"] = self.content_type + if self.size is not None: + payload["size"] = self.size + if self.url is not None: + payload["url"] = self.url + if self.proxy_url is not None: + payload["proxy_url"] = self.proxy_url + if self.height is not None: + payload["height"] = self.height + if self.width is not None: + payload["width"] = self.width + if self.ephemeral: + payload["ephemeral"] = self.ephemeral + return payload + + +class File: + """Represents a file to be uploaded.""" + + def __init__(self, filename: str, data: bytes): + self.filename = filename + self.data = data + + +class AllowedMentions: + """Represents allowed mentions for a message or interaction response.""" + + def __init__(self, data: Dict[str, Any]): + self.parse: List[str] = data.get("parse", []) + self.roles: List[str] = data.get("roles", []) + self.users: List[str] = data.get("users", []) + self.replied_user: bool = data.get("replied_user", False) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"parse": self.parse} + if self.roles: + payload["roles"] = self.roles + if self.users: + payload["users"] = self.users + if self.replied_user: + payload["replied_user"] = self.replied_user + return payload + + +class RoleTags: + """Represents tags for a role.""" + + def __init__(self, data: Dict[str, Any]): + self.bot_id: Optional[str] = data.get("bot_id") + self.integration_id: Optional[str] = data.get("integration_id") + self.premium_subscriber: Optional[bool] = ( + data.get("premium_subscriber") is None + ) # presence of null value means true + + def to_dict(self) -> Dict[str, Any]: + payload = {} + if self.bot_id: + payload["bot_id"] = self.bot_id + if self.integration_id: + payload["integration_id"] = self.integration_id + if self.premium_subscriber: + payload["premium_subscriber"] = None # Explicitly null + return payload + + +class Role: + """Represents a Discord Role.""" + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] + self.name: str = data["name"] + self.color: int = data["color"] + self.hoist: bool = data["hoist"] + self.icon: Optional[str] = data.get("icon") + self.unicode_emoji: Optional[str] = data.get("unicode_emoji") + self.position: int = data["position"] + self.permissions: str = data["permissions"] # String of bitwise permissions + self.managed: bool = data["managed"] + self.mentionable: bool = data["mentionable"] + self.tags: Optional[RoleTags] = ( + RoleTags(data["tags"]) if data.get("tags") else None + ) + + @property + def mention(self) -> str: + """str: Returns a string that allows you to mention the role.""" + return f"<@&{self.id}>" + + def __repr__(self) -> str: + return f"" + + +class Member(User): # Member inherits from User + """Represents a Guild Member. + This class combines User attributes with guild-specific Member attributes. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client: Optional["Client"] = client_instance + self.guild_id: Optional[str] = None + self.status: Optional[str] = None + self.voice_state: Optional[Dict[str, Any]] = None + # User part is nested under 'user' key in member data from gateway/API + user_data = data.get("user", {}) + # If 'id' is not in user_data but is top-level (e.g. from interaction resolved member without user object) + if "id" not in user_data and "id" in data: + # This case is less common for full member objects but can happen. + # We'd need to construct a partial user from top-level member fields if 'user' is missing. + # For now, assume 'user' object is present for full Member hydration. + # If 'user' is missing, the User part might be incomplete. + pass # User fields will be missing or default if 'user' not in data. + + super().__init__( + user_data if user_data else data + ) # Pass user_data or data if user_data is empty + + self.nick: Optional[str] = data.get("nick") + self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash + self.roles: List[str] = data.get("roles", []) # List of role IDs + self.joined_at: str = data["joined_at"] # ISO8601 timestamp + self.premium_since: Optional[str] = data.get( + "premium_since" + ) # ISO8601 timestamp + self.deaf: bool = data.get("deaf", False) + self.mute: bool = data.get("mute", False) + self.pending: bool = data.get("pending", False) + self.permissions: Optional[str] = data.get( + "permissions" + ) # Permissions in the channel, if applicable + self.communication_disabled_until: Optional[str] = data.get( + "communication_disabled_until" + ) # ISO8601 timestamp + + # If 'user' object was present, ensure User attributes are from there + if user_data: + self.id = user_data.get("id", self.id) # Prefer user.id if available + self.username = user_data.get("username", self.username) + self.discriminator = user_data.get("discriminator", self.discriminator) + self.bot = user_data.get("bot", self.bot) + # User's global avatar is User.avatar, Member.avatar is guild-specific + # super() already set self.avatar from user_data if present. + # The self.avatar = data.get("avatar") line above overwrites it with guild avatar. This is correct. + + def __repr__(self) -> str: + return f"" + + @property + def display_name(self) -> str: + """Return the nickname if set, otherwise the username.""" + + return self.nick or self.username + + async def kick(self, *, reason: Optional[str] = None) -> None: + if not self.guild_id or not self._client: + raise DisagreementException("Member.kick requires guild_id and client") + await self._client._http.kick_member(self.guild_id, self.id, reason=reason) + + async def ban( + self, + *, + delete_message_seconds: int = 0, + reason: Optional[str] = None, + ) -> None: + if not self.guild_id or not self._client: + raise DisagreementException("Member.ban requires guild_id and client") + await self._client._http.ban_member( + self.guild_id, + self.id, + delete_message_seconds=delete_message_seconds, + reason=reason, + ) + + async def timeout( + self, until: Optional[str], *, reason: Optional[str] = None + ) -> None: + if not self.guild_id or not self._client: + raise DisagreementException("Member.timeout requires guild_id and client") + await self._client._http.timeout_member( + self.guild_id, + self.id, + until=until, + reason=reason, + ) + + @property + def top_role(self) -> Optional["Role"]: + """Return the member's highest role from the guild cache.""" + + if not self.guild_id or not self._client: + return None + + guild = self._client.get_guild(self.guild_id) + if not guild: + return None + + if not guild.roles and hasattr(self._client, "fetch_roles"): + try: + self._client.loop.run_until_complete( + self._client.fetch_roles(self.guild_id) + ) + except RuntimeError: + future = asyncio.run_coroutine_threadsafe( + self._client.fetch_roles(self.guild_id), self._client.loop + ) + future.result() + + role_objects = [r for r in guild.roles if r.id in self.roles] + if not role_objects: + return None + + return max(role_objects, key=lambda r: r.position) + + +class PartialEmoji: + """Represents a partial emoji, often used in components or reactions. + + This typically means only id, name, and animated are known. + For unicode emojis, id will be None and name will be the unicode character. + """ + + def __init__(self, data: Dict[str, Any]): + self.id: Optional[str] = data.get("id") + self.name: Optional[str] = data.get( + "name" + ) # Can be None for unknown custom emoji, or unicode char + self.animated: bool = data.get("animated", False) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + if self.id: + payload["id"] = self.id + if self.name: + payload["name"] = self.name + if self.animated: # Only include if true, as per some Discord patterns + payload["animated"] = self.animated + return payload + + def __str__(self) -> str: + if self.id: + return f"<{'a' if self.animated else ''}:{self.name}:{self.id}>" + return self.name or "" # For unicode emoji + + def __repr__(self) -> str: + return ( + f"" + ) + + +def to_partial_emoji( + value: Union[str, "PartialEmoji", None], +) -> Optional["PartialEmoji"]: + """Convert a string or PartialEmoji to a PartialEmoji instance. + + Args: + value: Either a unicode emoji string, a :class:`PartialEmoji`, or ``None``. + + Returns: + A :class:`PartialEmoji` or ``None`` if ``value`` was ``None``. + + Raises: + TypeError: If ``value`` is not ``str`` or :class:`PartialEmoji`. + """ + + if value is None or isinstance(value, PartialEmoji): + return value + if isinstance(value, str): + return PartialEmoji({"name": value, "id": None}) + raise TypeError("emoji must be a str or PartialEmoji") + + +class Emoji(PartialEmoji): + """Represents a custom guild emoji. + + Inherits id, name, animated from PartialEmoji. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + super().__init__(data) + self._client: Optional["Client"] = ( + client_instance # For potential future methods + ) + + # Roles this emoji is whitelisted to + self.roles: List[str] = data.get("roles", []) # List of role IDs + + # User object for the user that created this emoji (optional, only for GUILD_EMOJIS_AND_STICKERS intent) + self.user: Optional[User] = User(data["user"]) if data.get("user") else None + + self.require_colons: bool = data.get("require_colons", False) + self.managed: bool = data.get( + "managed", False + ) # If this emoji is managed by an integration + self.available: bool = data.get( + "available", True + ) # Whether this emoji can be used + + def __repr__(self) -> str: + return f"" + + +class StickerItem: + """Represents a sticker item, a basic representation of a sticker. + + Used in sticker packs and sometimes in message data. + """ + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] + self.name: str = data["name"] + self.format_type: int = data["format_type"] # StickerFormatType enum + + def __repr__(self) -> str: + return f"" + + +class Sticker(StickerItem): + """Represents a Discord sticker. + + Inherits id, name, format_type from StickerItem. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + super().__init__(data) + self._client: Optional["Client"] = client_instance + + self.pack_id: Optional[str] = data.get( + "pack_id" + ) # For standard stickers, ID of the pack + self.description: Optional[str] = data.get("description") + self.tags: str = data.get( + "tags", "" + ) # Comma-separated list of tags for guild stickers + # type is StickerType enum (STANDARD or GUILD) + # For guild stickers, this is 2. For standard stickers, this is 1. + self.type: int = data["type"] + self.available: bool = data.get( + "available", True + ) # Whether this sticker can be used + self.guild_id: Optional[str] = data.get( + "guild_id" + ) # ID of the guild that owns this sticker + + # User object of the user that uploaded the guild sticker + self.user: Optional[User] = User(data["user"]) if data.get("user") else None + + self.sort_value: Optional[int] = data.get( + "sort_value" + ) # The standard sticker's sort order within its pack + + def __repr__(self) -> str: + return f"" + + +class StickerPack: + """Represents a pack of standard stickers.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client: Optional["Client"] = client_instance + self.id: str = data["id"] + self.stickers: List[Sticker] = [ + Sticker(s_data, client_instance) for s_data in data.get("stickers", []) + ] + self.name: str = data["name"] + self.sku_id: str = data["sku_id"] + self.cover_sticker_id: Optional[str] = data.get("cover_sticker_id") + self.description: str = data["description"] + self.banner_asset_id: Optional[str] = data.get( + "banner_asset_id" + ) # ID of the pack's banner image + + def __repr__(self) -> str: + return f"" + + +class PermissionOverwrite: + """Represents a permission overwrite for a role or member in a channel.""" + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] # Role or user ID + self._type_val: int = int(data["type"]) # Store raw type for enum property + self.allow: str = data["allow"] # Bitwise value of allowed permissions + self.deny: str = data["deny"] # Bitwise value of denied permissions + + @property + def type(self) -> "OverwriteType": + from .enums import ( + OverwriteType, + ) # Local import to avoid circularity at module level + + return OverwriteType(self._type_val) + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "type": self.type.value, + "allow": self.allow, + "deny": self.deny, + } + + def __repr__(self) -> str: + return f"" + + +class Guild: + """Represents a Discord Guild (Server). + + Attributes: + id (str): Guild ID. + name (str): Guild name (2-100 characters, excluding @, #, :, ```). + icon (Optional[str]): Icon hash. + splash (Optional[str]): Splash hash. + discovery_splash (Optional[str]): Discovery splash hash; only present for discoverable guilds. + owner (Optional[bool]): True if the user is the owner of the guild. (Only for /users/@me/guilds endpoint) + owner_id (str): ID of owner. + permissions (Optional[str]): Total permissions for the user in the guild (excludes overwrites). (Only for /users/@me/guilds endpoint) + afk_channel_id (Optional[str]): ID of afk channel. + afk_timeout (int): AFK timeout in seconds. + widget_enabled (Optional[bool]): True if the server widget is enabled. + widget_channel_id (Optional[str]): The channel id that the widget will generate an invite to, or null if set to no invite. + verification_level (VerificationLevel): Verification level required for the guild. + default_message_notifications (MessageNotificationLevel): Default message notifications level. + explicit_content_filter (ExplicitContentFilterLevel): Explicit content filter level. + roles (List[Role]): Roles in the guild. + emojis (List[Dict]): Custom emojis. (Consider creating an Emoji model) + features (List[GuildFeature]): Enabled guild features. + mfa_level (MFALevel): Required MFA level for the guild. + application_id (Optional[str]): Application ID of the guild creator if it is bot-created. + system_channel_id (Optional[str]): The id of the channel where guild notices such as welcome messages and boost events are posted. + system_channel_flags (int): System channel flags. + rules_channel_id (Optional[str]): The id of the channel where Community guilds can display rules. + max_members (Optional[int]): The maximum number of members for the guild. + vanity_url_code (Optional[str]): The vanity url code for the guild. + description (Optional[str]): The description of a Community guild. + banner (Optional[str]): Banner hash. + premium_tier (PremiumTier): Premium tier (Server Boost level). + premium_subscription_count (Optional[int]): The number of boosts this guild currently has. + preferred_locale (str): The preferred locale of a Community guild. Defaults to "en-US". + public_updates_channel_id (Optional[str]): The id of the channel where admins and moderators of Community guilds receive notices from Discord. + max_video_channel_users (Optional[int]): The maximum number of users in a video channel. + welcome_screen (Optional[Dict]): The welcome screen of a Community guild. (Consider a WelcomeScreen model) + nsfw_level (GuildNSFWLevel): Guild NSFW level. + stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model) + premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled. + """ + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + self._client: "Client" = client_instance + self.id: str = data["id"] + self.name: str = data["name"] + self.icon: Optional[str] = data.get("icon") + self.splash: Optional[str] = data.get("splash") + self.discovery_splash: Optional[str] = data.get("discovery_splash") + self.owner: Optional[bool] = data.get("owner") + self.owner_id: str = data["owner_id"] + self.permissions: Optional[str] = data.get("permissions") + self.afk_channel_id: Optional[str] = data.get("afk_channel_id") + self.afk_timeout: int = data["afk_timeout"] + self.widget_enabled: Optional[bool] = data.get("widget_enabled") + self.widget_channel_id: Optional[str] = data.get("widget_channel_id") + self.verification_level: VerificationLevel = VerificationLevel( + data["verification_level"] + ) + self.default_message_notifications: MessageNotificationLevel = ( + MessageNotificationLevel(data["default_message_notifications"]) + ) + self.explicit_content_filter: ExplicitContentFilterLevel = ( + ExplicitContentFilterLevel(data["explicit_content_filter"]) + ) + + self.roles: List[Role] = [Role(r) for r in data.get("roles", [])] + self.emojis: List[Emoji] = [ + Emoji(e_data, client_instance) for e_data in data.get("emojis", []) + ] + + # Assuming GuildFeature can be constructed from string feature names or their values + self.features: List[GuildFeature] = [ + GuildFeature(f) if not isinstance(f, GuildFeature) else f + for f in data.get("features", []) + ] + + self.mfa_level: MFALevel = MFALevel(data["mfa_level"]) + self.application_id: Optional[str] = data.get("application_id") + self.system_channel_id: Optional[str] = data.get("system_channel_id") + self.system_channel_flags: int = data["system_channel_flags"] + self.rules_channel_id: Optional[str] = data.get("rules_channel_id") + self.max_members: Optional[int] = data.get("max_members") + self.vanity_url_code: Optional[str] = data.get("vanity_url_code") + self.description: Optional[str] = data.get("description") + self.banner: Optional[str] = data.get("banner") + self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"]) + self.premium_subscription_count: Optional[int] = data.get( + "premium_subscription_count" + ) + self.preferred_locale: str = data.get("preferred_locale", "en-US") + self.public_updates_channel_id: Optional[str] = data.get( + "public_updates_channel_id" + ) + self.max_video_channel_users: Optional[int] = data.get( + "max_video_channel_users" + ) + self.approximate_member_count: Optional[int] = data.get( + "approximate_member_count" + ) + self.approximate_presence_count: Optional[int] = data.get( + "approximate_presence_count" + ) + self.welcome_screen: Optional["WelcomeScreen"] = ( + WelcomeScreen(data["welcome_screen"], client_instance) + if data.get("welcome_screen") + else None + ) + self.nsfw_level: GuildNSFWLevel = GuildNSFWLevel(data["nsfw_level"]) + self.stickers: Optional[List[Sticker]] = ( + [Sticker(s_data, client_instance) for s_data in data.get("stickers", [])] + if data.get("stickers") + else None + ) + self.premium_progress_bar_enabled: bool = data.get( + "premium_progress_bar_enabled", False + ) + + # Internal caches, populated by events or specific fetches + self._channels: ChannelCache = ChannelCache() + self._members: MemberCache = MemberCache(client_instance.member_cache_flags) + self._threads: Dict[str, "Thread"] = {} + + def get_channel(self, channel_id: str) -> Optional["Channel"]: + return self._channels.get(channel_id) + + def get_member(self, user_id: str) -> Optional[Member]: + return self._members.get(user_id) + + def get_member_named(self, name: str) -> Optional[Member]: + """Retrieve a cached member by username or nickname. + + The lookup is case-insensitive and searches both the username and + guild nickname for a match. + + Parameters + ---------- + name: str + The username or nickname to search for. + + Returns + ------- + Optional[Member] + The matching member if found, otherwise ``None``. + """ + + lowered = name.lower() + for member in self._members.values(): + if member.username.lower() == lowered: + return member + if member.nick and member.nick.lower() == lowered: + return member + return None + + def get_role(self, role_id: str) -> Optional[Role]: + return next((role for role in self.roles if role.id == role_id), None) + + def __repr__(self) -> str: + return f"" + async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]: """|coro| @@ -1166,341 +1170,427 @@ class Guild: del self._client._gateway._member_chunk_requests[nonce] raise - -class Channel: - """Base class for Discord channels.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - self._client: "Client" = client_instance - self.id: str = data["id"] - self._type_val: int = int(data["type"]) # Store raw type for enum property - - self.guild_id: Optional[str] = data.get("guild_id") - self.name: Optional[str] = data.get("name") - self.position: Optional[int] = data.get("position") - self.permission_overwrites: List["PermissionOverwrite"] = [ - PermissionOverwrite(d) for d in data.get("permission_overwrites", []) - ] - self.nsfw: Optional[bool] = data.get("nsfw", False) - self.parent_id: Optional[str] = data.get( - "parent_id" - ) # ID of the parent category channel or thread parent - - @property - def type(self) -> ChannelType: - return ChannelType(self._type_val) - - @property - def mention(self) -> str: - return f"<#{self.id}>" - - async def delete(self, reason: Optional[str] = None): - await self._client._http.delete_channel(self.id, reason=reason) - - def __repr__(self) -> str: - return f"" - - def permission_overwrite_for( - self, target: Union["Role", "Member", str] - ) -> Optional["PermissionOverwrite"]: - """Return the :class:`PermissionOverwrite` for ``target`` if present.""" - - if isinstance(target, str): - target_id = target - else: - target_id = target.id - for overwrite in self.permission_overwrites: - if overwrite.id == target_id: - return overwrite - return None - - @staticmethod - def _apply_overwrite( - perms: Permissions, overwrite: Optional["PermissionOverwrite"] - ) -> Permissions: - if overwrite is None: - return perms - - perms &= ~Permissions(int(overwrite.deny)) - perms |= Permissions(int(overwrite.allow)) - return perms - - def permissions_for(self, member: "Member") -> Permissions: - """Resolve channel permissions for ``member``.""" - - if self.guild_id is None: - return Permissions(~0) - - if not hasattr(self._client, "get_guild"): - return Permissions(0) - - guild = self._client.get_guild(self.guild_id) - if guild is None: - return Permissions(0) - - base = Permissions(0) - - everyone = guild.get_role(guild.id) - if everyone is not None: - base |= Permissions(int(everyone.permissions)) - - for rid in member.roles: - role = guild.get_role(rid) - if role is not None: - base |= Permissions(int(role.permissions)) - - if base & Permissions.ADMINISTRATOR: - return Permissions(~0) - - # Apply @everyone overwrite - base = self._apply_overwrite(base, self.permission_overwrite_for(guild.id)) - - # Role overwrites - role_allow = Permissions(0) - role_deny = Permissions(0) - for rid in member.roles: - ow = self.permission_overwrite_for(rid) - if ow is not None: - role_allow |= Permissions(int(ow.allow)) - role_deny |= Permissions(int(ow.deny)) - - base &= ~role_deny - base |= role_allow - - # Member overwrite - base = self._apply_overwrite(base, self.permission_overwrite_for(member.id)) - - return base - - -class TextChannel(Channel): - """Represents a guild text channel or announcement channel.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - super().__init__(data, client_instance) - self.topic: Optional[str] = data.get("topic") - self.last_message_id: Optional[str] = data.get("last_message_id") - self.rate_limit_per_user: Optional[int] = data.get("rate_limit_per_user", 0) - self.default_auto_archive_duration: Optional[int] = data.get( - "default_auto_archive_duration" - ) - self.last_pin_timestamp: Optional[str] = data.get("last_pin_timestamp") - - def history( - self, - *, - limit: Optional[int] = None, - before: Optional[str] = None, - after: Optional[str] = None, - ) -> AsyncIterator["Message"]: - """Return an async iterator over this channel's messages.""" - - from .utils import message_pager - - return message_pager(self, limit=limit, before=before, after=after) - - async def send( - self, - content: Optional[str] = None, - *, - embed: Optional[Embed] = None, - embeds: Optional[List[Embed]] = None, - components: Optional[List["ActionRow"]] = None, # Added components - ) -> "Message": # Forward reference Message - if not hasattr(self._client, "send_message"): - raise NotImplementedError( - "Client.send_message is required for TextChannel.send" - ) - - return await self._client.send_message( - channel_id=self.id, - content=content, - embed=embed, - embeds=embeds, - components=components, - ) - - async def purge( - self, limit: int, *, before: "Snowflake | None" = None - ) -> List["Snowflake"]: - """Bulk delete messages from this channel.""" - - params: Dict[str, Union[int, str]] = {"limit": limit} - if before is not None: - params["before"] = before - - messages = await self._client._http.request( - "GET", f"/channels/{self.id}/messages", params=params - ) - ids = [m["id"] for m in messages] - if not ids: - return [] - - await self._client._http.bulk_delete_messages(self.id, ids) - for mid in ids: - self._client._messages.pop(mid, None) - return ids - - def __repr__(self) -> str: - return f"" - - -class VoiceChannel(Channel): - """Represents a guild voice channel or stage voice channel.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - super().__init__(data, client_instance) - self.bitrate: int = data.get("bitrate", 64000) - self.user_limit: int = data.get("user_limit", 0) - self.rtc_region: Optional[str] = data.get("rtc_region") - self.video_quality_mode: Optional[int] = data.get("video_quality_mode") - - def __repr__(self) -> str: - return f"" - - -class StageChannel(VoiceChannel): - """Represents a guild stage channel.""" - - def __repr__(self) -> str: - return f"" - - async def start_stage_instance( - self, - topic: str, - *, - privacy_level: int = 2, - reason: Optional[str] = None, - guild_scheduled_event_id: Optional[str] = None, - ) -> "StageInstance": - if not hasattr(self._client, "_http"): - raise DisagreementException("Client missing HTTP for stage instance") - - payload: Dict[str, Any] = { - "channel_id": self.id, - "topic": topic, - "privacy_level": privacy_level, - } - if guild_scheduled_event_id is not None: - payload["guild_scheduled_event_id"] = guild_scheduled_event_id - - instance = await self._client._http.start_stage_instance(payload, reason=reason) - instance._client = self._client - return instance - - async def edit_stage_instance( - self, - *, - topic: Optional[str] = None, - privacy_level: Optional[int] = None, - reason: Optional[str] = None, - ) -> "StageInstance": - if not hasattr(self._client, "_http"): - raise DisagreementException("Client missing HTTP for stage instance") - - payload: Dict[str, Any] = {} - if topic is not None: - payload["topic"] = topic - if privacy_level is not None: - payload["privacy_level"] = privacy_level - - instance = await self._client._http.edit_stage_instance( - self.id, payload, reason=reason - ) - instance._client = self._client - return instance - - async def end_stage_instance(self, *, reason: Optional[str] = None) -> None: - if not hasattr(self._client, "_http"): - raise DisagreementException("Client missing HTTP for stage instance") - - await self._client._http.end_stage_instance(self.id, reason=reason) - - -class StageInstance: - """Represents a stage instance.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ) -> None: - self._client = client_instance - self.id: str = data["id"] - self.guild_id: Optional[str] = data.get("guild_id") - self.channel_id: str = data["channel_id"] - self.topic: str = data["topic"] - self.privacy_level: int = data.get("privacy_level", 2) - self.discoverable_disabled: bool = data.get("discoverable_disabled", False) - self.guild_scheduled_event_id: Optional[str] = data.get( - "guild_scheduled_event_id" - ) - - def __repr__(self) -> str: - return f"" - - -class CategoryChannel(Channel): - """Represents a guild category channel.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - super().__init__(data, client_instance) - - @property - def channels(self) -> List[Channel]: - if not self.guild_id or not hasattr(self._client, "get_guild"): - return [] - guild = self._client.get_guild(self.guild_id) - if not guild or not hasattr( - guild, "_channels" - ): # Ensure guild and _channels exist - return [] - - categorized_channels = [ - ch - for ch in guild._channels.values() - if getattr(ch, "parent_id", None) == self.id - ] - return sorted( - categorized_channels, - key=lambda c: c.position if c.position is not None else -1, - ) - - def __repr__(self) -> str: - return f"" - - -class ThreadMetadata: - """Represents the metadata of a thread.""" - - def __init__(self, data: Dict[str, Any]): - self.archived: bool = data["archived"] - self.auto_archive_duration: int = data["auto_archive_duration"] - self.archive_timestamp: str = data["archive_timestamp"] - self.locked: bool = data["locked"] - self.invitable: Optional[bool] = data.get("invitable") - self.create_timestamp: Optional[str] = data.get("create_timestamp") - - -class Thread(TextChannel): # Threads are a specialized TextChannel - """Represents a Discord Thread.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - super().__init__(data, client_instance) # Handles common text channel fields - self.owner_id: Optional[str] = data.get("owner_id") - # parent_id is already handled by base Channel init if present in data - self.message_count: Optional[int] = data.get("message_count") - self.member_count: Optional[int] = data.get("member_count") - self.thread_metadata: ThreadMetadata = ThreadMetadata(data["thread_metadata"]) - self.member: Optional["ThreadMember"] = ( - ThreadMember(data["member"], client_instance) - if data.get("member") - else None - ) - - def __repr__(self) -> str: - return ( - f"" - ) - + +class Channel: + """Base class for Discord channels.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + self._client: "Client" = client_instance + self.id: str = data["id"] + self._type_val: int = int(data["type"]) # Store raw type for enum property + + self.guild_id: Optional[str] = data.get("guild_id") + self.name: Optional[str] = data.get("name") + self.position: Optional[int] = data.get("position") + self.permission_overwrites: List["PermissionOverwrite"] = [ + PermissionOverwrite(d) for d in data.get("permission_overwrites", []) + ] + self.nsfw: Optional[bool] = data.get("nsfw", False) + self.parent_id: Optional[str] = data.get( + "parent_id" + ) # ID of the parent category channel or thread parent + + @property + def type(self) -> ChannelType: + return ChannelType(self._type_val) + + @property + def mention(self) -> str: + return f"<#{self.id}>" + + async def delete(self, reason: Optional[str] = None): + await self._client._http.delete_channel(self.id, reason=reason) + + def __repr__(self) -> str: + return f"" + + def permission_overwrite_for( + self, target: Union["Role", "Member", str] + ) -> Optional["PermissionOverwrite"]: + """Return the :class:`PermissionOverwrite` for ``target`` if present.""" + + if isinstance(target, str): + target_id = target + else: + target_id = target.id + for overwrite in self.permission_overwrites: + if overwrite.id == target_id: + return overwrite + return None + + @staticmethod + def _apply_overwrite( + perms: Permissions, overwrite: Optional["PermissionOverwrite"] + ) -> Permissions: + if overwrite is None: + return perms + + perms &= ~Permissions(int(overwrite.deny)) + perms |= Permissions(int(overwrite.allow)) + return perms + + def permissions_for(self, member: "Member") -> Permissions: + """Resolve channel permissions for ``member``.""" + + if self.guild_id is None: + return Permissions(~0) + + if not hasattr(self._client, "get_guild"): + return Permissions(0) + + guild = self._client.get_guild(self.guild_id) + if guild is None: + return Permissions(0) + + base = Permissions(0) + + everyone = guild.get_role(guild.id) + if everyone is not None: + base |= Permissions(int(everyone.permissions)) + + for rid in member.roles: + role = guild.get_role(rid) + if role is not None: + base |= Permissions(int(role.permissions)) + + if base & Permissions.ADMINISTRATOR: + return Permissions(~0) + + # Apply @everyone overwrite + base = self._apply_overwrite(base, self.permission_overwrite_for(guild.id)) + + # Role overwrites + role_allow = Permissions(0) + role_deny = Permissions(0) + for rid in member.roles: + ow = self.permission_overwrite_for(rid) + if ow is not None: + role_allow |= Permissions(int(ow.allow)) + role_deny |= Permissions(int(ow.deny)) + + base &= ~role_deny + base |= role_allow + + # Member overwrite + base = self._apply_overwrite(base, self.permission_overwrite_for(member.id)) + + return base + + +class TextChannel(Channel): + """Represents a guild text channel or announcement channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + self.topic: Optional[str] = data.get("topic") + self.last_message_id: Optional[str] = data.get("last_message_id") + self.rate_limit_per_user: Optional[int] = data.get("rate_limit_per_user", 0) + self.default_auto_archive_duration: Optional[int] = data.get( + "default_auto_archive_duration" + ) + self.last_pin_timestamp: Optional[str] = data.get("last_pin_timestamp") + + def history( + self, + *, + limit: Optional[int] = None, + before: Optional[str] = None, + after: Optional[str] = None, + ) -> AsyncIterator["Message"]: + """Return an async iterator over this channel's messages.""" + + from .utils import message_pager + + return message_pager(self, limit=limit, before=before, after=after) + + async def send( + self, + content: Optional[str] = None, + *, + embed: Optional[Embed] = None, + embeds: Optional[List[Embed]] = None, + components: Optional[List["ActionRow"]] = None, # Added components + ) -> "Message": # Forward reference Message + if not hasattr(self._client, "send_message"): + raise NotImplementedError( + "Client.send_message is required for TextChannel.send" + ) + + return await self._client.send_message( + channel_id=self.id, + content=content, + embed=embed, + embeds=embeds, + components=components, + ) + + async def purge( + self, limit: int, *, before: "Snowflake | None" = None + ) -> List["Snowflake"]: + """Bulk delete messages from this channel.""" + + params: Dict[str, Union[int, str]] = {"limit": limit} + if before is not None: + params["before"] = before + + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params + ) + ids = [m["id"] for m in messages] + if not ids: + return [] + + await self._client._http.bulk_delete_messages(self.id, ids) + for mid in ids: + self._client._messages.invalidate(mid) + return ids + + def get_partial_message(self, id: int) -> "PartialMessage": + """Returns a :class:`PartialMessage` for the given ID. + + This allows performing actions on a message without fetching it first. + + Parameters + ---------- + id: int + The ID of the message to get a partial instance of. + + Returns + ------- + PartialMessage + The partial message instance. + """ + return PartialMessage(id=str(id), channel=self) + + def __repr__(self) -> str: + return f"" + + async def pins(self) -> List["Message"]: + """|coro| + + Fetches all pinned messages in this channel. + + Returns + ------- + List[Message] + The pinned messages. + + Raises + ------ + HTTPException + Fetching the pinned messages failed. + """ + + messages_data = await self._client._http.get_pinned_messages(self.id) + return [self._client.parse_message(m) for m in messages_data] + + async def create_thread( + self, + name: str, + *, + type: ChannelType = ChannelType.PUBLIC_THREAD, + auto_archive_duration: Optional[int] = None, + invitable: Optional[bool] = None, + rate_limit_per_user: Optional[int] = None, + reason: Optional[str] = None, + ) -> "Thread": + """|coro| + + Creates a new thread in this channel. + + Parameters + ---------- + name: str + The name of the thread. + type: ChannelType + The type of thread to create. Defaults to PUBLIC_THREAD. + Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD. + auto_archive_duration: Optional[int] + The duration in minutes to automatically archive the thread after recent activity. + invitable: Optional[bool] + Whether non-moderators can invite other non-moderators to a private thread. + Only applicable to private threads. + rate_limit_per_user: Optional[int] + The number of seconds a user has to wait before sending another message. + reason: Optional[str] + The reason for creating the thread. + + Returns + ------- + Thread + The created thread. + """ + payload: Dict[str, Any] = { + "name": name, + "type": type.value, + } + if auto_archive_duration is not None: + payload["auto_archive_duration"] = auto_archive_duration + if invitable is not None and type == ChannelType.PRIVATE_THREAD: + payload["invitable"] = invitable + if rate_limit_per_user is not None: + payload["rate_limit_per_user"] = rate_limit_per_user + + data = await self._client._http.start_thread_without_message(self.id, payload) + return cast("Thread", self._client.parse_channel(data)) + + +class VoiceChannel(Channel): + """Represents a guild voice channel or stage voice channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + self.bitrate: int = data.get("bitrate", 64000) + self.user_limit: int = data.get("user_limit", 0) + self.rtc_region: Optional[str] = data.get("rtc_region") + self.video_quality_mode: Optional[int] = data.get("video_quality_mode") + + def __repr__(self) -> str: + return f"" + + +class StageChannel(VoiceChannel): + """Represents a guild stage channel.""" + + def __repr__(self) -> str: + return f"" + + async def start_stage_instance( + self, + topic: str, + *, + privacy_level: int = 2, + reason: Optional[str] = None, + guild_scheduled_event_id: Optional[str] = None, + ) -> "StageInstance": + if not hasattr(self._client, "_http"): + raise DisagreementException("Client missing HTTP for stage instance") + + payload: Dict[str, Any] = { + "channel_id": self.id, + "topic": topic, + "privacy_level": privacy_level, + } + if guild_scheduled_event_id is not None: + payload["guild_scheduled_event_id"] = guild_scheduled_event_id + + instance = await self._client._http.start_stage_instance(payload, reason=reason) + instance._client = self._client + return instance + + async def edit_stage_instance( + self, + *, + topic: Optional[str] = None, + privacy_level: Optional[int] = None, + reason: Optional[str] = None, + ) -> "StageInstance": + if not hasattr(self._client, "_http"): + raise DisagreementException("Client missing HTTP for stage instance") + + payload: Dict[str, Any] = {} + if topic is not None: + payload["topic"] = topic + if privacy_level is not None: + payload["privacy_level"] = privacy_level + + instance = await self._client._http.edit_stage_instance( + self.id, payload, reason=reason + ) + instance._client = self._client + return instance + + async def end_stage_instance(self, *, reason: Optional[str] = None) -> None: + if not hasattr(self._client, "_http"): + raise DisagreementException("Client missing HTTP for stage instance") + + await self._client._http.end_stage_instance(self.id, reason=reason) + + +class StageInstance: + """Represents a stage instance.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ) -> None: + self._client = client_instance + self.id: str = data["id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.channel_id: str = data["channel_id"] + self.topic: str = data["topic"] + self.privacy_level: int = data.get("privacy_level", 2) + self.discoverable_disabled: bool = data.get("discoverable_disabled", False) + self.guild_scheduled_event_id: Optional[str] = data.get( + "guild_scheduled_event_id" + ) + + def __repr__(self) -> str: + return f"" + + +class CategoryChannel(Channel): + """Represents a guild category channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + + @property + def channels(self) -> List[Channel]: + if not self.guild_id or not hasattr(self._client, "get_guild"): + return [] + guild = self._client.get_guild(self.guild_id) + if not guild or not hasattr( + guild, "_channels" + ): # Ensure guild and _channels exist + return [] + + categorized_channels = [ + ch + for ch in guild._channels.values() + if getattr(ch, "parent_id", None) == self.id + ] + return sorted( + categorized_channels, + key=lambda c: c.position if c.position is not None else -1, + ) + + def __repr__(self) -> str: + return f"" + + +class ThreadMetadata: + """Represents the metadata of a thread.""" + + def __init__(self, data: Dict[str, Any]): + self.archived: bool = data["archived"] + self.auto_archive_duration: int = data["auto_archive_duration"] + self.archive_timestamp: str = data["archive_timestamp"] + self.locked: bool = data["locked"] + self.invitable: Optional[bool] = data.get("invitable") + self.create_timestamp: Optional[str] = data.get("create_timestamp") + + +class Thread(TextChannel): # Threads are a specialized TextChannel + """Represents a Discord Thread.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) # Handles common text channel fields + self.owner_id: Optional[str] = data.get("owner_id") + # parent_id is already handled by base Channel init if present in data + self.message_count: Optional[int] = data.get("message_count") + self.member_count: Optional[int] = data.get("member_count") + self.thread_metadata: ThreadMetadata = ThreadMetadata(data["thread_metadata"]) + self.member: Optional["ThreadMember"] = ( + ThreadMember(data["member"], client_instance) + if data.get("member") + else None + ) + + def __repr__(self) -> str: + return ( + f"" + ) + async def join(self) -> None: """|coro|