diff --git a/disagreement/audio.py b/disagreement/audio.py index f70369c..9e58530 100644 --- a/disagreement/audio.py +++ b/disagreement/audio.py @@ -114,3 +114,20 @@ class FFmpegAudioSource(AudioSource): if isinstance(self.source, io.IOBase): with contextlib.suppress(Exception): self.source.close() + +class AudioSink: + """Abstract base class for audio sinks.""" + + def write(self, user, data): + """Write a chunk of PCM audio. + + Subclasses must implement this. The data is raw PCM at 48kHz + stereo. + """ + + raise NotImplementedError + + def close(self) -> None: + """Cleanup the sink when the voice client disconnects.""" + + return None diff --git a/disagreement/cache.py b/disagreement/cache.py index 666d46b..178e8ae 100644 --- a/disagreement/cache.py +++ b/disagreement/cache.py @@ -4,7 +4,8 @@ import time from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar if TYPE_CHECKING: - from .models import Channel, Guild + from .models import Channel, Guild, Member + from .caching import MemberCacheFlags T = TypeVar("T") @@ -53,3 +54,32 @@ class GuildCache(Cache["Guild"]): class ChannelCache(Cache["Channel"]): """Cache specifically for :class:`Channel` objects.""" + + +class MemberCache(Cache["Member"]): + """ + A cache for :class:`Member` objects that respects :class:`MemberCacheFlags`. + """ + + def __init__(self, flags: MemberCacheFlags, ttl: Optional[float] = None) -> None: + super().__init__(ttl) + self.flags = flags + + def _should_cache(self, member: Member) -> bool: + """Determines if a member should be cached based on the flags.""" + if self.flags.all: + return True + if self.flags.none: + return False + + if self.flags.online and member.status != "offline": + return True + if self.flags.voice and member.voice_state is not None: + return True + if self.flags.joined and getattr(member, "_just_joined", False): + return True + return False + + def set(self, key: str, value: Member) -> None: + if self._should_cache(value): + super().set(key, value) diff --git a/disagreement/caching.py b/disagreement/caching.py new file mode 100644 index 0000000..7aca2b8 --- /dev/null +++ b/disagreement/caching.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import operator +from typing import Any, Callable, ClassVar, Dict, Iterator, Tuple + + +class _MemberCacheFlagValue: + flag: int + + def __init__(self, func: Callable[[Any], bool]): + self.flag = getattr(func, 'flag', 0) + self.__doc__ = func.__doc__ + + def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any: + if instance is None: + return self + return instance.value & self.flag != 0 + + def __set__(self, instance: Any, value: bool) -> None: + if value: + instance.value |= self.flag + else: + instance.value &= ~self.flag + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} flag={self.flag}>' + + +def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]: + def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue: + setattr(func, 'flag', flag) + return _MemberCacheFlagValue(func) + return decorator + + +class MemberCacheFlags: + __slots__ = ('value',) + + VALID_FLAGS: ClassVar[Dict[str, int]] = { + 'joined': 1 << 0, + 'voice': 1 << 1, + 'online': 1 << 2, + } + DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4 + ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values()) + + def __init__(self, **kwargs: bool): + self.value = self.DEFAULT_FLAGS + for key, value in kwargs.items(): + if key not in self.VALID_FLAGS: + raise TypeError(f'{key!r} is not a valid member cache flag.') + setattr(self, key, value) + + @classmethod + def _from_value(cls, value: int) -> MemberCacheFlags: + self = cls.__new__(cls) + self.value = value + return self + + def __eq__(self, other: object) -> bool: + return isinstance(other, MemberCacheFlags) and self.value == other.value + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash(self.value) + + def __repr__(self) -> str: + return f'' + + def __iter__(self) -> Iterator[Tuple[str, bool]]: + for name in self.VALID_FLAGS: + yield name, getattr(self, name) + + def __int__(self) -> int: + return self.value + + def __index__(self) -> int: + return self.value + + @classmethod + def all(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with all flags enabled.""" + return cls._from_value(cls.ALL_FLAGS) + + @classmethod + def none(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with all flags disabled.""" + return cls._from_value(0) + + @classmethod + def only_joined(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled.""" + return cls._from_value(cls.VALID_FLAGS['joined']) + + @classmethod + def only_voice(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled.""" + return cls._from_value(cls.VALID_FLAGS['voice']) + + @classmethod + def only_online(cls) -> MemberCacheFlags: + """A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled.""" + return cls._from_value(cls.VALID_FLAGS['online']) + + @flag_value(1 << 0) + def joined(self) -> bool: + """Whether to cache members that have just joined the guild.""" + return False + + @flag_value(1 << 1) + def voice(self) -> bool: + """Whether to cache members that are in a voice channel.""" + return False + + @flag_value(1 << 2) + def online(self) -> bool: + """Whether to cache members that are online.""" + return False \ No newline at end of file diff --git a/disagreement/client.py b/disagreement/client.py index df73b26..23919eb 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -1,584 +1,587 @@ -# disagreement/client.py - -""" -The main Client class for interacting with the Discord API. -""" - -import asyncio -import signal -from typing import ( - Optional, - Callable, - Any, - TYPE_CHECKING, - Awaitable, - AsyncIterator, - Union, - List, - Dict, -) -from types import ModuleType - -from .http import HTTPClient -from .gateway import GatewayClient -from .shard_manager import ShardManager -from .event_dispatcher import EventDispatcher -from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion -from .errors import DisagreementException, AuthenticationError -from .typing import Typing -from .ext.commands.core import CommandHandler -from .ext.commands.cog import Cog -from .ext.app_commands.handler import AppCommandHandler -from .ext.app_commands.context import AppCommandContext -from .ext import loader as ext_loader -from .interactions import Interaction, Snowflake -from .error_handler import setup_global_error_handler -from .voice_client import VoiceClient - -if TYPE_CHECKING: - from .models import ( - Message, - Embed, - ActionRow, - Guild, - Channel, - User, - Member, - Role, - TextChannel, - VoiceChannel, - CategoryChannel, - Thread, - DMChannel, - Webhook, - GuildTemplate, - ScheduledEvent, - AuditLogEntry, - Invite, - ) - from .ui.view import View - from .enums import ChannelType as EnumChannelType - from .ext.commands.core import CommandContext - from .ext.commands.errors import CommandError, CommandInvokeError - from .ext.app_commands.commands import AppCommand, AppCommandGroup - - -class Client: - """ - Represents a client connection that connects to Discord. - This class is used to interact with the Discord WebSocket and API. - - Args: - token (str): The bot token for authentication. - intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`. - You might need to enable privileged intents in your bot's application page. - loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations. - Defaults to `asyncio.get_event_loop()`. - command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]): - The prefix(es) for commands. Defaults to '!'. - verbose (bool): If True, print raw HTTP and Gateway traffic for debugging. - http_options (Optional[Dict[str, Any]]): Extra options passed to - :class:`HTTPClient` for creating the internal - :class:`aiohttp.ClientSession`. - """ - - def __init__( - self, - token: str, - intents: Optional[int] = None, - loop: Optional[asyncio.AbstractEventLoop] = None, - command_prefix: Union[ - str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] - ] = "!", - application_id: Optional[Union[str, int]] = None, - verbose: bool = False, - mention_replies: bool = False, - shard_count: Optional[int] = None, - gateway_max_retries: int = 5, - gateway_max_backoff: float = 60.0, - http_options: Optional[Dict[str, Any]] = None, - ): - if not token: - raise ValueError("A bot token must be provided.") - - self.token: str = token - self.intents: int = intents if intents is not None else GatewayIntent.default() - self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() - self.application_id: Optional[Snowflake] = ( - str(application_id) if application_id else None - ) - setup_global_error_handler(self.loop) - - self.verbose: bool = verbose - self._http: HTTPClient = HTTPClient( - token=self.token, - verbose=verbose, - **(http_options or {}), - ) - self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self) - self._gateway: Optional[GatewayClient] = ( - None # Initialized in run() or connect() - ) - self.shard_count: Optional[int] = shard_count - self.gateway_max_retries: int = gateway_max_retries - self.gateway_max_backoff: float = gateway_max_backoff - self._shard_manager: Optional[ShardManager] = None - - # Initialize CommandHandler - self.command_handler: CommandHandler = CommandHandler( - client=self, prefix=command_prefix - ) - self.app_command_handler: AppCommandHandler = AppCommandHandler(client=self) - # Register internal listener for processing commands from messages - self._event_dispatcher.register( - "MESSAGE_CREATE", self._process_message_for_commands - ) - - self._closed: bool = False - self._ready_event: asyncio.Event = asyncio.Event() - self.user: Optional["User"] = ( - None # The bot's own user object, populated on READY - ) - - # Internal Caches - self._guilds: Dict[Snowflake, "Guild"] = {} - self._channels: Dict[Snowflake, "Channel"] = ( - {} - ) # Stores all channel types by ID - self._users: Dict[Snowflake, Any] = ( - {} - ) # Placeholder for User model cache if needed - self._messages: Dict[Snowflake, "Message"] = {} - self._views: Dict[Snowflake, "View"] = {} - self._voice_clients: Dict[Snowflake, VoiceClient] = {} - self._webhooks: Dict[Snowflake, "Webhook"] = {} - - # Default whether replies mention the user - self.mention_replies: bool = mention_replies - - # Basic signal handling for graceful shutdown - # This might be better handled by the user's application code, but can be a nice default. - # For more robust handling, consider libraries or more advanced patterns. - try: - self.loop.add_signal_handler( - signal.SIGINT, lambda: self.loop.create_task(self.close()) - ) - self.loop.add_signal_handler( - signal.SIGTERM, lambda: self.loop.create_task(self.close()) - ) - except NotImplementedError: - # add_signal_handler is not available on all platforms (e.g., Windows default event loop policy) - # Users on these platforms would need to handle shutdown differently. - print( - "Warning: Signal handlers for SIGINT/SIGTERM could not be added. " - "Graceful shutdown via signals might not work as expected on this platform." - ) - - async def _initialize_gateway(self): - """Initializes the GatewayClient if it doesn't exist.""" - if self._gateway is None: - self._gateway = GatewayClient( - http_client=self._http, - event_dispatcher=self._event_dispatcher, - token=self.token, - intents=self.intents, - client_instance=self, - verbose=self.verbose, - max_retries=self.gateway_max_retries, - max_backoff=self.gateway_max_backoff, - ) - - async def _initialize_shard_manager(self) -> None: - """Initializes the :class:`ShardManager` if not already created.""" - if self._shard_manager is None: - count = self.shard_count or 1 - self._shard_manager = ShardManager(self, count) - - async def connect(self, reconnect: bool = True) -> None: - """ - Establishes a connection to Discord. This includes logging in and connecting to the Gateway. - This method is a coroutine. - - Args: - reconnect (bool): Whether to automatically attempt to reconnect on disconnect. - (Note: Basic reconnect logic is within GatewayClient for now) - - Raises: - GatewayException: If the connection to the gateway fails. - AuthenticationError: If the token is invalid. - """ - if self._closed: - raise DisagreementException("Client is closed and cannot connect.") - if self.shard_count and self.shard_count > 1: - await self._initialize_shard_manager() - assert self._shard_manager is not None - await self._shard_manager.start() - print( - f"Client connected using {self.shard_count} shards, waiting for READY signal..." - ) - await self.wait_until_ready() - print("Client is READY!") - return - - await self._initialize_gateway() - assert self._gateway is not None # Should be initialized by now - - retry_delay = 5 # seconds - max_retries = 5 # For initial connection attempts by Client.run, Gateway has its own internal retries for some cases. - - for attempt in range(max_retries): - try: - await self._gateway.connect() - # After successful connection, GatewayClient's HELLO handler will trigger IDENTIFY/RESUME - # and its READY handler will set self._ready_event via dispatcher. - print("Client connected to Gateway, waiting for READY signal...") - await self.wait_until_ready() # Wait for the READY event from Gateway - print("Client is READY!") - return # Successfully connected and ready - except AuthenticationError: # Non-recoverable by retry here - print("Authentication failed. Please check your bot token.") - await self.close() # Ensure cleanup - raise - except DisagreementException as e: # Includes GatewayException - print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}") - if attempt < max_retries - 1: - print(f"Retrying in {retry_delay} seconds...") - await asyncio.sleep(retry_delay) - retry_delay = min( - retry_delay * 2, 60 - ) # Exponential backoff up to 60s - else: - print("Max connection retries reached. Giving up.") - await self.close() # Ensure cleanup - raise - # Should not be reached if max_retries is > 0 - if max_retries == 0: # If max_retries was 0, means no retries attempted - raise DisagreementException("Connection failed with 0 retries allowed.") - - async def run(self) -> None: - """ - A blocking call that connects the client to Discord and runs until the client is closed. - This method is a coroutine. - It handles login, Gateway connection, and keeping the connection alive. - """ - if self._closed: - raise DisagreementException("Client is already closed.") - - try: - await self.connect() - # The GatewayClient's _receive_loop will keep running. - # This run method effectively waits until the client is closed or an unhandled error occurs. - # A more robust implementation might have a main loop here that monitors gateway health. - # For now, we rely on the gateway's tasks. - while not self._closed: - if ( - self._gateway - and self._gateway._receive_task - and self._gateway._receive_task.done() - ): - # If receive task ended unexpectedly, try to handle it or re-raise - try: - exc = self._gateway._receive_task.exception() - if exc: - print( - f"Gateway receive task ended with exception: {exc}. Attempting to reconnect..." - ) - # This is a basic reconnect strategy from the client side. - # GatewayClient itself might handle some reconnects. - await self.close_gateway( - code=1000 - ) # Close current gateway state - await asyncio.sleep(5) # Wait before reconnecting - if ( - not self._closed - ): # If client wasn't closed by the exception handler - await self.connect() - else: - break # Client was closed, exit run loop - else: - print( - "Gateway receive task ended without exception. Assuming clean shutdown or reconnect handled internally." - ) - if ( - not self._closed - ): # If not explicitly closed, might be an issue - print( - "Warning: Gateway receive task ended but client not closed. This might indicate an issue." - ) - # Consider a more robust health check or reconnect strategy here. - await asyncio.sleep( - 1 - ) # Prevent tight loop if something is wrong - else: - break # Client was closed - except asyncio.CancelledError: - print("Gateway receive task was cancelled.") - break # Exit if cancelled - except Exception as e: - print(f"Error checking gateway receive task: {e}") - break # Exit on other errors - await asyncio.sleep(1) # Main loop check interval - except DisagreementException as e: - print(f"Client run loop encountered an error: {e}") - # Error already logged by connect or other methods - except asyncio.CancelledError: - print("Client run loop was cancelled.") - finally: - if not self._closed: - await self.close() - - async def close(self) -> None: - """ - Closes the connection to Discord. This method is a coroutine. - """ - if self._closed: - return - - self._closed = True - print("Closing client...") - - if self._shard_manager: - await self._shard_manager.close() - self._shard_manager = None - if self._gateway: - await self._gateway.close() - - if self._http: # HTTPClient has its own session to close - await self._http.close() - - self._ready_event.set() # Ensure any waiters for ready are unblocked - print("Client closed.") - - async def __aenter__(self) -> "Client": - """Enter the context manager by connecting to Discord.""" - await self.connect() - return self - - async def __aexit__( - self, - exc_type: Optional[type], - exc: Optional[BaseException], - tb: Optional[BaseException], - ) -> bool: - """Exit the context manager and close the client.""" - await self.close() - return False - - async def close_gateway(self, code: int = 1000) -> None: - """Closes only the gateway connection, allowing for potential reconnect.""" - if self._shard_manager: - await self._shard_manager.close() - self._shard_manager = None - if self._gateway: - await self._gateway.close(code=code) - self._gateway = None - self._ready_event.clear() # No longer ready if gateway is closed - - def is_closed(self) -> bool: - """Indicates if the client has been closed.""" - return self._closed - - def is_ready(self) -> bool: - """Indicates if the client has successfully connected to the Gateway and is ready.""" - return self._ready_event.is_set() - - @property - def latency(self) -> Optional[float]: - """Returns the gateway latency in seconds, or ``None`` if unavailable.""" - if self._gateway: - return self._gateway.latency - return None - - async def wait_until_ready(self) -> None: - """|coro| - Waits until the client is fully connected to Discord and the initial state is processed. - This is mainly useful for waiting for the READY event from the Gateway. - """ - await self._ready_event.wait() - - async def wait_for( - self, - event_name: str, - check: Optional[Callable[[Any], bool]] = None, - timeout: Optional[float] = None, - ) -> Any: - """|coro| - Waits for a specific event to occur that satisfies the ``check``. - - Parameters - ---------- - event_name: str - The name of the event to wait for. - check: Optional[Callable[[Any], bool]] - A function that determines whether the received event should resolve the wait. - timeout: Optional[float] - How long to wait for the event before raising :class:`asyncio.TimeoutError`. - """ - - future: asyncio.Future = self.loop.create_future() - self._event_dispatcher.add_waiter(event_name, future, check) - try: - return await asyncio.wait_for(future, timeout=timeout) - finally: - self._event_dispatcher.remove_waiter(event_name, future) - - async def change_presence( - self, - status: str, - activity_name: Optional[str] = None, - activity_type: int = 0, - since: int = 0, - afk: bool = False, - ): - """ - Changes the client's presence on Discord. - - Args: - status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible"). - activity_name (Optional[str]): The name of the activity. - activity_type (int): The type of the activity. - since (int): The timestamp (in milliseconds) of when the client went idle. - afk (bool): Whether the client is AFK. - """ - if self._closed: - raise DisagreementException("Client is closed.") - - if self._gateway: - await self._gateway.update_presence( - status=status, - activity_name=activity_name, - activity_type=activity_type, - since=since, - afk=afk, - ) - - # --- Event Handling --- - - def event( - self, coro: Callable[..., Awaitable[None]] - ) -> Callable[..., Awaitable[None]]: - """ - A decorator that registers an event to listen to. - The name of the coroutine is used as the event name. - Example: - @client.event - async def on_ready(): # Will listen for the 'READY' event - print("Bot is ready!") - - @client.event - async def on_message(message: disagreement.Message): # Will listen for 'MESSAGE_CREATE' - print(f"Message from {message.author}: {message.content}") - """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError("Event registered must be a coroutine function.") - - event_name = coro.__name__ - # Map common function names to Discord event types - # e.g., on_ready -> READY, on_message -> MESSAGE_CREATE - if event_name.startswith("on_"): - discord_event_name = event_name[3:].upper() - mapping = { - "MESSAGE": "MESSAGE_CREATE", - "MESSAGE_EDIT": "MESSAGE_UPDATE", - "MESSAGE_UPDATE": "MESSAGE_UPDATE", - "MESSAGE_DELETE": "MESSAGE_DELETE", - "REACTION_ADD": "MESSAGE_REACTION_ADD", - "REACTION_REMOVE": "MESSAGE_REACTION_REMOVE", - } - discord_event_name = mapping.get(discord_event_name, discord_event_name) - self._event_dispatcher.register(discord_event_name, coro) - else: - # If not starting with "on_", assume it's the direct Discord event name (e.g. "TYPING_START") - # Or raise an error if a specific format is required. - # For now, let's assume direct mapping if no "on_" prefix. - self._event_dispatcher.register(event_name.upper(), coro) - - return coro # Return the original coroutine - - def on_event( - self, event_name: str - ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: - """ - A decorator that registers an event to listen to with a specific event name. - Example: - @client.on_event('MESSAGE_CREATE') - async def my_message_handler(message: disagreement.Message): - print(f"Message: {message.content}") - """ - - def decorator( - coro: Callable[..., Awaitable[None]], - ) -> Callable[..., Awaitable[None]]: - if not asyncio.iscoroutinefunction(coro): - raise TypeError("Event registered must be a coroutine function.") - self._event_dispatcher.register(event_name.upper(), coro) - return coro - - return decorator - - async def _process_message_for_commands(self, message: "Message") -> None: - """Internal listener to process messages for commands.""" - # Make sure message object is valid and not from a bot (optional, common check) - if ( - not message or not message.author or message.author.bot - ): # Add .bot check to User model - return - await self.command_handler.process_commands(message) - - # --- Command Framework Methods --- - - def add_cog(self, cog: Cog) -> None: - """ - Adds a Cog to the bot. - Cogs are classes that group commands, listeners, and state. - This will also discover and register any application commands defined in the cog. - - Args: - cog (Cog): An instance of a class derived from `disagreement.ext.commands.Cog`. - """ - # Add to prefix command handler - self.command_handler.add_cog( - cog - ) # This should call cog._inject() internally or cog._inject() is called on Cog init - - # Discover and add application commands from the cog - # AppCommand and AppCommandGroup are already imported in TYPE_CHECKING block - for app_cmd_obj in cog.get_app_commands_and_groups(): # Uses the new method - # The cog attribute should have been set within Cog._inject() for AppCommands - self.app_command_handler.add_command(app_cmd_obj) - print( - f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'." - ) - - def remove_cog(self, cog_name: str) -> Optional[Cog]: - """ - Removes a Cog from the bot. - - Args: - cog_name (str): The name of the Cog to remove. - - Returns: - Optional[Cog]: The Cog that was removed, or None if not found. - """ - removed_cog = self.command_handler.remove_cog(cog_name) - if removed_cog: - # Also remove associated application commands - # This requires AppCommand to store a reference to its cog, or iterate all app_commands. - # Assuming AppCommand has a .cog attribute, which is set in Cog._inject() - # And AppCommandGroup might store commands that have .cog attribute - for app_cmd_or_group in removed_cog.get_app_commands_and_groups(): - # The AppCommandHandler.remove_command needs to handle both AppCommand and AppCommandGroup - self.app_command_handler.remove_command( - app_cmd_or_group.name - ) # Assuming name is unique enough for removal here - print( - f"Removed app command/group '{app_cmd_or_group.name}' from cog '{cog_name}'." - ) - # Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique - # (e.g. if it needs type or if groups and commands can share names). - # For now, assuming name is sufficient for removal from the handler's flat list. - return removed_cog - +# disagreement/client.py + +""" +The main Client class for interacting with the Discord API. +""" + +import asyncio +import signal +from typing import ( + Optional, + Callable, + Any, + TYPE_CHECKING, + Awaitable, + AsyncIterator, + Union, + List, + Dict, +) +from types import ModuleType + +from .http import HTTPClient +from .gateway import GatewayClient +from .shard_manager import ShardManager +from .event_dispatcher import EventDispatcher +from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion +from .errors import DisagreementException, AuthenticationError +from .typing import Typing +from .caching import MemberCacheFlags +from .cache import Cache, GuildCache, ChannelCache, MemberCache +from .ext.commands.core import Command, CommandHandler, Group +from .ext.commands.cog import Cog +from .ext.app_commands.handler import AppCommandHandler +from .ext.app_commands.context import AppCommandContext +from .ext import loader as ext_loader +from .interactions import Interaction, Snowflake +from .error_handler import setup_global_error_handler +from .voice_client import VoiceClient + +if TYPE_CHECKING: + from .models import ( + Message, + Embed, + ActionRow, + Guild, + Channel, + User, + Member, + Role, + TextChannel, + VoiceChannel, + CategoryChannel, + Thread, + DMChannel, + Webhook, + GuildTemplate, + ScheduledEvent, + AuditLogEntry, + Invite, + ) + from .ui.view import View + from .enums import ChannelType as EnumChannelType + from .ext.commands.core import CommandContext + from .ext.commands.errors import CommandError, CommandInvokeError + from .ext.app_commands.commands import AppCommand, AppCommandGroup + + +class Client: + """ + Represents a client connection that connects to Discord. + This class is used to interact with the Discord WebSocket and API. + + Args: + token (str): The bot token for authentication. + intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`. + You might need to enable privileged intents in your bot's application page. + loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations. + Defaults to `asyncio.get_event_loop()`. + command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]): + The prefix(es) for commands. Defaults to '!'. + verbose (bool): If True, print raw HTTP and Gateway traffic for debugging. + http_options (Optional[Dict[str, Any]]): Extra options passed to + :class:`HTTPClient` for creating the internal + :class:`aiohttp.ClientSession`. + """ + + def __init__( + self, + token: str, + intents: Optional[int] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + command_prefix: Union[ + str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] + ] = "!", + application_id: Optional[Union[str, int]] = None, + verbose: bool = False, + mention_replies: bool = False, + shard_count: Optional[int] = None, + gateway_max_retries: int = 5, + gateway_max_backoff: float = 60.0, + member_cache_flags: Optional[MemberCacheFlags] = None, + http_options: Optional[Dict[str, Any]] = None, + ): + if not token: + raise ValueError("A bot token must be provided.") + + self.token: str = token + self.member_cache_flags: MemberCacheFlags = ( + member_cache_flags if member_cache_flags is not None else MemberCacheFlags() + ) + self.intents: int = intents if intents is not None else GatewayIntent.default() + self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + self.application_id: Optional[Snowflake] = ( + str(application_id) if application_id else None + ) + setup_global_error_handler(self.loop) + + self.verbose: bool = verbose + self._http: HTTPClient = HTTPClient( + token=self.token, + verbose=verbose, + **(http_options or {}), + ) + self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self) + self._gateway: Optional[GatewayClient] = ( + None # Initialized in run() or connect() + ) + self.shard_count: Optional[int] = shard_count + self.gateway_max_retries: int = gateway_max_retries + self.gateway_max_backoff: float = gateway_max_backoff + self._shard_manager: Optional[ShardManager] = None + + # Initialize CommandHandler + self.command_handler: CommandHandler = CommandHandler( + client=self, prefix=command_prefix + ) + self.app_command_handler: AppCommandHandler = AppCommandHandler(client=self) + # Register internal listener for processing commands from messages + self._event_dispatcher.register( + "MESSAGE_CREATE", self._process_message_for_commands + ) + + self._closed: bool = False + self._ready_event: asyncio.Event = asyncio.Event() + self.user: Optional["User"] = ( + None # The bot's own user object, populated on READY + ) + + # Internal Caches + self._guilds: GuildCache = GuildCache() + self._channels: ChannelCache = ChannelCache() + self._users: Cache["User"] = Cache() + self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour + self._views: Dict[Snowflake, "View"] = {} + self._persistent_views: Dict[str, "View"] = {} + self._voice_clients: Dict[Snowflake, VoiceClient] = {} + self._webhooks: Dict[Snowflake, "Webhook"] = {} + + # Default whether replies mention the user + self.mention_replies: bool = mention_replies + + # Basic signal handling for graceful shutdown + # This might be better handled by the user's application code, but can be a nice default. + # For more robust handling, consider libraries or more advanced patterns. + try: + self.loop.add_signal_handler( + signal.SIGINT, lambda: self.loop.create_task(self.close()) + ) + self.loop.add_signal_handler( + signal.SIGTERM, lambda: self.loop.create_task(self.close()) + ) + except NotImplementedError: + # add_signal_handler is not available on all platforms (e.g., Windows default event loop policy) + # Users on these platforms would need to handle shutdown differently. + print( + "Warning: Signal handlers for SIGINT/SIGTERM could not be added. " + "Graceful shutdown via signals might not work as expected on this platform." + ) + + async def _initialize_gateway(self): + """Initializes the GatewayClient if it doesn't exist.""" + if self._gateway is None: + self._gateway = GatewayClient( + http_client=self._http, + event_dispatcher=self._event_dispatcher, + token=self.token, + intents=self.intents, + client_instance=self, + verbose=self.verbose, + max_retries=self.gateway_max_retries, + max_backoff=self.gateway_max_backoff, + ) + + async def _initialize_shard_manager(self) -> None: + """Initializes the :class:`ShardManager` if not already created.""" + if self._shard_manager is None: + count = self.shard_count or 1 + self._shard_manager = ShardManager(self, count) + + async def connect(self, reconnect: bool = True) -> None: + """ + Establishes a connection to Discord. This includes logging in and connecting to the Gateway. + This method is a coroutine. + + Args: + reconnect (bool): Whether to automatically attempt to reconnect on disconnect. + (Note: Basic reconnect logic is within GatewayClient for now) + + Raises: + GatewayException: If the connection to the gateway fails. + AuthenticationError: If the token is invalid. + """ + if self._closed: + raise DisagreementException("Client is closed and cannot connect.") + if self.shard_count and self.shard_count > 1: + await self._initialize_shard_manager() + assert self._shard_manager is not None + await self._shard_manager.start() + print( + f"Client connected using {self.shard_count} shards, waiting for READY signal..." + ) + await self.wait_until_ready() + print("Client is READY!") + return + + await self._initialize_gateway() + assert self._gateway is not None # Should be initialized by now + + retry_delay = 5 # seconds + max_retries = 5 # For initial connection attempts by Client.run, Gateway has its own internal retries for some cases. + + for attempt in range(max_retries): + try: + await self._gateway.connect() + # After successful connection, GatewayClient's HELLO handler will trigger IDENTIFY/RESUME + # and its READY handler will set self._ready_event via dispatcher. + print("Client connected to Gateway, waiting for READY signal...") + await self.wait_until_ready() # Wait for the READY event from Gateway + print("Client is READY!") + return # Successfully connected and ready + except AuthenticationError: # Non-recoverable by retry here + print("Authentication failed. Please check your bot token.") + await self.close() # Ensure cleanup + raise + except DisagreementException as e: # Includes GatewayException + print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}") + if attempt < max_retries - 1: + print(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + retry_delay = min( + retry_delay * 2, 60 + ) # Exponential backoff up to 60s + else: + print("Max connection retries reached. Giving up.") + await self.close() # Ensure cleanup + raise + # Should not be reached if max_retries is > 0 + if max_retries == 0: # If max_retries was 0, means no retries attempted + raise DisagreementException("Connection failed with 0 retries allowed.") + + async def run(self) -> None: + """ + A blocking call that connects the client to Discord and runs until the client is closed. + This method is a coroutine. + It handles login, Gateway connection, and keeping the connection alive. + """ + if self._closed: + raise DisagreementException("Client is already closed.") + + try: + await self.connect() + # The GatewayClient's _receive_loop will keep running. + # This run method effectively waits until the client is closed or an unhandled error occurs. + # A more robust implementation might have a main loop here that monitors gateway health. + # For now, we rely on the gateway's tasks. + while not self._closed: + if ( + self._gateway + and self._gateway._receive_task + and self._gateway._receive_task.done() + ): + # If receive task ended unexpectedly, try to handle it or re-raise + try: + exc = self._gateway._receive_task.exception() + if exc: + print( + f"Gateway receive task ended with exception: {exc}. Attempting to reconnect..." + ) + # This is a basic reconnect strategy from the client side. + # GatewayClient itself might handle some reconnects. + await self.close_gateway( + code=1000 + ) # Close current gateway state + await asyncio.sleep(5) # Wait before reconnecting + if ( + not self._closed + ): # If client wasn't closed by the exception handler + await self.connect() + else: + break # Client was closed, exit run loop + else: + print( + "Gateway receive task ended without exception. Assuming clean shutdown or reconnect handled internally." + ) + if ( + not self._closed + ): # If not explicitly closed, might be an issue + print( + "Warning: Gateway receive task ended but client not closed. This might indicate an issue." + ) + # Consider a more robust health check or reconnect strategy here. + await asyncio.sleep( + 1 + ) # Prevent tight loop if something is wrong + else: + break # Client was closed + except asyncio.CancelledError: + print("Gateway receive task was cancelled.") + break # Exit if cancelled + except Exception as e: + print(f"Error checking gateway receive task: {e}") + break # Exit on other errors + await asyncio.sleep(1) # Main loop check interval + except DisagreementException as e: + print(f"Client run loop encountered an error: {e}") + # Error already logged by connect or other methods + except asyncio.CancelledError: + print("Client run loop was cancelled.") + finally: + if not self._closed: + await self.close() + + async def close(self) -> None: + """ + Closes the connection to Discord. This method is a coroutine. + """ + if self._closed: + return + + self._closed = True + print("Closing client...") + + if self._shard_manager: + await self._shard_manager.close() + self._shard_manager = None + if self._gateway: + await self._gateway.close() + + if self._http: # HTTPClient has its own session to close + await self._http.close() + + self._ready_event.set() # Ensure any waiters for ready are unblocked + print("Client closed.") + + async def __aenter__(self) -> "Client": + """Enter the context manager by connecting to Discord.""" + await self.connect() + return self + + async def __aexit__( + self, + exc_type: Optional[type], + exc: Optional[BaseException], + tb: Optional[BaseException], + ) -> bool: + """Exit the context manager and close the client.""" + await self.close() + return False + + async def close_gateway(self, code: int = 1000) -> None: + """Closes only the gateway connection, allowing for potential reconnect.""" + if self._shard_manager: + await self._shard_manager.close() + self._shard_manager = None + if self._gateway: + await self._gateway.close(code=code) + self._gateway = None + self._ready_event.clear() # No longer ready if gateway is closed + + def is_closed(self) -> bool: + """Indicates if the client has been closed.""" + return self._closed + + def is_ready(self) -> bool: + """Indicates if the client has successfully connected to the Gateway and is ready.""" + return self._ready_event.is_set() + + @property + def latency(self) -> Optional[float]: + """Returns the gateway latency in seconds, or ``None`` if unavailable.""" + if self._gateway: + return self._gateway.latency + return None + + async def wait_until_ready(self) -> None: + """|coro| + Waits until the client is fully connected to Discord and the initial state is processed. + This is mainly useful for waiting for the READY event from the Gateway. + """ + await self._ready_event.wait() + + async def wait_for( + self, + event_name: str, + check: Optional[Callable[[Any], bool]] = None, + timeout: Optional[float] = None, + ) -> Any: + """|coro| + Waits for a specific event to occur that satisfies the ``check``. + + Parameters + ---------- + event_name: str + The name of the event to wait for. + check: Optional[Callable[[Any], bool]] + A function that determines whether the received event should resolve the wait. + timeout: Optional[float] + How long to wait for the event before raising :class:`asyncio.TimeoutError`. + """ + + future: asyncio.Future = self.loop.create_future() + self._event_dispatcher.add_waiter(event_name, future, check) + try: + return await asyncio.wait_for(future, timeout=timeout) + finally: + self._event_dispatcher.remove_waiter(event_name, future) + + async def change_presence( + self, + status: str, + activity_name: Optional[str] = None, + activity_type: int = 0, + since: int = 0, + afk: bool = False, + ): + """ + Changes the client's presence on Discord. + + Args: + status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible"). + activity_name (Optional[str]): The name of the activity. + activity_type (int): The type of the activity. + since (int): The timestamp (in milliseconds) of when the client went idle. + afk (bool): Whether the client is AFK. + """ + if self._closed: + raise DisagreementException("Client is closed.") + + if self._gateway: + await self._gateway.update_presence( + status=status, + activity_name=activity_name, + activity_type=activity_type, + since=since, + afk=afk, + ) + + # --- Event Handling --- + + def event( + self, coro: Callable[..., Awaitable[None]] + ) -> Callable[..., Awaitable[None]]: + """ + A decorator that registers an event to listen to. + The name of the coroutine is used as the event name. + Example: + @client.event + async def on_ready(): # Will listen for the 'READY' event + print("Bot is ready!") + + @client.event + async def on_message(message: disagreement.Message): # Will listen for 'MESSAGE_CREATE' + print(f"Message from {message.author}: {message.content}") + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("Event registered must be a coroutine function.") + + event_name = coro.__name__ + # Map common function names to Discord event types + # e.g., on_ready -> READY, on_message -> MESSAGE_CREATE + if event_name.startswith("on_"): + discord_event_name = event_name[3:].upper() + mapping = { + "MESSAGE": "MESSAGE_CREATE", + "MESSAGE_EDIT": "MESSAGE_UPDATE", + "MESSAGE_UPDATE": "MESSAGE_UPDATE", + "MESSAGE_DELETE": "MESSAGE_DELETE", + "REACTION_ADD": "MESSAGE_REACTION_ADD", + "REACTION_REMOVE": "MESSAGE_REACTION_REMOVE", + } + discord_event_name = mapping.get(discord_event_name, discord_event_name) + self._event_dispatcher.register(discord_event_name, coro) + else: + # If not starting with "on_", assume it's the direct Discord event name (e.g. "TYPING_START") + # Or raise an error if a specific format is required. + # For now, let's assume direct mapping if no "on_" prefix. + self._event_dispatcher.register(event_name.upper(), coro) + + return coro # Return the original coroutine + + def on_event( + self, event_name: str + ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """ + A decorator that registers an event to listen to with a specific event name. + Example: + @client.on_event('MESSAGE_CREATE') + async def my_message_handler(message: disagreement.Message): + print(f"Message: {message.content}") + """ + + def decorator( + coro: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + if not asyncio.iscoroutinefunction(coro): + raise TypeError("Event registered must be a coroutine function.") + self._event_dispatcher.register(event_name.upper(), coro) + return coro + + return decorator + + async def _process_message_for_commands(self, message: "Message") -> None: + """Internal listener to process messages for commands.""" + # Make sure message object is valid and not from a bot (optional, common check) + if ( + not message or not message.author or message.author.bot + ): # Add .bot check to User model + return + await self.command_handler.process_commands(message) + + # --- Command Framework Methods --- + + def add_cog(self, cog: Cog) -> None: + """ + Adds a Cog to the bot. + Cogs are classes that group commands, listeners, and state. + This will also discover and register any application commands defined in the cog. + + Args: + cog (Cog): An instance of a class derived from `disagreement.ext.commands.Cog`. + """ + # Add to prefix command handler + self.command_handler.add_cog( + cog + ) # This should call cog._inject() internally or cog._inject() is called on Cog init + + # Discover and add application commands from the cog + # AppCommand and AppCommandGroup are already imported in TYPE_CHECKING block + for app_cmd_obj in cog.get_app_commands_and_groups(): # Uses the new method + # The cog attribute should have been set within Cog._inject() for AppCommands + self.app_command_handler.add_command(app_cmd_obj) + print( + f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'." + ) + + def remove_cog(self, cog_name: str) -> Optional[Cog]: + """ + Removes a Cog from the bot. + + Args: + cog_name (str): The name of the Cog to remove. + + Returns: + Optional[Cog]: The Cog that was removed, or None if not found. + """ + removed_cog = self.command_handler.remove_cog(cog_name) + if removed_cog: + # Also remove associated application commands + # This requires AppCommand to store a reference to its cog, or iterate all app_commands. + # Assuming AppCommand has a .cog attribute, which is set in Cog._inject() + # And AppCommandGroup might store commands that have .cog attribute + for app_cmd_or_group in removed_cog.get_app_commands_and_groups(): + # The AppCommandHandler.remove_command needs to handle both AppCommand and AppCommandGroup + self.app_command_handler.remove_command( + app_cmd_or_group.name + ) # Assuming name is unique enough for removal here + print( + f"Removed app command/group '{app_cmd_or_group.name}' from cog '{cog_name}'." + ) + # Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique + # (e.g. if it needs type or if groups and commands can share names). + # For now, assuming name is sufficient for removal from the handler's flat list. + return removed_cog + def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]): """ A decorator that adds a global check to the bot. @@ -693,804 +696,816 @@ class Client: ctx (CommandContext): The context of the command that completed. """ pass - - # --- Extension Management Methods --- - - def load_extension(self, name: str) -> ModuleType: - """Load an extension by name using :mod:`disagreement.ext.loader`.""" - - return ext_loader.load_extension(name) - - def unload_extension(self, name: str) -> None: - """Unload a previously loaded extension.""" - - ext_loader.unload_extension(name) - - def reload_extension(self, name: str) -> ModuleType: - """Reload an extension by name.""" - - return ext_loader.reload_extension(name) - - # --- Model Parsing and Fetching --- - - def parse_user(self, data: Dict[str, Any]) -> "User": - """Parses user data and returns a User object, updating cache.""" - from .models import User # Ensure User model is available - - user = User(data) - self._users[user.id] = user # Cache the user - return user - - def parse_channel(self, data: Dict[str, Any]) -> "Channel": - """Parses channel data and returns a Channel object, updating caches.""" - - from .models import channel_factory - - channel = channel_factory(data, self) - self._channels[channel.id] = channel - if channel.guild_id: - guild = self._guilds.get(channel.guild_id) - if guild: - guild._channels[channel.id] = channel - return channel - - def parse_message(self, data: Dict[str, Any]) -> "Message": - """Parses message data and returns a Message object, updating cache.""" - - from .models import Message - - message = Message(data, client_instance=self) - self._messages[message.id] = message - return message - - def parse_webhook(self, data: Union[Dict[str, Any], "Webhook"]) -> "Webhook": - """Parses webhook data and returns a Webhook object, updating cache.""" - - from .models import Webhook - - if isinstance(data, Webhook): - webhook = data - webhook._client = self # type: ignore[attr-defined] - else: - webhook = Webhook(data, client_instance=self) - self._webhooks[webhook.id] = webhook - return webhook - - def parse_template(self, data: Dict[str, Any]) -> "GuildTemplate": - """Parses template data into a GuildTemplate object.""" - - from .models import GuildTemplate - - return GuildTemplate(data, client_instance=self) - - def parse_scheduled_event(self, data: Dict[str, Any]) -> "ScheduledEvent": - """Parses scheduled event data and updates cache.""" - - from .models import ScheduledEvent - - event = ScheduledEvent(data, client_instance=self) - # Cache by ID under guild if guild cache exists - guild = self._guilds.get(event.guild_id) - if guild is not None: - events = getattr(guild, "_scheduled_events", {}) - events[event.id] = event - setattr(guild, "_scheduled_events", events) - return event - - def parse_audit_log_entry(self, data: Dict[str, Any]) -> "AuditLogEntry": - """Parses audit log entry data.""" - from .models import AuditLogEntry - - return AuditLogEntry(data, client_instance=self) - - def parse_invite(self, data: Dict[str, Any]) -> "Invite": - """Parses invite data into an :class:`Invite`.""" - - from .models import Invite - - return Invite.from_dict(data) - - async def fetch_user(self, user_id: Snowflake) -> Optional["User"]: - """Fetches a user by ID from Discord.""" - if self._closed: - raise DisagreementException("Client is closed.") - - cached_user = self._users.get(user_id) - if cached_user: - return cached_user # Return cached if available, though fetch implies wanting fresh - - try: - user_data = await self._http.get_user(user_id) - return self.parse_user(user_data) - except DisagreementException as e: # Catch HTTP exceptions from http client - print(f"Failed to fetch user {user_id}: {e}") - return None - - async def fetch_message( - self, channel_id: Snowflake, message_id: Snowflake - ) -> Optional["Message"]: - """Fetches a message by ID from Discord and caches it.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - cached_message = self._messages.get(message_id) - if cached_message: - return cached_message - - try: - message_data = await self._http.get_message(channel_id, message_id) - return self.parse_message(message_data) - except DisagreementException as e: - print( - f"Failed to fetch message {message_id} from channel {channel_id}: {e}" - ) - return None - - def parse_member(self, data: Dict[str, Any], guild_id: Snowflake) -> "Member": - """Parses member data and returns a Member object, updating relevant caches.""" - from .models import Member # Ensure Member model is available - - # Member's __init__ should handle the nested 'user' data. - member = Member(data, client_instance=self) - member.guild_id = str(guild_id) - - # Cache the member in the guild's member cache - guild = self._guilds.get(guild_id) - if guild: - guild._members[member.id] = member # Assuming Guild has _members dict - - # Also cache the user part if not already cached or if this is newer - # Since Member inherits from User, the member object itself is the user. - self._users[member.id] = member - # If 'user' was in data and Member.__init__ used it, it's already part of 'member'. - return member - - async def fetch_member( - self, guild_id: Snowflake, member_id: Snowflake - ) -> Optional["Member"]: - """Fetches a member from a guild by ID.""" - if self._closed: - raise DisagreementException("Client is closed.") - - guild = self.get_guild(guild_id) - if guild: - cached_member = guild.get_member(member_id) # Use Guild's get_member - if cached_member: - return cached_member # Return cached if available - - try: - member_data = await self._http.get_guild_member(guild_id, member_id) - return self.parse_member(member_data, guild_id) - except DisagreementException as e: - print(f"Failed to fetch member {member_id} from guild {guild_id}: {e}") - return None - - def parse_role(self, data: Dict[str, Any], guild_id: Snowflake) -> "Role": - """Parses role data and returns a Role object, updating guild's role cache.""" - from .models import Role # Ensure Role model is available - - role = Role(data) - guild = self._guilds.get(guild_id) - if guild: - # Update the role in the guild's roles list if it exists, or add it. - # Guild.roles is List[Role]. We need to find and replace or append. - found = False - for i, existing_role in enumerate(guild.roles): - if existing_role.id == role.id: - guild.roles[i] = role - found = True - break - if not found: - guild.roles.append(role) - return role - - def parse_guild(self, data: Dict[str, Any]) -> "Guild": - """Parses guild data and returns a Guild object, updating cache.""" - - from .models import Guild - - guild = Guild(data, client_instance=self) - self._guilds[guild.id] = guild - - # Populate channel and member caches if provided - for ch in data.get("channels", []): - channel_obj = self.parse_channel(ch) - guild._channels[channel_obj.id] = channel_obj - - for member in data.get("members", []): - member_obj = self.parse_member(member, guild.id) - guild._members[member_obj.id] = member_obj - - return guild - - async def fetch_roles(self, guild_id: Snowflake) -> List["Role"]: - """Fetches all roles for a given guild and caches them. - - If the guild is not cached, it will be retrieved first using - :meth:`fetch_guild`. - """ - if self._closed: - raise DisagreementException("Client is closed.") - guild = self.get_guild(guild_id) - if not guild: - guild = await self.fetch_guild(guild_id) - if not guild: - return [] - - try: - roles_data = await self._http.get_guild_roles(guild_id) - parsed_roles = [] - for role_data in roles_data: - # parse_role will add/update it in the guild.roles list - parsed_roles.append(self.parse_role(role_data, guild_id)) - guild.roles = parsed_roles # Replace the entire list with the fresh one - return parsed_roles - except DisagreementException as e: - print(f"Failed to fetch roles for guild {guild_id}: {e}") - return [] - - async def fetch_role( - self, guild_id: Snowflake, role_id: Snowflake - ) -> Optional["Role"]: - """Fetches a specific role from a guild by ID. - If roles for the guild aren't cached or might be stale, it fetches all roles first. - """ - guild = self.get_guild(guild_id) - if guild: - # Try to find in existing guild.roles - for role in guild.roles: - if role.id == role_id: - return role - - # If not found in cache or guild doesn't exist yet in cache, fetch all roles for the guild - await self.fetch_roles(guild_id) # This will populate/update guild.roles - - # Try again from the now (hopefully) populated cache - guild = self.get_guild( - guild_id - ) # Re-get guild in case it was populated by fetch_roles - if guild: - for role in guild.roles: - if role.id == role_id: - return role - - return None # Role not found even after fetching - - # --- API Methods --- - - # --- API Methods --- - - async def send_message( - self, - channel_id: str, - content: Optional[str] = None, - *, # Make additional params keyword-only - tts: bool = False, - embed: Optional["Embed"] = None, - embeds: Optional[List["Embed"]] = None, - components: Optional[List["ActionRow"]] = None, - allowed_mentions: Optional[Dict[str, Any]] = None, - message_reference: Optional[Dict[str, Any]] = None, - attachments: Optional[List[Any]] = None, - files: Optional[List[Any]] = None, - flags: Optional[int] = None, - view: Optional["View"] = None, - ) -> "Message": - """|coro| - Sends a message to the specified channel. - - Args: - channel_id (str): The ID of the channel to send the message to. - content (Optional[str]): The content of the message. - tts (bool): Whether the message should be sent with text-to-speech. Defaults to False. - embed (Optional[Embed]): A single embed to send. Cannot be used with `embeds`. - embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`. - Discord supports up to 10 embeds per message. - components (Optional[List[ActionRow]]): A list of ActionRow components to include. - allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. - message_reference (Optional[Dict[str, Any]]): Message reference for replying. - attachments (Optional[List[Any]]): Attachments to include with the message. - files (Optional[List[Any]]): Files to upload with the message. - flags (Optional[int]): Message flags. - view (Optional[View]): A view to send with the message. - - Returns: - Message: The message that was sent. - - Raises: - HTTPException: Sending the message failed. - ValueError: If both `embed` and `embeds` are provided, or if both `components` and `view` are provided. - """ - if self._closed: - raise DisagreementException("Client is closed.") - - if embed and embeds: - raise ValueError("Cannot provide both embed and embeds.") - if components and view: - raise ValueError("Cannot provide both 'components' and 'view'.") - - final_embeds_payload: Optional[List[Dict[str, Any]]] = None - if embed: - final_embeds_payload = [embed.to_dict()] - elif embeds: - from .models import ( - Embed as EmbedModel, - ) - - final_embeds_payload = [ - e.to_dict() for e in embeds if isinstance(e, EmbedModel) - ] - - components_payload: Optional[List[Dict[str, Any]]] = None - if view: - await view._start(self) - components_payload = view.to_components_payload() - elif components: - from .models import Component as ComponentModel - - components_payload = [ - comp.to_dict() - for comp in components - if isinstance(comp, ComponentModel) - ] - - message_data = await self._http.send_message( - channel_id=channel_id, - content=content, - tts=tts, - embeds=final_embeds_payload, - components=components_payload, - allowed_mentions=allowed_mentions, - message_reference=message_reference, - attachments=attachments, - files=files, - flags=flags, - ) - - if view: - message_id = message_data["id"] - view.message_id = message_id - self._views[message_id] = view - - return self.parse_message(message_data) - - def typing(self, channel_id: str) -> Typing: - """Return a context manager to show a typing indicator in a channel.""" - - return Typing(self, channel_id) - - async def join_voice( - self, - guild_id: Snowflake, - channel_id: Snowflake, - *, - self_mute: bool = False, - self_deaf: bool = False, - ) -> VoiceClient: - """|coro| Join a voice channel and return a :class:`VoiceClient`.""" - - if self._closed: - raise DisagreementException("Client is closed.") - if not self.is_ready(): - await self.wait_until_ready() - if self._gateway is None: - raise DisagreementException("Gateway is not connected.") - if not self.user: - raise DisagreementException("Client user unavailable.") - assert self.user is not None - user_id = self.user.id - - if guild_id in self._voice_clients: - return self._voice_clients[guild_id] - - payload = { - "op": GatewayOpcode.VOICE_STATE_UPDATE, - "d": { - "guild_id": str(guild_id), - "channel_id": str(channel_id), - "self_mute": self_mute, - "self_deaf": self_deaf, - }, - } - await self._gateway._send_json(payload) # type: ignore[attr-defined] - - server = await self.wait_for( - "VOICE_SERVER_UPDATE", - check=lambda d: d.get("guild_id") == str(guild_id), - timeout=10, - ) - state = await self.wait_for( - "VOICE_STATE_UPDATE", - check=lambda d, uid=user_id: d.get("guild_id") == str(guild_id) - and d.get("user_id") == str(uid), - timeout=10, - ) - - endpoint = f"wss://{server['endpoint']}?v=10" - token = server["token"] - session_id = state["session_id"] - - voice = VoiceClient( + + # --- Extension Management Methods --- + + def load_extension(self, name: str) -> ModuleType: + """Load an extension by name using :mod:`disagreement.ext.loader`.""" + + return ext_loader.load_extension(name) + + def unload_extension(self, name: str) -> None: + """Unload a previously loaded extension.""" + + ext_loader.unload_extension(name) + + def reload_extension(self, name: str) -> ModuleType: + """Reload an extension by name.""" + + return ext_loader.reload_extension(name) + + # --- Model Parsing and Fetching --- + + def parse_user(self, data: Dict[str, Any]) -> "User": + """Parses user data and returns a User object, updating cache.""" + from .models import User # Ensure User model is available + + user = User(data) + self._users.set(user.id, user) # Cache the user + return user + + def parse_channel(self, data: Dict[str, Any]) -> "Channel": + """Parses channel data and returns a Channel object, updating caches.""" + + from .models import channel_factory + + channel = channel_factory(data, self) + self._channels.set(channel.id, channel) + if channel.guild_id: + guild = self._guilds.get(channel.guild_id) + if guild: + guild._channels.set(channel.id, channel) + return channel + + def parse_message(self, data: Dict[str, Any]) -> "Message": + """Parses message data and returns a Message object, updating cache.""" + + from .models import Message + + message = Message(data, client_instance=self) + self._messages.set(message.id, message) + return message + + def parse_webhook(self, data: Union[Dict[str, Any], "Webhook"]) -> "Webhook": + """Parses webhook data and returns a Webhook object, updating cache.""" + + from .models import Webhook + + if isinstance(data, Webhook): + webhook = data + webhook._client = self # type: ignore[attr-defined] + else: + webhook = Webhook(data, client_instance=self) + self._webhooks[webhook.id] = webhook + return webhook + + def parse_template(self, data: Dict[str, Any]) -> "GuildTemplate": + """Parses template data into a GuildTemplate object.""" + + from .models import GuildTemplate + + return GuildTemplate(data, client_instance=self) + + def parse_scheduled_event(self, data: Dict[str, Any]) -> "ScheduledEvent": + """Parses scheduled event data and updates cache.""" + + from .models import ScheduledEvent + + event = ScheduledEvent(data, client_instance=self) + # Cache by ID under guild if guild cache exists + guild = self._guilds.get(event.guild_id) + if guild is not None: + events = getattr(guild, "_scheduled_events", {}) + events[event.id] = event + setattr(guild, "_scheduled_events", events) + return event + + def parse_audit_log_entry(self, data: Dict[str, Any]) -> "AuditLogEntry": + """Parses audit log entry data.""" + from .models import AuditLogEntry + + return AuditLogEntry(data, client_instance=self) + + def parse_invite(self, data: Dict[str, Any]) -> "Invite": + """Parses invite data into an :class:`Invite`.""" + + from .models import Invite + + return Invite.from_dict(data) + + async def fetch_user(self, user_id: Snowflake) -> Optional["User"]: + """Fetches a user by ID from Discord.""" + if self._closed: + raise DisagreementException("Client is closed.") + + cached_user = self._users.get(user_id) + if cached_user: + return cached_user + + try: + user_data = await self._http.get_user(user_id) + return self.parse_user(user_data) + except DisagreementException as e: # Catch HTTP exceptions from http client + print(f"Failed to fetch user {user_id}: {e}") + return None + + async def fetch_message( + self, channel_id: Snowflake, message_id: Snowflake + ) -> Optional["Message"]: + """Fetches a message by ID from Discord and caches it.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + cached_message = self._messages.get(message_id) + if cached_message: + return cached_message + + try: + message_data = await self._http.get_message(channel_id, message_id) + return self.parse_message(message_data) + except DisagreementException as e: + print( + f"Failed to fetch message {message_id} from channel {channel_id}: {e}" + ) + return None + + def parse_member( + self, data: Dict[str, Any], guild_id: Snowflake, *, just_joined: bool = False + ) -> "Member": + """Parses member data and returns a Member object, updating relevant caches.""" + from .models import Member + + member = Member(data, client_instance=self) + member.guild_id = str(guild_id) + + if just_joined: + setattr(member, "_just_joined", True) + + guild = self._guilds.get(guild_id) + if guild: + guild._members.set(member.id, member) + + if just_joined and hasattr(member, "_just_joined"): + delattr(member, "_just_joined") + + self._users.set(member.id, member) + return member + + async def fetch_member( + self, guild_id: Snowflake, member_id: Snowflake + ) -> Optional["Member"]: + """Fetches a member from a guild by ID.""" + if self._closed: + raise DisagreementException("Client is closed.") + + guild = self.get_guild(guild_id) + if guild: + cached_member = guild.get_member(member_id) # Use Guild's get_member + if cached_member: + return cached_member # Return cached if available + + try: + member_data = await self._http.get_guild_member(guild_id, member_id) + return self.parse_member(member_data, guild_id) + except DisagreementException as e: + print(f"Failed to fetch member {member_id} from guild {guild_id}: {e}") + return None + + def parse_role(self, data: Dict[str, Any], guild_id: Snowflake) -> "Role": + """Parses role data and returns a Role object, updating guild's role cache.""" + from .models import Role # Ensure Role model is available + + role = Role(data) + guild = self._guilds.get(guild_id) + if guild: + # Update the role in the guild's roles list if it exists, or add it. + # Guild.roles is List[Role]. We need to find and replace or append. + found = False + for i, existing_role in enumerate(guild.roles): + if existing_role.id == role.id: + guild.roles[i] = role + found = True + break + if not found: + guild.roles.append(role) + return role + + def parse_guild(self, data: Dict[str, Any]) -> "Guild": + """Parses guild data and returns a Guild object, updating cache.""" + from .models import Guild + + guild = Guild(data, client_instance=self) + self._guilds.set(guild.id, guild) + + presences = {p["user"]["id"]: p for p in data.get("presences", [])} + voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])} + + for ch_data in data.get("channels", []): + self.parse_channel(ch_data) + + for member_data in data.get("members", []): + user_id = member_data.get("user", {}).get("id") + if user_id: + presence = presences.get(user_id) + if presence: + member_data["status"] = presence.get("status", "offline") + + voice_state = voice_states.get(user_id) + if voice_state: + member_data["voice_state"] = voice_state + + self.parse_member(member_data, guild.id) + + return guild + + async def fetch_roles(self, guild_id: Snowflake) -> List["Role"]: + """Fetches all roles for a given guild and caches them. + + If the guild is not cached, it will be retrieved first using + :meth:`fetch_guild`. + """ + if self._closed: + raise DisagreementException("Client is closed.") + guild = self.get_guild(guild_id) + if not guild: + guild = await self.fetch_guild(guild_id) + if not guild: + return [] + + try: + roles_data = await self._http.get_guild_roles(guild_id) + parsed_roles = [] + for role_data in roles_data: + # parse_role will add/update it in the guild.roles list + parsed_roles.append(self.parse_role(role_data, guild_id)) + guild.roles = parsed_roles # Replace the entire list with the fresh one + return parsed_roles + except DisagreementException as e: + print(f"Failed to fetch roles for guild {guild_id}: {e}") + return [] + + async def fetch_role( + self, guild_id: Snowflake, role_id: Snowflake + ) -> Optional["Role"]: + """Fetches a specific role from a guild by ID. + If roles for the guild aren't cached or might be stale, it fetches all roles first. + """ + guild = self.get_guild(guild_id) + if guild: + # Try to find in existing guild.roles + for role in guild.roles: + if role.id == role_id: + return role + + # If not found in cache or guild doesn't exist yet in cache, fetch all roles for the guild + await self.fetch_roles(guild_id) # This will populate/update guild.roles + + # Try again from the now (hopefully) populated cache + guild = self.get_guild( + guild_id + ) # Re-get guild in case it was populated by fetch_roles + if guild: + for role in guild.roles: + if role.id == role_id: + return role + + return None # Role not found even after fetching + + # --- API Methods --- + + # --- API Methods --- + + async def send_message( + self, + channel_id: str, + content: Optional[str] = None, + *, # Make additional params keyword-only + tts: bool = False, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + message_reference: Optional[Dict[str, Any]] = None, + attachments: Optional[List[Any]] = None, + files: Optional[List[Any]] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """|coro| + Sends a message to the specified channel. + + Args: + channel_id (str): The ID of the channel to send the message to. + content (Optional[str]): The content of the message. + tts (bool): Whether the message should be sent with text-to-speech. Defaults to False. + embed (Optional[Embed]): A single embed to send. Cannot be used with `embeds`. + embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`. + Discord supports up to 10 embeds per message. + components (Optional[List[ActionRow]]): A list of ActionRow components to include. + allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. + message_reference (Optional[Dict[str, Any]]): Message reference for replying. + attachments (Optional[List[Any]]): Attachments to include with the message. + files (Optional[List[Any]]): Files to upload with the message. + flags (Optional[int]): Message flags. + view (Optional[View]): A view to send with the message. + + Returns: + Message: The message that was sent. + + Raises: + HTTPException: Sending the message failed. + ValueError: If both `embed` and `embeds` are provided, or if both `components` and `view` are provided. + """ + if self._closed: + raise DisagreementException("Client is closed.") + + if embed and embeds: + raise ValueError("Cannot provide both embed and embeds.") + if components and view: + raise ValueError("Cannot provide both 'components' and 'view'.") + + final_embeds_payload: Optional[List[Dict[str, Any]]] = None + if embed: + final_embeds_payload = [embed.to_dict()] + elif embeds: + from .models import ( + Embed as EmbedModel, + ) + + final_embeds_payload = [ + e.to_dict() for e in embeds if isinstance(e, EmbedModel) + ] + + components_payload: Optional[List[Dict[str, Any]]] = None + if view: + await view._start(self) + components_payload = view.to_components_payload() + elif components: + from .models import Component as ComponentModel + + components_payload = [ + comp.to_dict() + for comp in components + if isinstance(comp, ComponentModel) + ] + + message_data = await self._http.send_message( + channel_id=channel_id, + content=content, + tts=tts, + embeds=final_embeds_payload, + components=components_payload, + allowed_mentions=allowed_mentions, + message_reference=message_reference, + attachments=attachments, + files=files, + flags=flags, + ) + + if view: + message_id = message_data["id"] + view.message_id = message_id + self._views[message_id] = view + + return self.parse_message(message_data) + + def typing(self, channel_id: str) -> Typing: + """Return a context manager to show a typing indicator in a channel.""" + + return Typing(self, channel_id) + + async def join_voice( + self, + guild_id: Snowflake, + channel_id: Snowflake, + *, + self_mute: bool = False, + self_deaf: bool = False, + ) -> VoiceClient: + """|coro| Join a voice channel and return a :class:`VoiceClient`.""" + + if self._closed: + raise DisagreementException("Client is closed.") + if not self.is_ready(): + await self.wait_until_ready() + if self._gateway is None: + raise DisagreementException("Gateway is not connected.") + if not self.user: + raise DisagreementException("Client user unavailable.") + assert self.user is not None + user_id = self.user.id + + if guild_id in self._voice_clients: + return self._voice_clients[guild_id] + + payload = { + "op": GatewayOpcode.VOICE_STATE_UPDATE, + "d": { + "guild_id": str(guild_id), + "channel_id": str(channel_id), + "self_mute": self_mute, + "self_deaf": self_deaf, + }, + } + await self._gateway._send_json(payload) # type: ignore[attr-defined] + + server = await self.wait_for( + "VOICE_SERVER_UPDATE", + check=lambda d: d.get("guild_id") == str(guild_id), + timeout=10, + ) + state = await self.wait_for( + "VOICE_STATE_UPDATE", + check=lambda d, uid=user_id: d.get("guild_id") == str(guild_id) + and d.get("user_id") == str(uid), + timeout=10, + ) + + endpoint = f"wss://{server['endpoint']}?v=10" + token = server["token"] + session_id = state["session_id"] + + voice = VoiceClient( self, - endpoint, - session_id, - token, - int(guild_id), - int(self.user.id), - verbose=self.verbose, - ) - await voice.connect() - self._voice_clients[guild_id] = voice - return voice - - async def add_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: - """|coro| Add a reaction to a message.""" - - await self.create_reaction(channel_id, message_id, emoji) - - async def remove_reaction( - self, channel_id: str, message_id: str, emoji: str - ) -> None: - """|coro| Remove the bot's reaction from a message.""" - - await self.delete_reaction(channel_id, message_id, emoji) - - async def clear_reactions(self, channel_id: str, message_id: str) -> None: - """|coro| Remove all reactions from a message.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.clear_reactions(channel_id, message_id) - - async def create_reaction( - self, channel_id: str, message_id: str, emoji: str - ) -> None: - """|coro| Add a reaction to a message.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.create_reaction(channel_id, message_id, emoji) - - user_id = getattr(getattr(self, "user", None), "id", None) - payload = { - "user_id": user_id, - "channel_id": channel_id, - "message_id": message_id, - "emoji": {"name": emoji, "id": None}, - } - if hasattr(self, "_event_dispatcher"): - await self._event_dispatcher.dispatch("MESSAGE_REACTION_ADD", payload) - - async def delete_reaction( - self, channel_id: str, message_id: str, emoji: str - ) -> None: - """|coro| Remove the bot's reaction from a message.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_reaction(channel_id, message_id, emoji) - - user_id = getattr(getattr(self, "user", None), "id", None) - payload = { - "user_id": user_id, - "channel_id": channel_id, - "message_id": message_id, - "emoji": {"name": emoji, "id": None}, - } - if hasattr(self, "_event_dispatcher"): - await self._event_dispatcher.dispatch("MESSAGE_REACTION_REMOVE", payload) - - async def get_reactions( - self, channel_id: str, message_id: str, emoji: str - ) -> List["User"]: - """|coro| Return the users who reacted with the given emoji.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - users_data = await self._http.get_reactions(channel_id, message_id, emoji) - return [self.parse_user(u) for u in users_data] - - async def edit_message( - self, - channel_id: str, - message_id: str, - *, - content: Optional[str] = None, - embed: Optional["Embed"] = None, - embeds: Optional[List["Embed"]] = None, - components: Optional[List["ActionRow"]] = None, - allowed_mentions: Optional[Dict[str, Any]] = None, - flags: Optional[int] = None, - view: Optional["View"] = None, - ) -> "Message": - """Edits a previously sent message.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - if embed and embeds: - raise ValueError("Cannot provide both embed and embeds.") - if components and view: - raise ValueError("Cannot provide both 'components' and 'view'.") - - final_embeds_payload: Optional[List[Dict[str, Any]]] = None - if embed: - final_embeds_payload = [embed.to_dict()] - elif embeds: - final_embeds_payload = [e.to_dict() for e in embeds] - - components_payload: Optional[List[Dict[str, Any]]] = None - if view: - await view._start(self) - components_payload = view.to_components_payload() - elif components: - components_payload = [c.to_dict() for c in components] - - payload: Dict[str, Any] = {} - if content is not None: - payload["content"] = content - if final_embeds_payload is not None: - payload["embeds"] = final_embeds_payload - if components_payload is not None: - payload["components"] = components_payload - if allowed_mentions is not None: - payload["allowed_mentions"] = allowed_mentions - if flags is not None: - payload["flags"] = flags - - message_data = await self._http.edit_message( - channel_id=channel_id, - message_id=message_id, - payload=payload, - ) - - if view: - view.message_id = message_data["id"] - self._views[message_data["id"]] = view - - return self.parse_message(message_data) - - def get_guild(self, guild_id: Snowflake) -> Optional["Guild"]: - """Returns a guild from the internal cache. - - Use :meth:`fetch_guild` to retrieve it from Discord if it's not cached. - """ - - return self._guilds.get(guild_id) - - def get_channel(self, channel_id: Snowflake) -> Optional["Channel"]: - """Returns a channel from the internal cache.""" - - return self._channels.get(channel_id) - - def get_message(self, message_id: Snowflake) -> Optional["Message"]: - """Returns a message from the internal cache.""" - - return self._messages.get(message_id) - - async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]: - """Fetches a guild by ID from Discord and caches it.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - cached_guild = self._guilds.get(guild_id) - if cached_guild: - return cached_guild - - try: - guild_data = await self._http.get_guild(guild_id) - return self.parse_guild(guild_data) - except DisagreementException as e: - print(f"Failed to fetch guild {guild_id}: {e}") - return None - - async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]: - """Fetches a channel from Discord by its ID and updates the cache.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - try: - channel_data = await self._http.get_channel(channel_id) - if not channel_data: - return None - - from .models import channel_factory - - channel = channel_factory(channel_data, self) - - self._channels[channel.id] = channel - return channel - - except DisagreementException as e: # Includes HTTPException - print(f"Failed to fetch channel {channel_id}: {e}") - return None - - async def fetch_audit_logs( - self, guild_id: Snowflake, **filters: Any - ) -> AsyncIterator["AuditLogEntry"]: - """Fetch audit log entries for a guild.""" - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_audit_logs(guild_id, **filters) - for entry in data.get("audit_log_entries", []): - yield self.parse_audit_log_entry(entry) - - async def fetch_voice_regions(self) -> List[VoiceRegion]: - """Fetches available voice regions.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_voice_regions() - regions = [] - for region in data: - region_id = region.get("id") - if region_id: - regions.append(VoiceRegion(region_id)) - return regions - - async def create_webhook( - self, channel_id: Snowflake, payload: Dict[str, Any] - ) -> "Webhook": - """|coro| Create a webhook in the given channel.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.create_webhook(channel_id, payload) - return self.parse_webhook(data) - - async def edit_webhook( - self, webhook_id: Snowflake, payload: Dict[str, Any] - ) -> "Webhook": - """|coro| Edit an existing webhook.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.edit_webhook(webhook_id, payload) - return self.parse_webhook(data) - - async def delete_webhook(self, webhook_id: Snowflake) -> None: - """|coro| Delete a webhook by ID.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_webhook(webhook_id) - - async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]: - """|coro| Fetch all templates for a guild.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_guild_templates(guild_id) - return [self.parse_template(t) for t in data] - - async def create_template( - self, guild_id: Snowflake, payload: Dict[str, Any] - ) -> "GuildTemplate": - """|coro| Create a template for a guild.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.create_guild_template(guild_id, payload) - return self.parse_template(data) - - async def sync_template( - self, guild_id: Snowflake, template_code: str - ) -> "GuildTemplate": - """|coro| Sync a template to the guild's current state.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.sync_guild_template(guild_id, template_code) - return self.parse_template(data) - - async def delete_template(self, guild_id: Snowflake, template_code: str) -> None: - """|coro| Delete a guild template.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_guild_template(guild_id, template_code) - - async def fetch_scheduled_events( - self, guild_id: Snowflake - ) -> List["ScheduledEvent"]: - """|coro| Fetch all scheduled events for a guild.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_guild_scheduled_events(guild_id) - return [self.parse_scheduled_event(ev) for ev in data] - - async def fetch_scheduled_event( - self, guild_id: Snowflake, event_id: Snowflake - ) -> Optional["ScheduledEvent"]: - """|coro| Fetch a single scheduled event.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - try: - data = await self._http.get_guild_scheduled_event(guild_id, event_id) - return self.parse_scheduled_event(data) - except DisagreementException as e: - print(f"Failed to fetch scheduled event {event_id}: {e}") - return None - - async def create_scheduled_event( - self, guild_id: Snowflake, payload: Dict[str, Any] - ) -> "ScheduledEvent": - """|coro| Create a scheduled event in a guild.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.create_guild_scheduled_event(guild_id, payload) - return self.parse_scheduled_event(data) - - async def edit_scheduled_event( - self, guild_id: Snowflake, event_id: Snowflake, payload: Dict[str, Any] - ) -> "ScheduledEvent": - """|coro| Edit an existing scheduled event.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.edit_guild_scheduled_event(guild_id, event_id, payload) - return self.parse_scheduled_event(data) - - async def delete_scheduled_event( - self, guild_id: Snowflake, event_id: Snowflake - ) -> None: - """|coro| Delete a scheduled event.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_guild_scheduled_event(guild_id, event_id) - - async def create_invite( - self, channel_id: Snowflake, payload: Dict[str, Any] - ) -> "Invite": - """|coro| Create an invite for the given channel.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - return await self._http.create_invite(channel_id, payload) - - async def delete_invite(self, code: str) -> None: - """|coro| Delete an invite by code.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - await self._http.delete_invite(code) - - async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]: - """|coro| Fetch all invites for a channel.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_channel_invites(channel_id) - return [self.parse_invite(inv) for inv in data] - + endpoint, + session_id, + token, + int(guild_id), + int(self.user.id), + verbose=self.verbose, + ) + await voice.connect() + self._voice_clients[guild_id] = voice + return voice + + async def add_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: + """|coro| Add a reaction to a message.""" + + await self.create_reaction(channel_id, message_id, emoji) + + async def remove_reaction( + self, channel_id: str, message_id: str, emoji: str + ) -> None: + """|coro| Remove the bot's reaction from a message.""" + + await self.delete_reaction(channel_id, message_id, emoji) + + async def clear_reactions(self, channel_id: str, message_id: str) -> None: + """|coro| Remove all reactions from a message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.clear_reactions(channel_id, message_id) + + async def create_reaction( + self, channel_id: str, message_id: str, emoji: str + ) -> None: + """|coro| Add a reaction to a message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.create_reaction(channel_id, message_id, emoji) + + user_id = getattr(getattr(self, "user", None), "id", None) + payload = { + "user_id": user_id, + "channel_id": channel_id, + "message_id": message_id, + "emoji": {"name": emoji, "id": None}, + } + if hasattr(self, "_event_dispatcher"): + await self._event_dispatcher.dispatch("MESSAGE_REACTION_ADD", payload) + + async def delete_reaction( + self, channel_id: str, message_id: str, emoji: str + ) -> None: + """|coro| Remove the bot's reaction from a message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_reaction(channel_id, message_id, emoji) + + user_id = getattr(getattr(self, "user", None), "id", None) + payload = { + "user_id": user_id, + "channel_id": channel_id, + "message_id": message_id, + "emoji": {"name": emoji, "id": None}, + } + if hasattr(self, "_event_dispatcher"): + await self._event_dispatcher.dispatch("MESSAGE_REACTION_REMOVE", payload) + + async def get_reactions( + self, channel_id: str, message_id: str, emoji: str + ) -> List["User"]: + """|coro| Return the users who reacted with the given emoji.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + users_data = await self._http.get_reactions(channel_id, message_id, emoji) + return [self.parse_user(u) for u in users_data] + + async def edit_message( + self, + channel_id: str, + message_id: str, + *, + content: Optional[str] = None, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """Edits a previously sent message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + if embed and embeds: + raise ValueError("Cannot provide both embed and embeds.") + if components and view: + raise ValueError("Cannot provide both 'components' and 'view'.") + + final_embeds_payload: Optional[List[Dict[str, Any]]] = None + if embed: + final_embeds_payload = [embed.to_dict()] + elif embeds: + final_embeds_payload = [e.to_dict() for e in embeds] + + components_payload: Optional[List[Dict[str, Any]]] = None + if view: + await view._start(self) + components_payload = view.to_components_payload() + elif components: + components_payload = [c.to_dict() for c in components] + + payload: Dict[str, Any] = {} + if content is not None: + payload["content"] = content + if final_embeds_payload is not None: + payload["embeds"] = final_embeds_payload + if components_payload is not None: + payload["components"] = components_payload + if allowed_mentions is not None: + payload["allowed_mentions"] = allowed_mentions + if flags is not None: + payload["flags"] = flags + + message_data = await self._http.edit_message( + channel_id=channel_id, + message_id=message_id, + payload=payload, + ) + + if view: + view.message_id = message_data["id"] + self._views[message_data["id"]] = view + + return self.parse_message(message_data) + + def get_guild(self, guild_id: Snowflake) -> Optional["Guild"]: + """Returns a guild from the internal cache. + + Use :meth:`fetch_guild` to retrieve it from Discord if it's not cached. + """ + + return self._guilds.get(guild_id) + + def get_channel(self, channel_id: Snowflake) -> Optional["Channel"]: + """Returns a channel from the internal cache.""" + + return self._channels.get(channel_id) + + def get_message(self, message_id: Snowflake) -> Optional["Message"]: + """Returns a message from the internal cache.""" + + return self._messages.get(message_id) + + async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]: + """Fetches a guild by ID from Discord and caches it.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + cached_guild = self._guilds.get(guild_id) + if cached_guild: + return cached_guild + + try: + guild_data = await self._http.get_guild(guild_id) + return self.parse_guild(guild_data) + except DisagreementException as e: + print(f"Failed to fetch guild {guild_id}: {e}") + return None + + async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]: + """Fetches a channel from Discord by its ID and updates the cache.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + try: + channel_data = await self._http.get_channel(channel_id) + if not channel_data: + return None + + from .models import channel_factory + + channel = channel_factory(channel_data, self) + + self._channels.set(channel.id, channel) + return channel + + except DisagreementException as e: # Includes HTTPException + print(f"Failed to fetch channel {channel_id}: {e}") + return None + + async def fetch_audit_logs( + self, guild_id: Snowflake, **filters: Any + ) -> AsyncIterator["AuditLogEntry"]: + """Fetch audit log entries for a guild.""" + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_audit_logs(guild_id, **filters) + for entry in data.get("audit_log_entries", []): + yield self.parse_audit_log_entry(entry) + + async def fetch_voice_regions(self) -> List[VoiceRegion]: + """Fetches available voice regions.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_voice_regions() + regions = [] + for region in data: + region_id = region.get("id") + if region_id: + regions.append(VoiceRegion(region_id)) + return regions + + async def create_webhook( + self, channel_id: Snowflake, payload: Dict[str, Any] + ) -> "Webhook": + """|coro| Create a webhook in the given channel.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.create_webhook(channel_id, payload) + return self.parse_webhook(data) + + async def edit_webhook( + self, webhook_id: Snowflake, payload: Dict[str, Any] + ) -> "Webhook": + """|coro| Edit an existing webhook.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.edit_webhook(webhook_id, payload) + return self.parse_webhook(data) + + async def delete_webhook(self, webhook_id: Snowflake) -> None: + """|coro| Delete a webhook by ID.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_webhook(webhook_id) + + async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]: + """|coro| Fetch all templates for a guild.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_guild_templates(guild_id) + return [self.parse_template(t) for t in data] + + async def create_template( + self, guild_id: Snowflake, payload: Dict[str, Any] + ) -> "GuildTemplate": + """|coro| Create a template for a guild.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.create_guild_template(guild_id, payload) + return self.parse_template(data) + + async def sync_template( + self, guild_id: Snowflake, template_code: str + ) -> "GuildTemplate": + """|coro| Sync a template to the guild's current state.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.sync_guild_template(guild_id, template_code) + return self.parse_template(data) + + async def delete_template(self, guild_id: Snowflake, template_code: str) -> None: + """|coro| Delete a guild template.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_guild_template(guild_id, template_code) + + async def fetch_scheduled_events( + self, guild_id: Snowflake + ) -> List["ScheduledEvent"]: + """|coro| Fetch all scheduled events for a guild.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_guild_scheduled_events(guild_id) + return [self.parse_scheduled_event(ev) for ev in data] + + async def fetch_scheduled_event( + self, guild_id: Snowflake, event_id: Snowflake + ) -> Optional["ScheduledEvent"]: + """|coro| Fetch a single scheduled event.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + try: + data = await self._http.get_guild_scheduled_event(guild_id, event_id) + return self.parse_scheduled_event(data) + except DisagreementException as e: + print(f"Failed to fetch scheduled event {event_id}: {e}") + return None + + async def create_scheduled_event( + self, guild_id: Snowflake, payload: Dict[str, Any] + ) -> "ScheduledEvent": + """|coro| Create a scheduled event in a guild.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.create_guild_scheduled_event(guild_id, payload) + return self.parse_scheduled_event(data) + + async def edit_scheduled_event( + self, guild_id: Snowflake, event_id: Snowflake, payload: Dict[str, Any] + ) -> "ScheduledEvent": + """|coro| Edit an existing scheduled event.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.edit_guild_scheduled_event(guild_id, event_id, payload) + return self.parse_scheduled_event(data) + + async def delete_scheduled_event( + self, guild_id: Snowflake, event_id: Snowflake + ) -> None: + """|coro| Delete a scheduled event.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_guild_scheduled_event(guild_id, event_id) + + async def create_invite( + self, channel_id: Snowflake, payload: Dict[str, Any] + ) -> "Invite": + """|coro| Create an invite for the given channel.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + return await self._http.create_invite(channel_id, payload) + + async def delete_invite(self, code: str) -> None: + """|coro| Delete an invite by code.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_invite(code) + + async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]: + """|coro| Fetch all invites for a channel.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_channel_invites(channel_id) + return [self.parse_invite(inv) for inv in data] + def add_persistent_view(self, view: "View") -> None: """ Registers a persistent view with the client. @@ -1522,22 +1537,22 @@ class Client: ) self._persistent_views[item.custom_id] = view - # --- Application Command Methods --- - async def process_interaction(self, interaction: Interaction) -> None: - """Internal method to process an interaction from the gateway.""" - - if hasattr(self, "on_interaction_create"): - asyncio.create_task(self.on_interaction_create(interaction)) - # Route component interactions to the appropriate View - if ( - interaction.type == InteractionType.MESSAGE_COMPONENT - and interaction.message + # --- Application Command Methods --- + async def process_interaction(self, interaction: Interaction) -> None: + """Internal method to process an interaction from the gateway.""" + + if hasattr(self, "on_interaction_create"): + asyncio.create_task(self.on_interaction_create(interaction)) + # Route component interactions to the appropriate View + if ( + interaction.type == InteractionType.MESSAGE_COMPONENT + and interaction.message and interaction.data - ): - view = self._views.get(interaction.message.id) - if view: - asyncio.create_task(view._dispatch(interaction)) - return + ): + view = self._views.get(interaction.message.id) + if view: + asyncio.create_task(view._dispatch(interaction)) + return else: # No active view found, check for persistent views custom_id = interaction.data.custom_id diff --git a/disagreement/event_dispatcher.py b/disagreement/event_dispatcher.py index b6ea317..5f4ebd2 100644 --- a/disagreement/event_dispatcher.py +++ b/disagreement/event_dispatcher.py @@ -76,7 +76,7 @@ class EventDispatcher: """Parses MESSAGE_DELETE and updates message cache.""" message_id = data.get("id") if message_id: - self._client._messages.pop(message_id, None) + self._client._messages.invalidate(message_id) return data def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]: @@ -124,7 +124,7 @@ class EventDispatcher: """Parses GUILD_MEMBER_ADD into a Member object.""" guild_id = str(data.get("guild_id")) - return self._client.parse_member(data, guild_id) + return self._client.parse_member(data, guild_id, just_joined=True) def _parse_guild_member_remove(self, data: Dict[str, Any]): """Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model.""" diff --git a/disagreement/models.py b/disagreement/models.py index be671a0..aa18424 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1,116 +1,118 @@ -# disagreement/models.py - -""" -Data models for Discord objects. -""" - -import asyncio -import json -from dataclasses import dataclass -from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union - -import aiohttp # pylint: disable=import-error -from .color import Color -from .errors import DisagreementException, HTTPException -from .enums import ( # These enums will need to be defined in disagreement/enums.py - VerificationLevel, - MessageNotificationLevel, - ExplicitContentFilterLevel, - MFALevel, - GuildNSFWLevel, - PremiumTier, - GuildFeature, - ChannelType, - ComponentType, - ButtonStyle, # Added for Button - GuildScheduledEventPrivacyLevel, - GuildScheduledEventStatus, - GuildScheduledEventEntityType, - # SelectMenuType will be part of ComponentType or a new enum if needed -) -from .permissions import Permissions - - -if TYPE_CHECKING: - from .client import Client # For type hinting to avoid circular imports - from .enums import OverwriteType # For PermissionOverwrite model - from .ui.view import View - from .interactions import Snowflake - - # Forward reference Message if it were used in type hints before its definition - # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. - from .components import component_factory - - -class User: - """Represents a Discord User. - - Attributes: - id (str): The user's unique ID. - username (str): The user's username. - discriminator (str): The user's 4-digit discord-tag. - bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False. - avatar (Optional[str]): The user's avatar hash, if any. - """ - - def __init__(self, data: dict): - self.id: str = data["id"] - self.username: str = data["username"] - self.discriminator: str = data["discriminator"] - self.bot: bool = data.get("bot", False) - self.avatar: Optional[str] = data.get("avatar") - - @property - def mention(self) -> str: - """str: Returns a string that allows you to mention the user.""" - return f"<@{self.id}>" - - def __repr__(self) -> str: - return f"" - - -class Message: - """Represents a message sent in a channel on Discord. - - Attributes: - id (str): The message's unique ID. - channel_id (str): The ID of the channel the message was sent in. - guild_id (Optional[str]): The ID of the guild the message was sent in, if applicable. - author (User): The user who sent the message. - content (str): The actual content of the message. - timestamp (str): When this message was sent (ISO8601 timestamp). - components (Optional[List[ActionRow]]): Structured components attached - to the message if present. - attachments (List[Attachment]): Attachments included with the message. - """ - - def __init__(self, data: dict, client_instance: "Client"): - self._client: "Client" = ( - client_instance # Store reference to client for methods like reply - ) - - self.id: str = data["id"] - self.channel_id: str = data["channel_id"] - self.guild_id: Optional[str] = data.get("guild_id") - self.author: User = User(data["author"]) - self.content: str = data["content"] - self.timestamp: str = data["timestamp"] - if data.get("components"): - self.components: Optional[List[ActionRow]] = [ - ActionRow.from_dict(c, client_instance) - for c in data.get("components", []) - ] - else: - self.components = None - self.attachments: List[Attachment] = [ - Attachment(a) for a in data.get("attachments", []) - ] +# disagreement/models.py + +""" +Data models for Discord objects. +""" + +import asyncio +import json +from dataclasses import dataclass +from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast + +from .cache import ChannelCache, MemberCache + +import aiohttp # pylint: disable=import-error +from .color import Color +from .errors import DisagreementException, HTTPException +from .enums import ( # These enums will need to be defined in disagreement/enums.py + VerificationLevel, + MessageNotificationLevel, + ExplicitContentFilterLevel, + MFALevel, + GuildNSFWLevel, + PremiumTier, + GuildFeature, + ChannelType, + ComponentType, + ButtonStyle, # Added for Button + GuildScheduledEventPrivacyLevel, + GuildScheduledEventStatus, + GuildScheduledEventEntityType, + # SelectMenuType will be part of ComponentType or a new enum if needed +) +from .permissions import Permissions + + +if TYPE_CHECKING: + from .client import Client # For type hinting to avoid circular imports + from .enums import OverwriteType # For PermissionOverwrite model + from .ui.view import View + from .interactions import Snowflake + + # Forward reference Message if it were used in type hints before its definition + # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. + from .components import component_factory + + +class User: + """Represents a Discord User. + + Attributes: + id (str): The user's unique ID. + username (str): The user's username. + discriminator (str): The user's 4-digit discord-tag. + bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False. + avatar (Optional[str]): The user's avatar hash, if any. + """ + + def __init__(self, data: dict): + self.id: str = data["id"] + self.username: str = data["username"] + self.discriminator: str = data["discriminator"] + self.bot: bool = data.get("bot", False) + self.avatar: Optional[str] = data.get("avatar") + + @property + def mention(self) -> str: + """str: Returns a string that allows you to mention the user.""" + return f"<@{self.id}>" + + def __repr__(self) -> str: + return f"" + + +class Message: + """Represents a message sent in a channel on Discord. + + Attributes: + id (str): The message's unique ID. + channel_id (str): The ID of the channel the message was sent in. + guild_id (Optional[str]): The ID of the guild the message was sent in, if applicable. + author (User): The user who sent the message. + content (str): The actual content of the message. + timestamp (str): When this message was sent (ISO8601 timestamp). + components (Optional[List[ActionRow]]): Structured components attached + to the message if present. + attachments (List[Attachment]): Attachments included with the message. + """ + + def __init__(self, data: dict, client_instance: "Client"): + self._client: "Client" = ( + client_instance # Store reference to client for methods like reply + ) + + self.id: str = data["id"] + self.channel_id: str = data["channel_id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.author: User = User(data["author"]) + self.content: str = data["content"] + self.timestamp: str = data["timestamp"] + if data.get("components"): + self.components: Optional[List[ActionRow]] = [ + ActionRow.from_dict(c, client_instance) + for c in data.get("components", []) + ] + else: + self.components = None + self.attachments: List[Attachment] = [ + Attachment(a) for a in data.get("attachments", []) + ] self.pinned: bool = data.get("pinned", False) - # Add other fields as needed, e.g., attachments, embeds, reactions, etc. - # self.mentions: List[User] = [User(u) for u in data.get("mentions", [])] - # self.mention_roles: List[str] = data.get("mention_roles", []) - # self.mention_everyone: bool = data.get("mention_everyone", False) - + # Add other fields as needed, e.g., attachments, embeds, reactions, etc. + # self.mentions: List[User] = [User(u) for u in data.get("mentions", [])] + # self.mention_roles: List[str] = data.get("mention_roles", []) + # self.mention_everyone: bool = data.get("mention_everyone", False) + async def pin(self) -> None: """|coro| @@ -392,737 +394,739 @@ class PartialMessage: else: await self._client._http.delete_reaction(self.channel.id, self.id, emoji) - -class EmbedFooter: - """Represents an embed footer.""" - - def __init__(self, data: Dict[str, Any]): - self.text: str = data["text"] - self.icon_url: Optional[str] = data.get("icon_url") - self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") - - def to_dict(self) -> Dict[str, Any]: - payload = {"text": self.text} - if self.icon_url: - payload["icon_url"] = self.icon_url - if self.proxy_icon_url: - payload["proxy_icon_url"] = self.proxy_icon_url - return payload - - -class EmbedImage: - """Represents an embed image.""" - - def __init__(self, data: Dict[str, Any]): - self.url: str = data["url"] - self.proxy_url: Optional[str] = data.get("proxy_url") - self.height: Optional[int] = data.get("height") - self.width: Optional[int] = data.get("width") - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"url": self.url} - if self.proxy_url: - payload["proxy_url"] = self.proxy_url - if self.height: - payload["height"] = self.height - if self.width: - payload["width"] = self.width - return payload - - def __repr__(self) -> str: - return f"" - - -class EmbedThumbnail(EmbedImage): # Similar structure to EmbedImage - """Represents an embed thumbnail.""" - - pass - - -class EmbedAuthor: - """Represents an embed author.""" - - def __init__(self, data: Dict[str, Any]): - self.name: str = data["name"] - self.url: Optional[str] = data.get("url") - self.icon_url: Optional[str] = data.get("icon_url") - self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") - - def to_dict(self) -> Dict[str, Any]: - payload = {"name": self.name} - if self.url: - payload["url"] = self.url - if self.icon_url: - payload["icon_url"] = self.icon_url - if self.proxy_icon_url: - payload["proxy_icon_url"] = self.proxy_icon_url - return payload - - -class EmbedField: - """Represents an embed field.""" - - def __init__(self, data: Dict[str, Any]): - self.name: str = data["name"] - self.value: str = data["value"] - self.inline: bool = data.get("inline", False) - - def to_dict(self) -> Dict[str, Any]: - return {"name": self.name, "value": self.value, "inline": self.inline} - - -class Embed: - """Represents a Discord embed. - - Attributes can be set directly or via methods like `set_author`, `add_field`. - """ - - def __init__(self, data: Optional[Dict[str, Any]] = None): - data = data or {} - self.title: Optional[str] = data.get("title") - self.type: str = data.get("type", "rich") # Default to "rich" for sending - self.description: Optional[str] = data.get("description") - self.url: Optional[str] = data.get("url") - self.timestamp: Optional[str] = data.get("timestamp") # ISO8601 timestamp - self.color = Color.parse(data.get("color")) - - self.footer: Optional[EmbedFooter] = ( - EmbedFooter(data["footer"]) if data.get("footer") else None - ) - self.image: Optional[EmbedImage] = ( - EmbedImage(data["image"]) if data.get("image") else None - ) - self.thumbnail: Optional[EmbedThumbnail] = ( - EmbedThumbnail(data["thumbnail"]) if data.get("thumbnail") else None - ) - # Video and Provider are less common for bot-sent embeds, can be added if needed. - self.author: Optional[EmbedAuthor] = ( - EmbedAuthor(data["author"]) if data.get("author") else None - ) - self.fields: List[EmbedField] = ( - [EmbedField(f) for f in data["fields"]] if data.get("fields") else [] - ) - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"type": self.type} - if self.title: - payload["title"] = self.title - if self.description: - payload["description"] = self.description - if self.url: - payload["url"] = self.url - if self.timestamp: - payload["timestamp"] = self.timestamp - if self.color is not None: - payload["color"] = self.color.value - if self.footer: - payload["footer"] = self.footer.to_dict() - if self.image: - payload["image"] = self.image.to_dict() - if self.thumbnail: - payload["thumbnail"] = self.thumbnail.to_dict() - if self.author: - payload["author"] = self.author.to_dict() - if self.fields: - payload["fields"] = [f.to_dict() for f in self.fields] - return payload - - # Convenience methods for building embeds can be added here - # e.g., set_author, add_field, set_footer, set_image, etc. - - -class Attachment: - """Represents a message attachment.""" - - def __init__(self, data: Dict[str, Any]): - self.id: str = data["id"] - self.filename: str = data["filename"] - self.description: Optional[str] = data.get("description") - self.content_type: Optional[str] = data.get("content_type") - self.size: Optional[int] = data.get("size") - self.url: Optional[str] = data.get("url") - self.proxy_url: Optional[str] = data.get("proxy_url") - self.height: Optional[int] = data.get("height") # If image - self.width: Optional[int] = data.get("width") # If image - self.ephemeral: bool = data.get("ephemeral", False) - - def __repr__(self) -> str: - return f"" - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"id": self.id, "filename": self.filename} - if self.description is not None: - payload["description"] = self.description - if self.content_type is not None: - payload["content_type"] = self.content_type - if self.size is not None: - payload["size"] = self.size - if self.url is not None: - payload["url"] = self.url - if self.proxy_url is not None: - payload["proxy_url"] = self.proxy_url - if self.height is not None: - payload["height"] = self.height - if self.width is not None: - payload["width"] = self.width - if self.ephemeral: - payload["ephemeral"] = self.ephemeral - return payload - - -class File: - """Represents a file to be uploaded.""" - - def __init__(self, filename: str, data: bytes): - self.filename = filename - self.data = data - - -class AllowedMentions: - """Represents allowed mentions for a message or interaction response.""" - - def __init__(self, data: Dict[str, Any]): - self.parse: List[str] = data.get("parse", []) - self.roles: List[str] = data.get("roles", []) - self.users: List[str] = data.get("users", []) - self.replied_user: bool = data.get("replied_user", False) - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"parse": self.parse} - if self.roles: - payload["roles"] = self.roles - if self.users: - payload["users"] = self.users - if self.replied_user: - payload["replied_user"] = self.replied_user - return payload - - -class RoleTags: - """Represents tags for a role.""" - - def __init__(self, data: Dict[str, Any]): - self.bot_id: Optional[str] = data.get("bot_id") - self.integration_id: Optional[str] = data.get("integration_id") - self.premium_subscriber: Optional[bool] = ( - data.get("premium_subscriber") is None - ) # presence of null value means true - - def to_dict(self) -> Dict[str, Any]: - payload = {} - if self.bot_id: - payload["bot_id"] = self.bot_id - if self.integration_id: - payload["integration_id"] = self.integration_id - if self.premium_subscriber: - payload["premium_subscriber"] = None # Explicitly null - return payload - - -class Role: - """Represents a Discord Role.""" - - def __init__(self, data: Dict[str, Any]): - self.id: str = data["id"] - self.name: str = data["name"] - self.color: int = data["color"] - self.hoist: bool = data["hoist"] - self.icon: Optional[str] = data.get("icon") - self.unicode_emoji: Optional[str] = data.get("unicode_emoji") - self.position: int = data["position"] - self.permissions: str = data["permissions"] # String of bitwise permissions - self.managed: bool = data["managed"] - self.mentionable: bool = data["mentionable"] - self.tags: Optional[RoleTags] = ( - RoleTags(data["tags"]) if data.get("tags") else None - ) - - @property - def mention(self) -> str: - """str: Returns a string that allows you to mention the role.""" - return f"<@&{self.id}>" - - def __repr__(self) -> str: - return f"" - - -class Member(User): # Member inherits from User - """Represents a Guild Member. - This class combines User attributes with guild-specific Member attributes. - """ - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client: Optional["Client"] = client_instance - self.guild_id: Optional[str] = None - # User part is nested under 'user' key in member data from gateway/API - user_data = data.get("user", {}) - # If 'id' is not in user_data but is top-level (e.g. from interaction resolved member without user object) - if "id" not in user_data and "id" in data: - # This case is less common for full member objects but can happen. - # We'd need to construct a partial user from top-level member fields if 'user' is missing. - # For now, assume 'user' object is present for full Member hydration. - # If 'user' is missing, the User part might be incomplete. - pass # User fields will be missing or default if 'user' not in data. - - super().__init__( - user_data if user_data else data - ) # Pass user_data or data if user_data is empty - - self.nick: Optional[str] = data.get("nick") - self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash - self.roles: List[str] = data.get("roles", []) # List of role IDs - self.joined_at: str = data["joined_at"] # ISO8601 timestamp - self.premium_since: Optional[str] = data.get( - "premium_since" - ) # ISO8601 timestamp - self.deaf: bool = data.get("deaf", False) - self.mute: bool = data.get("mute", False) - self.pending: bool = data.get("pending", False) - self.permissions: Optional[str] = data.get( - "permissions" - ) # Permissions in the channel, if applicable - self.communication_disabled_until: Optional[str] = data.get( - "communication_disabled_until" - ) # ISO8601 timestamp - - # If 'user' object was present, ensure User attributes are from there - if user_data: - self.id = user_data.get("id", self.id) # Prefer user.id if available - self.username = user_data.get("username", self.username) - self.discriminator = user_data.get("discriminator", self.discriminator) - self.bot = user_data.get("bot", self.bot) - # User's global avatar is User.avatar, Member.avatar is guild-specific - # super() already set self.avatar from user_data if present. - # The self.avatar = data.get("avatar") line above overwrites it with guild avatar. This is correct. - - def __repr__(self) -> str: - return f"" - - @property - def display_name(self) -> str: - """Return the nickname if set, otherwise the username.""" - - return self.nick or self.username - - async def kick(self, *, reason: Optional[str] = None) -> None: - if not self.guild_id or not self._client: - raise DisagreementException("Member.kick requires guild_id and client") - await self._client._http.kick_member(self.guild_id, self.id, reason=reason) - - async def ban( - self, - *, - delete_message_seconds: int = 0, - reason: Optional[str] = None, - ) -> None: - if not self.guild_id or not self._client: - raise DisagreementException("Member.ban requires guild_id and client") - await self._client._http.ban_member( - self.guild_id, - self.id, - delete_message_seconds=delete_message_seconds, - reason=reason, - ) - - async def timeout( - self, until: Optional[str], *, reason: Optional[str] = None - ) -> None: - if not self.guild_id or not self._client: - raise DisagreementException("Member.timeout requires guild_id and client") - await self._client._http.timeout_member( - self.guild_id, - self.id, - until=until, - reason=reason, - ) - - @property - def top_role(self) -> Optional["Role"]: - """Return the member's highest role from the guild cache.""" - - if not self.guild_id or not self._client: - return None - - guild = self._client.get_guild(self.guild_id) - if not guild: - return None - - if not guild.roles and hasattr(self._client, "fetch_roles"): - try: - self._client.loop.run_until_complete( - self._client.fetch_roles(self.guild_id) - ) - except RuntimeError: - future = asyncio.run_coroutine_threadsafe( - self._client.fetch_roles(self.guild_id), self._client.loop - ) - future.result() - - role_objects = [r for r in guild.roles if r.id in self.roles] - if not role_objects: - return None - - return max(role_objects, key=lambda r: r.position) - - -class PartialEmoji: - """Represents a partial emoji, often used in components or reactions. - - This typically means only id, name, and animated are known. - For unicode emojis, id will be None and name will be the unicode character. - """ - - def __init__(self, data: Dict[str, Any]): - self.id: Optional[str] = data.get("id") - self.name: Optional[str] = data.get( - "name" - ) # Can be None for unknown custom emoji, or unicode char - self.animated: bool = data.get("animated", False) - - def to_dict(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {} - if self.id: - payload["id"] = self.id - if self.name: - payload["name"] = self.name - if self.animated: # Only include if true, as per some Discord patterns - payload["animated"] = self.animated - return payload - - def __str__(self) -> str: - if self.id: - return f"<{'a' if self.animated else ''}:{self.name}:{self.id}>" - return self.name or "" # For unicode emoji - - def __repr__(self) -> str: - return ( - f"" - ) - - -def to_partial_emoji( - value: Union[str, "PartialEmoji", None], -) -> Optional["PartialEmoji"]: - """Convert a string or PartialEmoji to a PartialEmoji instance. - - Args: - value: Either a unicode emoji string, a :class:`PartialEmoji`, or ``None``. - - Returns: - A :class:`PartialEmoji` or ``None`` if ``value`` was ``None``. - - Raises: - TypeError: If ``value`` is not ``str`` or :class:`PartialEmoji`. - """ - - if value is None or isinstance(value, PartialEmoji): - return value - if isinstance(value, str): - return PartialEmoji({"name": value, "id": None}) - raise TypeError("emoji must be a str or PartialEmoji") - - -class Emoji(PartialEmoji): - """Represents a custom guild emoji. - - Inherits id, name, animated from PartialEmoji. - """ - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - super().__init__(data) - self._client: Optional["Client"] = ( - client_instance # For potential future methods - ) - - # Roles this emoji is whitelisted to - self.roles: List[str] = data.get("roles", []) # List of role IDs - - # User object for the user that created this emoji (optional, only for GUILD_EMOJIS_AND_STICKERS intent) - self.user: Optional[User] = User(data["user"]) if data.get("user") else None - - self.require_colons: bool = data.get("require_colons", False) - self.managed: bool = data.get( - "managed", False - ) # If this emoji is managed by an integration - self.available: bool = data.get( - "available", True - ) # Whether this emoji can be used - - def __repr__(self) -> str: - return f"" - - -class StickerItem: - """Represents a sticker item, a basic representation of a sticker. - - Used in sticker packs and sometimes in message data. - """ - - def __init__(self, data: Dict[str, Any]): - self.id: str = data["id"] - self.name: str = data["name"] - self.format_type: int = data["format_type"] # StickerFormatType enum - - def __repr__(self) -> str: - return f"" - - -class Sticker(StickerItem): - """Represents a Discord sticker. - - Inherits id, name, format_type from StickerItem. - """ - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - super().__init__(data) - self._client: Optional["Client"] = client_instance - - self.pack_id: Optional[str] = data.get( - "pack_id" - ) # For standard stickers, ID of the pack - self.description: Optional[str] = data.get("description") - self.tags: str = data.get( - "tags", "" - ) # Comma-separated list of tags for guild stickers - # type is StickerType enum (STANDARD or GUILD) - # For guild stickers, this is 2. For standard stickers, this is 1. - self.type: int = data["type"] - self.available: bool = data.get( - "available", True - ) # Whether this sticker can be used - self.guild_id: Optional[str] = data.get( - "guild_id" - ) # ID of the guild that owns this sticker - - # User object of the user that uploaded the guild sticker - self.user: Optional[User] = User(data["user"]) if data.get("user") else None - - self.sort_value: Optional[int] = data.get( - "sort_value" - ) # The standard sticker's sort order within its pack - - def __repr__(self) -> str: - return f"" - - -class StickerPack: - """Represents a pack of standard stickers.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client: Optional["Client"] = client_instance - self.id: str = data["id"] - self.stickers: List[Sticker] = [ - Sticker(s_data, client_instance) for s_data in data.get("stickers", []) - ] - self.name: str = data["name"] - self.sku_id: str = data["sku_id"] - self.cover_sticker_id: Optional[str] = data.get("cover_sticker_id") - self.description: str = data["description"] - self.banner_asset_id: Optional[str] = data.get( - "banner_asset_id" - ) # ID of the pack's banner image - - def __repr__(self) -> str: - return f"" - - -class PermissionOverwrite: - """Represents a permission overwrite for a role or member in a channel.""" - - def __init__(self, data: Dict[str, Any]): - self.id: str = data["id"] # Role or user ID - self._type_val: int = int(data["type"]) # Store raw type for enum property - self.allow: str = data["allow"] # Bitwise value of allowed permissions - self.deny: str = data["deny"] # Bitwise value of denied permissions - - @property - def type(self) -> "OverwriteType": - from .enums import ( - OverwriteType, - ) # Local import to avoid circularity at module level - - return OverwriteType(self._type_val) - - def to_dict(self) -> Dict[str, Any]: - return { - "id": self.id, - "type": self.type.value, - "allow": self.allow, - "deny": self.deny, - } - - def __repr__(self) -> str: - return f"" - - -class Guild: - """Represents a Discord Guild (Server). - - Attributes: - id (str): Guild ID. - name (str): Guild name (2-100 characters, excluding @, #, :, ```). - icon (Optional[str]): Icon hash. - splash (Optional[str]): Splash hash. - discovery_splash (Optional[str]): Discovery splash hash; only present for discoverable guilds. - owner (Optional[bool]): True if the user is the owner of the guild. (Only for /users/@me/guilds endpoint) - owner_id (str): ID of owner. - permissions (Optional[str]): Total permissions for the user in the guild (excludes overwrites). (Only for /users/@me/guilds endpoint) - afk_channel_id (Optional[str]): ID of afk channel. - afk_timeout (int): AFK timeout in seconds. - widget_enabled (Optional[bool]): True if the server widget is enabled. - widget_channel_id (Optional[str]): The channel id that the widget will generate an invite to, or null if set to no invite. - verification_level (VerificationLevel): Verification level required for the guild. - default_message_notifications (MessageNotificationLevel): Default message notifications level. - explicit_content_filter (ExplicitContentFilterLevel): Explicit content filter level. - roles (List[Role]): Roles in the guild. - emojis (List[Dict]): Custom emojis. (Consider creating an Emoji model) - features (List[GuildFeature]): Enabled guild features. - mfa_level (MFALevel): Required MFA level for the guild. - application_id (Optional[str]): Application ID of the guild creator if it is bot-created. - system_channel_id (Optional[str]): The id of the channel where guild notices such as welcome messages and boost events are posted. - system_channel_flags (int): System channel flags. - rules_channel_id (Optional[str]): The id of the channel where Community guilds can display rules. - max_members (Optional[int]): The maximum number of members for the guild. - vanity_url_code (Optional[str]): The vanity url code for the guild. - description (Optional[str]): The description of a Community guild. - banner (Optional[str]): Banner hash. - premium_tier (PremiumTier): Premium tier (Server Boost level). - premium_subscription_count (Optional[int]): The number of boosts this guild currently has. - preferred_locale (str): The preferred locale of a Community guild. Defaults to "en-US". - public_updates_channel_id (Optional[str]): The id of the channel where admins and moderators of Community guilds receive notices from Discord. - max_video_channel_users (Optional[int]): The maximum number of users in a video channel. - welcome_screen (Optional[Dict]): The welcome screen of a Community guild. (Consider a WelcomeScreen model) - nsfw_level (GuildNSFWLevel): Guild NSFW level. - stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model) - premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled. - """ - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - self._client: "Client" = client_instance - self.id: str = data["id"] - self.name: str = data["name"] - self.icon: Optional[str] = data.get("icon") - self.splash: Optional[str] = data.get("splash") - self.discovery_splash: Optional[str] = data.get("discovery_splash") - self.owner: Optional[bool] = data.get("owner") - self.owner_id: str = data["owner_id"] - self.permissions: Optional[str] = data.get("permissions") - self.afk_channel_id: Optional[str] = data.get("afk_channel_id") - self.afk_timeout: int = data["afk_timeout"] - self.widget_enabled: Optional[bool] = data.get("widget_enabled") - self.widget_channel_id: Optional[str] = data.get("widget_channel_id") - self.verification_level: VerificationLevel = VerificationLevel( - data["verification_level"] - ) - self.default_message_notifications: MessageNotificationLevel = ( - MessageNotificationLevel(data["default_message_notifications"]) - ) - self.explicit_content_filter: ExplicitContentFilterLevel = ( - ExplicitContentFilterLevel(data["explicit_content_filter"]) - ) - - self.roles: List[Role] = [Role(r) for r in data.get("roles", [])] - self.emojis: List[Emoji] = [ - Emoji(e_data, client_instance) for e_data in data.get("emojis", []) - ] - - # Assuming GuildFeature can be constructed from string feature names or their values - self.features: List[GuildFeature] = [ - GuildFeature(f) if not isinstance(f, GuildFeature) else f - for f in data.get("features", []) - ] - - self.mfa_level: MFALevel = MFALevel(data["mfa_level"]) - self.application_id: Optional[str] = data.get("application_id") - self.system_channel_id: Optional[str] = data.get("system_channel_id") - self.system_channel_flags: int = data["system_channel_flags"] - self.rules_channel_id: Optional[str] = data.get("rules_channel_id") - self.max_members: Optional[int] = data.get("max_members") - self.vanity_url_code: Optional[str] = data.get("vanity_url_code") - self.description: Optional[str] = data.get("description") - self.banner: Optional[str] = data.get("banner") - self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"]) - self.premium_subscription_count: Optional[int] = data.get( - "premium_subscription_count" - ) - self.preferred_locale: str = data.get("preferred_locale", "en-US") - self.public_updates_channel_id: Optional[str] = data.get( - "public_updates_channel_id" - ) - self.max_video_channel_users: Optional[int] = data.get( - "max_video_channel_users" - ) - self.approximate_member_count: Optional[int] = data.get( - "approximate_member_count" - ) - self.approximate_presence_count: Optional[int] = data.get( - "approximate_presence_count" - ) - self.welcome_screen: Optional["WelcomeScreen"] = ( - WelcomeScreen(data["welcome_screen"], client_instance) - if data.get("welcome_screen") - else None - ) - self.nsfw_level: GuildNSFWLevel = GuildNSFWLevel(data["nsfw_level"]) - self.stickers: Optional[List[Sticker]] = ( - [Sticker(s_data, client_instance) for s_data in data.get("stickers", [])] - if data.get("stickers") - else None - ) - self.premium_progress_bar_enabled: bool = data.get( - "premium_progress_bar_enabled", False - ) - - # Internal caches, populated by events or specific fetches - self._channels: Dict[str, "Channel"] = {} - self._members: Dict[str, Member] = {} - self._threads: Dict[str, "Thread"] = {} - - def get_channel(self, channel_id: str) -> Optional["Channel"]: - return self._channels.get(channel_id) - - def get_member(self, user_id: str) -> Optional[Member]: - return self._members.get(user_id) - - def get_member_named(self, name: str) -> Optional[Member]: - """Retrieve a cached member by username or nickname. - - The lookup is case-insensitive and searches both the username and - guild nickname for a match. - - Parameters - ---------- - name: str - The username or nickname to search for. - - Returns - ------- - Optional[Member] - The matching member if found, otherwise ``None``. - """ - - lowered = name.lower() - for member in self._members.values(): - if member.username.lower() == lowered: - return member - if member.nick and member.nick.lower() == lowered: - return member - return None - - def get_role(self, role_id: str) -> Optional[Role]: - return next((role for role in self.roles if role.id == role_id), None) - - def __repr__(self) -> str: - return f"" - + +class EmbedFooter: + """Represents an embed footer.""" + + def __init__(self, data: Dict[str, Any]): + self.text: str = data["text"] + self.icon_url: Optional[str] = data.get("icon_url") + self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") + + def to_dict(self) -> Dict[str, Any]: + payload = {"text": self.text} + if self.icon_url: + payload["icon_url"] = self.icon_url + if self.proxy_icon_url: + payload["proxy_icon_url"] = self.proxy_icon_url + return payload + + +class EmbedImage: + """Represents an embed image.""" + + def __init__(self, data: Dict[str, Any]): + self.url: str = data["url"] + self.proxy_url: Optional[str] = data.get("proxy_url") + self.height: Optional[int] = data.get("height") + self.width: Optional[int] = data.get("width") + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"url": self.url} + if self.proxy_url: + payload["proxy_url"] = self.proxy_url + if self.height: + payload["height"] = self.height + if self.width: + payload["width"] = self.width + return payload + + def __repr__(self) -> str: + return f"" + + +class EmbedThumbnail(EmbedImage): # Similar structure to EmbedImage + """Represents an embed thumbnail.""" + + pass + + +class EmbedAuthor: + """Represents an embed author.""" + + def __init__(self, data: Dict[str, Any]): + self.name: str = data["name"] + self.url: Optional[str] = data.get("url") + self.icon_url: Optional[str] = data.get("icon_url") + self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") + + def to_dict(self) -> Dict[str, Any]: + payload = {"name": self.name} + if self.url: + payload["url"] = self.url + if self.icon_url: + payload["icon_url"] = self.icon_url + if self.proxy_icon_url: + payload["proxy_icon_url"] = self.proxy_icon_url + return payload + + +class EmbedField: + """Represents an embed field.""" + + def __init__(self, data: Dict[str, Any]): + self.name: str = data["name"] + self.value: str = data["value"] + self.inline: bool = data.get("inline", False) + + def to_dict(self) -> Dict[str, Any]: + return {"name": self.name, "value": self.value, "inline": self.inline} + + +class Embed: + """Represents a Discord embed. + + Attributes can be set directly or via methods like `set_author`, `add_field`. + """ + + def __init__(self, data: Optional[Dict[str, Any]] = None): + data = data or {} + self.title: Optional[str] = data.get("title") + self.type: str = data.get("type", "rich") # Default to "rich" for sending + self.description: Optional[str] = data.get("description") + self.url: Optional[str] = data.get("url") + self.timestamp: Optional[str] = data.get("timestamp") # ISO8601 timestamp + self.color = Color.parse(data.get("color")) + + self.footer: Optional[EmbedFooter] = ( + EmbedFooter(data["footer"]) if data.get("footer") else None + ) + self.image: Optional[EmbedImage] = ( + EmbedImage(data["image"]) if data.get("image") else None + ) + self.thumbnail: Optional[EmbedThumbnail] = ( + EmbedThumbnail(data["thumbnail"]) if data.get("thumbnail") else None + ) + # Video and Provider are less common for bot-sent embeds, can be added if needed. + self.author: Optional[EmbedAuthor] = ( + EmbedAuthor(data["author"]) if data.get("author") else None + ) + self.fields: List[EmbedField] = ( + [EmbedField(f) for f in data["fields"]] if data.get("fields") else [] + ) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"type": self.type} + if self.title: + payload["title"] = self.title + if self.description: + payload["description"] = self.description + if self.url: + payload["url"] = self.url + if self.timestamp: + payload["timestamp"] = self.timestamp + if self.color is not None: + payload["color"] = self.color.value + if self.footer: + payload["footer"] = self.footer.to_dict() + if self.image: + payload["image"] = self.image.to_dict() + if self.thumbnail: + payload["thumbnail"] = self.thumbnail.to_dict() + if self.author: + payload["author"] = self.author.to_dict() + if self.fields: + payload["fields"] = [f.to_dict() for f in self.fields] + return payload + + # Convenience methods for building embeds can be added here + # e.g., set_author, add_field, set_footer, set_image, etc. + + +class Attachment: + """Represents a message attachment.""" + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] + self.filename: str = data["filename"] + self.description: Optional[str] = data.get("description") + self.content_type: Optional[str] = data.get("content_type") + self.size: Optional[int] = data.get("size") + self.url: Optional[str] = data.get("url") + self.proxy_url: Optional[str] = data.get("proxy_url") + self.height: Optional[int] = data.get("height") # If image + self.width: Optional[int] = data.get("width") # If image + self.ephemeral: bool = data.get("ephemeral", False) + + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"id": self.id, "filename": self.filename} + if self.description is not None: + payload["description"] = self.description + if self.content_type is not None: + payload["content_type"] = self.content_type + if self.size is not None: + payload["size"] = self.size + if self.url is not None: + payload["url"] = self.url + if self.proxy_url is not None: + payload["proxy_url"] = self.proxy_url + if self.height is not None: + payload["height"] = self.height + if self.width is not None: + payload["width"] = self.width + if self.ephemeral: + payload["ephemeral"] = self.ephemeral + return payload + + +class File: + """Represents a file to be uploaded.""" + + def __init__(self, filename: str, data: bytes): + self.filename = filename + self.data = data + + +class AllowedMentions: + """Represents allowed mentions for a message or interaction response.""" + + def __init__(self, data: Dict[str, Any]): + self.parse: List[str] = data.get("parse", []) + self.roles: List[str] = data.get("roles", []) + self.users: List[str] = data.get("users", []) + self.replied_user: bool = data.get("replied_user", False) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"parse": self.parse} + if self.roles: + payload["roles"] = self.roles + if self.users: + payload["users"] = self.users + if self.replied_user: + payload["replied_user"] = self.replied_user + return payload + + +class RoleTags: + """Represents tags for a role.""" + + def __init__(self, data: Dict[str, Any]): + self.bot_id: Optional[str] = data.get("bot_id") + self.integration_id: Optional[str] = data.get("integration_id") + self.premium_subscriber: Optional[bool] = ( + data.get("premium_subscriber") is None + ) # presence of null value means true + + def to_dict(self) -> Dict[str, Any]: + payload = {} + if self.bot_id: + payload["bot_id"] = self.bot_id + if self.integration_id: + payload["integration_id"] = self.integration_id + if self.premium_subscriber: + payload["premium_subscriber"] = None # Explicitly null + return payload + + +class Role: + """Represents a Discord Role.""" + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] + self.name: str = data["name"] + self.color: int = data["color"] + self.hoist: bool = data["hoist"] + self.icon: Optional[str] = data.get("icon") + self.unicode_emoji: Optional[str] = data.get("unicode_emoji") + self.position: int = data["position"] + self.permissions: str = data["permissions"] # String of bitwise permissions + self.managed: bool = data["managed"] + self.mentionable: bool = data["mentionable"] + self.tags: Optional[RoleTags] = ( + RoleTags(data["tags"]) if data.get("tags") else None + ) + + @property + def mention(self) -> str: + """str: Returns a string that allows you to mention the role.""" + return f"<@&{self.id}>" + + def __repr__(self) -> str: + return f"" + + +class Member(User): # Member inherits from User + """Represents a Guild Member. + This class combines User attributes with guild-specific Member attributes. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client: Optional["Client"] = client_instance + self.guild_id: Optional[str] = None + self.status: Optional[str] = None + self.voice_state: Optional[Dict[str, Any]] = None + # User part is nested under 'user' key in member data from gateway/API + user_data = data.get("user", {}) + # If 'id' is not in user_data but is top-level (e.g. from interaction resolved member without user object) + if "id" not in user_data and "id" in data: + # This case is less common for full member objects but can happen. + # We'd need to construct a partial user from top-level member fields if 'user' is missing. + # For now, assume 'user' object is present for full Member hydration. + # If 'user' is missing, the User part might be incomplete. + pass # User fields will be missing or default if 'user' not in data. + + super().__init__( + user_data if user_data else data + ) # Pass user_data or data if user_data is empty + + self.nick: Optional[str] = data.get("nick") + self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash + self.roles: List[str] = data.get("roles", []) # List of role IDs + self.joined_at: str = data["joined_at"] # ISO8601 timestamp + self.premium_since: Optional[str] = data.get( + "premium_since" + ) # ISO8601 timestamp + self.deaf: bool = data.get("deaf", False) + self.mute: bool = data.get("mute", False) + self.pending: bool = data.get("pending", False) + self.permissions: Optional[str] = data.get( + "permissions" + ) # Permissions in the channel, if applicable + self.communication_disabled_until: Optional[str] = data.get( + "communication_disabled_until" + ) # ISO8601 timestamp + + # If 'user' object was present, ensure User attributes are from there + if user_data: + self.id = user_data.get("id", self.id) # Prefer user.id if available + self.username = user_data.get("username", self.username) + self.discriminator = user_data.get("discriminator", self.discriminator) + self.bot = user_data.get("bot", self.bot) + # User's global avatar is User.avatar, Member.avatar is guild-specific + # super() already set self.avatar from user_data if present. + # The self.avatar = data.get("avatar") line above overwrites it with guild avatar. This is correct. + + def __repr__(self) -> str: + return f"" + + @property + def display_name(self) -> str: + """Return the nickname if set, otherwise the username.""" + + return self.nick or self.username + + async def kick(self, *, reason: Optional[str] = None) -> None: + if not self.guild_id or not self._client: + raise DisagreementException("Member.kick requires guild_id and client") + await self._client._http.kick_member(self.guild_id, self.id, reason=reason) + + async def ban( + self, + *, + delete_message_seconds: int = 0, + reason: Optional[str] = None, + ) -> None: + if not self.guild_id or not self._client: + raise DisagreementException("Member.ban requires guild_id and client") + await self._client._http.ban_member( + self.guild_id, + self.id, + delete_message_seconds=delete_message_seconds, + reason=reason, + ) + + async def timeout( + self, until: Optional[str], *, reason: Optional[str] = None + ) -> None: + if not self.guild_id or not self._client: + raise DisagreementException("Member.timeout requires guild_id and client") + await self._client._http.timeout_member( + self.guild_id, + self.id, + until=until, + reason=reason, + ) + + @property + def top_role(self) -> Optional["Role"]: + """Return the member's highest role from the guild cache.""" + + if not self.guild_id or not self._client: + return None + + guild = self._client.get_guild(self.guild_id) + if not guild: + return None + + if not guild.roles and hasattr(self._client, "fetch_roles"): + try: + self._client.loop.run_until_complete( + self._client.fetch_roles(self.guild_id) + ) + except RuntimeError: + future = asyncio.run_coroutine_threadsafe( + self._client.fetch_roles(self.guild_id), self._client.loop + ) + future.result() + + role_objects = [r for r in guild.roles if r.id in self.roles] + if not role_objects: + return None + + return max(role_objects, key=lambda r: r.position) + + +class PartialEmoji: + """Represents a partial emoji, often used in components or reactions. + + This typically means only id, name, and animated are known. + For unicode emojis, id will be None and name will be the unicode character. + """ + + def __init__(self, data: Dict[str, Any]): + self.id: Optional[str] = data.get("id") + self.name: Optional[str] = data.get( + "name" + ) # Can be None for unknown custom emoji, or unicode char + self.animated: bool = data.get("animated", False) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + if self.id: + payload["id"] = self.id + if self.name: + payload["name"] = self.name + if self.animated: # Only include if true, as per some Discord patterns + payload["animated"] = self.animated + return payload + + def __str__(self) -> str: + if self.id: + return f"<{'a' if self.animated else ''}:{self.name}:{self.id}>" + return self.name or "" # For unicode emoji + + def __repr__(self) -> str: + return ( + f"" + ) + + +def to_partial_emoji( + value: Union[str, "PartialEmoji", None], +) -> Optional["PartialEmoji"]: + """Convert a string or PartialEmoji to a PartialEmoji instance. + + Args: + value: Either a unicode emoji string, a :class:`PartialEmoji`, or ``None``. + + Returns: + A :class:`PartialEmoji` or ``None`` if ``value`` was ``None``. + + Raises: + TypeError: If ``value`` is not ``str`` or :class:`PartialEmoji`. + """ + + if value is None or isinstance(value, PartialEmoji): + return value + if isinstance(value, str): + return PartialEmoji({"name": value, "id": None}) + raise TypeError("emoji must be a str or PartialEmoji") + + +class Emoji(PartialEmoji): + """Represents a custom guild emoji. + + Inherits id, name, animated from PartialEmoji. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + super().__init__(data) + self._client: Optional["Client"] = ( + client_instance # For potential future methods + ) + + # Roles this emoji is whitelisted to + self.roles: List[str] = data.get("roles", []) # List of role IDs + + # User object for the user that created this emoji (optional, only for GUILD_EMOJIS_AND_STICKERS intent) + self.user: Optional[User] = User(data["user"]) if data.get("user") else None + + self.require_colons: bool = data.get("require_colons", False) + self.managed: bool = data.get( + "managed", False + ) # If this emoji is managed by an integration + self.available: bool = data.get( + "available", True + ) # Whether this emoji can be used + + def __repr__(self) -> str: + return f"" + + +class StickerItem: + """Represents a sticker item, a basic representation of a sticker. + + Used in sticker packs and sometimes in message data. + """ + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] + self.name: str = data["name"] + self.format_type: int = data["format_type"] # StickerFormatType enum + + def __repr__(self) -> str: + return f"" + + +class Sticker(StickerItem): + """Represents a Discord sticker. + + Inherits id, name, format_type from StickerItem. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + super().__init__(data) + self._client: Optional["Client"] = client_instance + + self.pack_id: Optional[str] = data.get( + "pack_id" + ) # For standard stickers, ID of the pack + self.description: Optional[str] = data.get("description") + self.tags: str = data.get( + "tags", "" + ) # Comma-separated list of tags for guild stickers + # type is StickerType enum (STANDARD or GUILD) + # For guild stickers, this is 2. For standard stickers, this is 1. + self.type: int = data["type"] + self.available: bool = data.get( + "available", True + ) # Whether this sticker can be used + self.guild_id: Optional[str] = data.get( + "guild_id" + ) # ID of the guild that owns this sticker + + # User object of the user that uploaded the guild sticker + self.user: Optional[User] = User(data["user"]) if data.get("user") else None + + self.sort_value: Optional[int] = data.get( + "sort_value" + ) # The standard sticker's sort order within its pack + + def __repr__(self) -> str: + return f"" + + +class StickerPack: + """Represents a pack of standard stickers.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client: Optional["Client"] = client_instance + self.id: str = data["id"] + self.stickers: List[Sticker] = [ + Sticker(s_data, client_instance) for s_data in data.get("stickers", []) + ] + self.name: str = data["name"] + self.sku_id: str = data["sku_id"] + self.cover_sticker_id: Optional[str] = data.get("cover_sticker_id") + self.description: str = data["description"] + self.banner_asset_id: Optional[str] = data.get( + "banner_asset_id" + ) # ID of the pack's banner image + + def __repr__(self) -> str: + return f"" + + +class PermissionOverwrite: + """Represents a permission overwrite for a role or member in a channel.""" + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] # Role or user ID + self._type_val: int = int(data["type"]) # Store raw type for enum property + self.allow: str = data["allow"] # Bitwise value of allowed permissions + self.deny: str = data["deny"] # Bitwise value of denied permissions + + @property + def type(self) -> "OverwriteType": + from .enums import ( + OverwriteType, + ) # Local import to avoid circularity at module level + + return OverwriteType(self._type_val) + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "type": self.type.value, + "allow": self.allow, + "deny": self.deny, + } + + def __repr__(self) -> str: + return f"" + + +class Guild: + """Represents a Discord Guild (Server). + + Attributes: + id (str): Guild ID. + name (str): Guild name (2-100 characters, excluding @, #, :, ```). + icon (Optional[str]): Icon hash. + splash (Optional[str]): Splash hash. + discovery_splash (Optional[str]): Discovery splash hash; only present for discoverable guilds. + owner (Optional[bool]): True if the user is the owner of the guild. (Only for /users/@me/guilds endpoint) + owner_id (str): ID of owner. + permissions (Optional[str]): Total permissions for the user in the guild (excludes overwrites). (Only for /users/@me/guilds endpoint) + afk_channel_id (Optional[str]): ID of afk channel. + afk_timeout (int): AFK timeout in seconds. + widget_enabled (Optional[bool]): True if the server widget is enabled. + widget_channel_id (Optional[str]): The channel id that the widget will generate an invite to, or null if set to no invite. + verification_level (VerificationLevel): Verification level required for the guild. + default_message_notifications (MessageNotificationLevel): Default message notifications level. + explicit_content_filter (ExplicitContentFilterLevel): Explicit content filter level. + roles (List[Role]): Roles in the guild. + emojis (List[Dict]): Custom emojis. (Consider creating an Emoji model) + features (List[GuildFeature]): Enabled guild features. + mfa_level (MFALevel): Required MFA level for the guild. + application_id (Optional[str]): Application ID of the guild creator if it is bot-created. + system_channel_id (Optional[str]): The id of the channel where guild notices such as welcome messages and boost events are posted. + system_channel_flags (int): System channel flags. + rules_channel_id (Optional[str]): The id of the channel where Community guilds can display rules. + max_members (Optional[int]): The maximum number of members for the guild. + vanity_url_code (Optional[str]): The vanity url code for the guild. + description (Optional[str]): The description of a Community guild. + banner (Optional[str]): Banner hash. + premium_tier (PremiumTier): Premium tier (Server Boost level). + premium_subscription_count (Optional[int]): The number of boosts this guild currently has. + preferred_locale (str): The preferred locale of a Community guild. Defaults to "en-US". + public_updates_channel_id (Optional[str]): The id of the channel where admins and moderators of Community guilds receive notices from Discord. + max_video_channel_users (Optional[int]): The maximum number of users in a video channel. + welcome_screen (Optional[Dict]): The welcome screen of a Community guild. (Consider a WelcomeScreen model) + nsfw_level (GuildNSFWLevel): Guild NSFW level. + stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model) + premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled. + """ + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + self._client: "Client" = client_instance + self.id: str = data["id"] + self.name: str = data["name"] + self.icon: Optional[str] = data.get("icon") + self.splash: Optional[str] = data.get("splash") + self.discovery_splash: Optional[str] = data.get("discovery_splash") + self.owner: Optional[bool] = data.get("owner") + self.owner_id: str = data["owner_id"] + self.permissions: Optional[str] = data.get("permissions") + self.afk_channel_id: Optional[str] = data.get("afk_channel_id") + self.afk_timeout: int = data["afk_timeout"] + self.widget_enabled: Optional[bool] = data.get("widget_enabled") + self.widget_channel_id: Optional[str] = data.get("widget_channel_id") + self.verification_level: VerificationLevel = VerificationLevel( + data["verification_level"] + ) + self.default_message_notifications: MessageNotificationLevel = ( + MessageNotificationLevel(data["default_message_notifications"]) + ) + self.explicit_content_filter: ExplicitContentFilterLevel = ( + ExplicitContentFilterLevel(data["explicit_content_filter"]) + ) + + self.roles: List[Role] = [Role(r) for r in data.get("roles", [])] + self.emojis: List[Emoji] = [ + Emoji(e_data, client_instance) for e_data in data.get("emojis", []) + ] + + # Assuming GuildFeature can be constructed from string feature names or their values + self.features: List[GuildFeature] = [ + GuildFeature(f) if not isinstance(f, GuildFeature) else f + for f in data.get("features", []) + ] + + self.mfa_level: MFALevel = MFALevel(data["mfa_level"]) + self.application_id: Optional[str] = data.get("application_id") + self.system_channel_id: Optional[str] = data.get("system_channel_id") + self.system_channel_flags: int = data["system_channel_flags"] + self.rules_channel_id: Optional[str] = data.get("rules_channel_id") + self.max_members: Optional[int] = data.get("max_members") + self.vanity_url_code: Optional[str] = data.get("vanity_url_code") + self.description: Optional[str] = data.get("description") + self.banner: Optional[str] = data.get("banner") + self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"]) + self.premium_subscription_count: Optional[int] = data.get( + "premium_subscription_count" + ) + self.preferred_locale: str = data.get("preferred_locale", "en-US") + self.public_updates_channel_id: Optional[str] = data.get( + "public_updates_channel_id" + ) + self.max_video_channel_users: Optional[int] = data.get( + "max_video_channel_users" + ) + self.approximate_member_count: Optional[int] = data.get( + "approximate_member_count" + ) + self.approximate_presence_count: Optional[int] = data.get( + "approximate_presence_count" + ) + self.welcome_screen: Optional["WelcomeScreen"] = ( + WelcomeScreen(data["welcome_screen"], client_instance) + if data.get("welcome_screen") + else None + ) + self.nsfw_level: GuildNSFWLevel = GuildNSFWLevel(data["nsfw_level"]) + self.stickers: Optional[List[Sticker]] = ( + [Sticker(s_data, client_instance) for s_data in data.get("stickers", [])] + if data.get("stickers") + else None + ) + self.premium_progress_bar_enabled: bool = data.get( + "premium_progress_bar_enabled", False + ) + + # Internal caches, populated by events or specific fetches + self._channels: ChannelCache = ChannelCache() + self._members: MemberCache = MemberCache(client_instance.member_cache_flags) + self._threads: Dict[str, "Thread"] = {} + + def get_channel(self, channel_id: str) -> Optional["Channel"]: + return self._channels.get(channel_id) + + def get_member(self, user_id: str) -> Optional[Member]: + return self._members.get(user_id) + + def get_member_named(self, name: str) -> Optional[Member]: + """Retrieve a cached member by username or nickname. + + The lookup is case-insensitive and searches both the username and + guild nickname for a match. + + Parameters + ---------- + name: str + The username or nickname to search for. + + Returns + ------- + Optional[Member] + The matching member if found, otherwise ``None``. + """ + + lowered = name.lower() + for member in self._members.values(): + if member.username.lower() == lowered: + return member + if member.nick and member.nick.lower() == lowered: + return member + return None + + def get_role(self, role_id: str) -> Optional[Role]: + return next((role for role in self.roles if role.id == role_id), None) + + def __repr__(self) -> str: + return f"" + async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]: """|coro| @@ -1166,341 +1170,427 @@ class Guild: del self._client._gateway._member_chunk_requests[nonce] raise - -class Channel: - """Base class for Discord channels.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - self._client: "Client" = client_instance - self.id: str = data["id"] - self._type_val: int = int(data["type"]) # Store raw type for enum property - - self.guild_id: Optional[str] = data.get("guild_id") - self.name: Optional[str] = data.get("name") - self.position: Optional[int] = data.get("position") - self.permission_overwrites: List["PermissionOverwrite"] = [ - PermissionOverwrite(d) for d in data.get("permission_overwrites", []) - ] - self.nsfw: Optional[bool] = data.get("nsfw", False) - self.parent_id: Optional[str] = data.get( - "parent_id" - ) # ID of the parent category channel or thread parent - - @property - def type(self) -> ChannelType: - return ChannelType(self._type_val) - - @property - def mention(self) -> str: - return f"<#{self.id}>" - - async def delete(self, reason: Optional[str] = None): - await self._client._http.delete_channel(self.id, reason=reason) - - def __repr__(self) -> str: - return f"" - - def permission_overwrite_for( - self, target: Union["Role", "Member", str] - ) -> Optional["PermissionOverwrite"]: - """Return the :class:`PermissionOverwrite` for ``target`` if present.""" - - if isinstance(target, str): - target_id = target - else: - target_id = target.id - for overwrite in self.permission_overwrites: - if overwrite.id == target_id: - return overwrite - return None - - @staticmethod - def _apply_overwrite( - perms: Permissions, overwrite: Optional["PermissionOverwrite"] - ) -> Permissions: - if overwrite is None: - return perms - - perms &= ~Permissions(int(overwrite.deny)) - perms |= Permissions(int(overwrite.allow)) - return perms - - def permissions_for(self, member: "Member") -> Permissions: - """Resolve channel permissions for ``member``.""" - - if self.guild_id is None: - return Permissions(~0) - - if not hasattr(self._client, "get_guild"): - return Permissions(0) - - guild = self._client.get_guild(self.guild_id) - if guild is None: - return Permissions(0) - - base = Permissions(0) - - everyone = guild.get_role(guild.id) - if everyone is not None: - base |= Permissions(int(everyone.permissions)) - - for rid in member.roles: - role = guild.get_role(rid) - if role is not None: - base |= Permissions(int(role.permissions)) - - if base & Permissions.ADMINISTRATOR: - return Permissions(~0) - - # Apply @everyone overwrite - base = self._apply_overwrite(base, self.permission_overwrite_for(guild.id)) - - # Role overwrites - role_allow = Permissions(0) - role_deny = Permissions(0) - for rid in member.roles: - ow = self.permission_overwrite_for(rid) - if ow is not None: - role_allow |= Permissions(int(ow.allow)) - role_deny |= Permissions(int(ow.deny)) - - base &= ~role_deny - base |= role_allow - - # Member overwrite - base = self._apply_overwrite(base, self.permission_overwrite_for(member.id)) - - return base - - -class TextChannel(Channel): - """Represents a guild text channel or announcement channel.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - super().__init__(data, client_instance) - self.topic: Optional[str] = data.get("topic") - self.last_message_id: Optional[str] = data.get("last_message_id") - self.rate_limit_per_user: Optional[int] = data.get("rate_limit_per_user", 0) - self.default_auto_archive_duration: Optional[int] = data.get( - "default_auto_archive_duration" - ) - self.last_pin_timestamp: Optional[str] = data.get("last_pin_timestamp") - - def history( - self, - *, - limit: Optional[int] = None, - before: Optional[str] = None, - after: Optional[str] = None, - ) -> AsyncIterator["Message"]: - """Return an async iterator over this channel's messages.""" - - from .utils import message_pager - - return message_pager(self, limit=limit, before=before, after=after) - - async def send( - self, - content: Optional[str] = None, - *, - embed: Optional[Embed] = None, - embeds: Optional[List[Embed]] = None, - components: Optional[List["ActionRow"]] = None, # Added components - ) -> "Message": # Forward reference Message - if not hasattr(self._client, "send_message"): - raise NotImplementedError( - "Client.send_message is required for TextChannel.send" - ) - - return await self._client.send_message( - channel_id=self.id, - content=content, - embed=embed, - embeds=embeds, - components=components, - ) - - async def purge( - self, limit: int, *, before: "Snowflake | None" = None - ) -> List["Snowflake"]: - """Bulk delete messages from this channel.""" - - params: Dict[str, Union[int, str]] = {"limit": limit} - if before is not None: - params["before"] = before - - messages = await self._client._http.request( - "GET", f"/channels/{self.id}/messages", params=params - ) - ids = [m["id"] for m in messages] - if not ids: - return [] - - await self._client._http.bulk_delete_messages(self.id, ids) - for mid in ids: - self._client._messages.pop(mid, None) - return ids - - def __repr__(self) -> str: - return f"" - - -class VoiceChannel(Channel): - """Represents a guild voice channel or stage voice channel.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - super().__init__(data, client_instance) - self.bitrate: int = data.get("bitrate", 64000) - self.user_limit: int = data.get("user_limit", 0) - self.rtc_region: Optional[str] = data.get("rtc_region") - self.video_quality_mode: Optional[int] = data.get("video_quality_mode") - - def __repr__(self) -> str: - return f"" - - -class StageChannel(VoiceChannel): - """Represents a guild stage channel.""" - - def __repr__(self) -> str: - return f"" - - async def start_stage_instance( - self, - topic: str, - *, - privacy_level: int = 2, - reason: Optional[str] = None, - guild_scheduled_event_id: Optional[str] = None, - ) -> "StageInstance": - if not hasattr(self._client, "_http"): - raise DisagreementException("Client missing HTTP for stage instance") - - payload: Dict[str, Any] = { - "channel_id": self.id, - "topic": topic, - "privacy_level": privacy_level, - } - if guild_scheduled_event_id is not None: - payload["guild_scheduled_event_id"] = guild_scheduled_event_id - - instance = await self._client._http.start_stage_instance(payload, reason=reason) - instance._client = self._client - return instance - - async def edit_stage_instance( - self, - *, - topic: Optional[str] = None, - privacy_level: Optional[int] = None, - reason: Optional[str] = None, - ) -> "StageInstance": - if not hasattr(self._client, "_http"): - raise DisagreementException("Client missing HTTP for stage instance") - - payload: Dict[str, Any] = {} - if topic is not None: - payload["topic"] = topic - if privacy_level is not None: - payload["privacy_level"] = privacy_level - - instance = await self._client._http.edit_stage_instance( - self.id, payload, reason=reason - ) - instance._client = self._client - return instance - - async def end_stage_instance(self, *, reason: Optional[str] = None) -> None: - if not hasattr(self._client, "_http"): - raise DisagreementException("Client missing HTTP for stage instance") - - await self._client._http.end_stage_instance(self.id, reason=reason) - - -class StageInstance: - """Represents a stage instance.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ) -> None: - self._client = client_instance - self.id: str = data["id"] - self.guild_id: Optional[str] = data.get("guild_id") - self.channel_id: str = data["channel_id"] - self.topic: str = data["topic"] - self.privacy_level: int = data.get("privacy_level", 2) - self.discoverable_disabled: bool = data.get("discoverable_disabled", False) - self.guild_scheduled_event_id: Optional[str] = data.get( - "guild_scheduled_event_id" - ) - - def __repr__(self) -> str: - return f"" - - -class CategoryChannel(Channel): - """Represents a guild category channel.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - super().__init__(data, client_instance) - - @property - def channels(self) -> List[Channel]: - if not self.guild_id or not hasattr(self._client, "get_guild"): - return [] - guild = self._client.get_guild(self.guild_id) - if not guild or not hasattr( - guild, "_channels" - ): # Ensure guild and _channels exist - return [] - - categorized_channels = [ - ch - for ch in guild._channels.values() - if getattr(ch, "parent_id", None) == self.id - ] - return sorted( - categorized_channels, - key=lambda c: c.position if c.position is not None else -1, - ) - - def __repr__(self) -> str: - return f"" - - -class ThreadMetadata: - """Represents the metadata of a thread.""" - - def __init__(self, data: Dict[str, Any]): - self.archived: bool = data["archived"] - self.auto_archive_duration: int = data["auto_archive_duration"] - self.archive_timestamp: str = data["archive_timestamp"] - self.locked: bool = data["locked"] - self.invitable: Optional[bool] = data.get("invitable") - self.create_timestamp: Optional[str] = data.get("create_timestamp") - - -class Thread(TextChannel): # Threads are a specialized TextChannel - """Represents a Discord Thread.""" - - def __init__(self, data: Dict[str, Any], client_instance: "Client"): - super().__init__(data, client_instance) # Handles common text channel fields - self.owner_id: Optional[str] = data.get("owner_id") - # parent_id is already handled by base Channel init if present in data - self.message_count: Optional[int] = data.get("message_count") - self.member_count: Optional[int] = data.get("member_count") - self.thread_metadata: ThreadMetadata = ThreadMetadata(data["thread_metadata"]) - self.member: Optional["ThreadMember"] = ( - ThreadMember(data["member"], client_instance) - if data.get("member") - else None - ) - - def __repr__(self) -> str: - return ( - f"" - ) - + +class Channel: + """Base class for Discord channels.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + self._client: "Client" = client_instance + self.id: str = data["id"] + self._type_val: int = int(data["type"]) # Store raw type for enum property + + self.guild_id: Optional[str] = data.get("guild_id") + self.name: Optional[str] = data.get("name") + self.position: Optional[int] = data.get("position") + self.permission_overwrites: List["PermissionOverwrite"] = [ + PermissionOverwrite(d) for d in data.get("permission_overwrites", []) + ] + self.nsfw: Optional[bool] = data.get("nsfw", False) + self.parent_id: Optional[str] = data.get( + "parent_id" + ) # ID of the parent category channel or thread parent + + @property + def type(self) -> ChannelType: + return ChannelType(self._type_val) + + @property + def mention(self) -> str: + return f"<#{self.id}>" + + async def delete(self, reason: Optional[str] = None): + await self._client._http.delete_channel(self.id, reason=reason) + + def __repr__(self) -> str: + return f"" + + def permission_overwrite_for( + self, target: Union["Role", "Member", str] + ) -> Optional["PermissionOverwrite"]: + """Return the :class:`PermissionOverwrite` for ``target`` if present.""" + + if isinstance(target, str): + target_id = target + else: + target_id = target.id + for overwrite in self.permission_overwrites: + if overwrite.id == target_id: + return overwrite + return None + + @staticmethod + def _apply_overwrite( + perms: Permissions, overwrite: Optional["PermissionOverwrite"] + ) -> Permissions: + if overwrite is None: + return perms + + perms &= ~Permissions(int(overwrite.deny)) + perms |= Permissions(int(overwrite.allow)) + return perms + + def permissions_for(self, member: "Member") -> Permissions: + """Resolve channel permissions for ``member``.""" + + if self.guild_id is None: + return Permissions(~0) + + if not hasattr(self._client, "get_guild"): + return Permissions(0) + + guild = self._client.get_guild(self.guild_id) + if guild is None: + return Permissions(0) + + base = Permissions(0) + + everyone = guild.get_role(guild.id) + if everyone is not None: + base |= Permissions(int(everyone.permissions)) + + for rid in member.roles: + role = guild.get_role(rid) + if role is not None: + base |= Permissions(int(role.permissions)) + + if base & Permissions.ADMINISTRATOR: + return Permissions(~0) + + # Apply @everyone overwrite + base = self._apply_overwrite(base, self.permission_overwrite_for(guild.id)) + + # Role overwrites + role_allow = Permissions(0) + role_deny = Permissions(0) + for rid in member.roles: + ow = self.permission_overwrite_for(rid) + if ow is not None: + role_allow |= Permissions(int(ow.allow)) + role_deny |= Permissions(int(ow.deny)) + + base &= ~role_deny + base |= role_allow + + # Member overwrite + base = self._apply_overwrite(base, self.permission_overwrite_for(member.id)) + + return base + + +class TextChannel(Channel): + """Represents a guild text channel or announcement channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + self.topic: Optional[str] = data.get("topic") + self.last_message_id: Optional[str] = data.get("last_message_id") + self.rate_limit_per_user: Optional[int] = data.get("rate_limit_per_user", 0) + self.default_auto_archive_duration: Optional[int] = data.get( + "default_auto_archive_duration" + ) + self.last_pin_timestamp: Optional[str] = data.get("last_pin_timestamp") + + def history( + self, + *, + limit: Optional[int] = None, + before: Optional[str] = None, + after: Optional[str] = None, + ) -> AsyncIterator["Message"]: + """Return an async iterator over this channel's messages.""" + + from .utils import message_pager + + return message_pager(self, limit=limit, before=before, after=after) + + async def send( + self, + content: Optional[str] = None, + *, + embed: Optional[Embed] = None, + embeds: Optional[List[Embed]] = None, + components: Optional[List["ActionRow"]] = None, # Added components + ) -> "Message": # Forward reference Message + if not hasattr(self._client, "send_message"): + raise NotImplementedError( + "Client.send_message is required for TextChannel.send" + ) + + return await self._client.send_message( + channel_id=self.id, + content=content, + embed=embed, + embeds=embeds, + components=components, + ) + + async def purge( + self, limit: int, *, before: "Snowflake | None" = None + ) -> List["Snowflake"]: + """Bulk delete messages from this channel.""" + + params: Dict[str, Union[int, str]] = {"limit": limit} + if before is not None: + params["before"] = before + + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params + ) + ids = [m["id"] for m in messages] + if not ids: + return [] + + await self._client._http.bulk_delete_messages(self.id, ids) + for mid in ids: + self._client._messages.invalidate(mid) + return ids + + def get_partial_message(self, id: int) -> "PartialMessage": + """Returns a :class:`PartialMessage` for the given ID. + + This allows performing actions on a message without fetching it first. + + Parameters + ---------- + id: int + The ID of the message to get a partial instance of. + + Returns + ------- + PartialMessage + The partial message instance. + """ + return PartialMessage(id=str(id), channel=self) + + def __repr__(self) -> str: + return f"" + + async def pins(self) -> List["Message"]: + """|coro| + + Fetches all pinned messages in this channel. + + Returns + ------- + List[Message] + The pinned messages. + + Raises + ------ + HTTPException + Fetching the pinned messages failed. + """ + + messages_data = await self._client._http.get_pinned_messages(self.id) + return [self._client.parse_message(m) for m in messages_data] + + async def create_thread( + self, + name: str, + *, + type: ChannelType = ChannelType.PUBLIC_THREAD, + auto_archive_duration: Optional[int] = None, + invitable: Optional[bool] = None, + rate_limit_per_user: Optional[int] = None, + reason: Optional[str] = None, + ) -> "Thread": + """|coro| + + Creates a new thread in this channel. + + Parameters + ---------- + name: str + The name of the thread. + type: ChannelType + The type of thread to create. Defaults to PUBLIC_THREAD. + Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD. + auto_archive_duration: Optional[int] + The duration in minutes to automatically archive the thread after recent activity. + invitable: Optional[bool] + Whether non-moderators can invite other non-moderators to a private thread. + Only applicable to private threads. + rate_limit_per_user: Optional[int] + The number of seconds a user has to wait before sending another message. + reason: Optional[str] + The reason for creating the thread. + + Returns + ------- + Thread + The created thread. + """ + payload: Dict[str, Any] = { + "name": name, + "type": type.value, + } + if auto_archive_duration is not None: + payload["auto_archive_duration"] = auto_archive_duration + if invitable is not None and type == ChannelType.PRIVATE_THREAD: + payload["invitable"] = invitable + if rate_limit_per_user is not None: + payload["rate_limit_per_user"] = rate_limit_per_user + + data = await self._client._http.start_thread_without_message(self.id, payload) + return cast("Thread", self._client.parse_channel(data)) + + +class VoiceChannel(Channel): + """Represents a guild voice channel or stage voice channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + self.bitrate: int = data.get("bitrate", 64000) + self.user_limit: int = data.get("user_limit", 0) + self.rtc_region: Optional[str] = data.get("rtc_region") + self.video_quality_mode: Optional[int] = data.get("video_quality_mode") + + def __repr__(self) -> str: + return f"" + + +class StageChannel(VoiceChannel): + """Represents a guild stage channel.""" + + def __repr__(self) -> str: + return f"" + + async def start_stage_instance( + self, + topic: str, + *, + privacy_level: int = 2, + reason: Optional[str] = None, + guild_scheduled_event_id: Optional[str] = None, + ) -> "StageInstance": + if not hasattr(self._client, "_http"): + raise DisagreementException("Client missing HTTP for stage instance") + + payload: Dict[str, Any] = { + "channel_id": self.id, + "topic": topic, + "privacy_level": privacy_level, + } + if guild_scheduled_event_id is not None: + payload["guild_scheduled_event_id"] = guild_scheduled_event_id + + instance = await self._client._http.start_stage_instance(payload, reason=reason) + instance._client = self._client + return instance + + async def edit_stage_instance( + self, + *, + topic: Optional[str] = None, + privacy_level: Optional[int] = None, + reason: Optional[str] = None, + ) -> "StageInstance": + if not hasattr(self._client, "_http"): + raise DisagreementException("Client missing HTTP for stage instance") + + payload: Dict[str, Any] = {} + if topic is not None: + payload["topic"] = topic + if privacy_level is not None: + payload["privacy_level"] = privacy_level + + instance = await self._client._http.edit_stage_instance( + self.id, payload, reason=reason + ) + instance._client = self._client + return instance + + async def end_stage_instance(self, *, reason: Optional[str] = None) -> None: + if not hasattr(self._client, "_http"): + raise DisagreementException("Client missing HTTP for stage instance") + + await self._client._http.end_stage_instance(self.id, reason=reason) + + +class StageInstance: + """Represents a stage instance.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ) -> None: + self._client = client_instance + self.id: str = data["id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.channel_id: str = data["channel_id"] + self.topic: str = data["topic"] + self.privacy_level: int = data.get("privacy_level", 2) + self.discoverable_disabled: bool = data.get("discoverable_disabled", False) + self.guild_scheduled_event_id: Optional[str] = data.get( + "guild_scheduled_event_id" + ) + + def __repr__(self) -> str: + return f"" + + +class CategoryChannel(Channel): + """Represents a guild category channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + + @property + def channels(self) -> List[Channel]: + if not self.guild_id or not hasattr(self._client, "get_guild"): + return [] + guild = self._client.get_guild(self.guild_id) + if not guild or not hasattr( + guild, "_channels" + ): # Ensure guild and _channels exist + return [] + + categorized_channels = [ + ch + for ch in guild._channels.values() + if getattr(ch, "parent_id", None) == self.id + ] + return sorted( + categorized_channels, + key=lambda c: c.position if c.position is not None else -1, + ) + + def __repr__(self) -> str: + return f"" + + +class ThreadMetadata: + """Represents the metadata of a thread.""" + + def __init__(self, data: Dict[str, Any]): + self.archived: bool = data["archived"] + self.auto_archive_duration: int = data["auto_archive_duration"] + self.archive_timestamp: str = data["archive_timestamp"] + self.locked: bool = data["locked"] + self.invitable: Optional[bool] = data.get("invitable") + self.create_timestamp: Optional[str] = data.get("create_timestamp") + + +class Thread(TextChannel): # Threads are a specialized TextChannel + """Represents a Discord Thread.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) # Handles common text channel fields + self.owner_id: Optional[str] = data.get("owner_id") + # parent_id is already handled by base Channel init if present in data + self.message_count: Optional[int] = data.get("message_count") + self.member_count: Optional[int] = data.get("member_count") + self.thread_metadata: ThreadMetadata = ThreadMetadata(data["thread_metadata"]) + self.member: Optional["ThreadMember"] = ( + ThreadMember(data["member"], client_instance) + if data.get("member") + else None + ) + + def __repr__(self) -> str: + return ( + f"" + ) + async def join(self) -> None: """|coro|