From a41a3019273c9543b21355c6ae5701fb09789c7e Mon Sep 17 00:00:00 2001 From: Slipstreamm Date: Sat, 14 Jun 2025 23:49:33 -0600 Subject: [PATCH] fix(core): Improve client ready state and user parsing The `_ready_event` is now set in `GatewayClient` immediately after receiving the `READY` payload, before dispatching `on_ready` to user code. This ensures `Client.wait_until_ready()` and `Client.is_ready()` accurately reflect the client's state before dependent user logic executes. This change allows simplifying `Client.sync_commands` by removing redundant `wait_until_ready()` calls and `application_id` checks, as the application ID is guaranteed to be available upon READY. Additionally, `User` model initialization is improved to correctly handle nested user data found in certain API payloads (e.g., within `member` objects in events like `PresenceUpdate`). Add `SOUNDBOARD` and `VIDEO_QUALITY_720_60FPS` to `GuildFeature` enum. --- disagreement/__init__.py | 2 +- disagreement/client.py | 176 +++++++++++++-------------- disagreement/enums.py | 2 + disagreement/gateway.py | 4 + disagreement/models.py | 252 +++++++++++++++++++-------------------- pyproject.toml | 2 +- 6 files changed, 217 insertions(+), 221 deletions(-) diff --git a/disagreement/__init__.py b/disagreement/__init__.py index 0a3c12d..01b306f 100644 --- a/disagreement/__init__.py +++ b/disagreement/__init__.py @@ -12,7 +12,7 @@ __title__ = "disagreement" __author__ = "Slipstream" __license__ = "BSD 3-Clause License" __copyright__ = "Copyright 2025 Slipstream" -__version__ = "0.8.0" +__version__ = "0.8.1" from .client import Client, AutoShardedClient from .models import ( diff --git a/disagreement/client.py b/disagreement/client.py index 502af57..49c5db8 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -4,18 +4,18 @@ The main Client class for interacting with the Discord API. import asyncio import signal -from typing import ( - Optional, - Callable, - Any, - TYPE_CHECKING, - Awaitable, - AsyncIterator, - Union, - List, - Dict, - cast, -) +from typing import ( + Optional, + Callable, + Any, + TYPE_CHECKING, + Awaitable, + AsyncIterator, + Union, + List, + Dict, + cast, +) from types import ModuleType from .http import HTTPClient @@ -263,16 +263,16 @@ class Client: raise except DisagreementException as e: # Includes GatewayException print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}") - if attempt < max_retries - 1: - print(f"Retrying in {retry_delay} seconds...") - await asyncio.sleep(retry_delay) - retry_delay = min( - retry_delay * 2, 60 - ) # Exponential backoff up to 60s - else: - print("Max connection retries reached. Giving up.") - await self.close() # Ensure cleanup - raise + if attempt < max_retries - 1: + print(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + retry_delay = min( + retry_delay * 2, 60 + ) # Exponential backoff up to 60s + else: + print("Max connection retries reached. Giving up.") + await self.close() # Ensure cleanup + raise if max_retries == 0: # If max_retries was 0, means no retries attempted raise DisagreementException("Connection failed with 0 retries allowed.") @@ -530,29 +530,29 @@ class Client: print(f"Message: {message.content}") """ - def decorator( - coro: Callable[..., Awaitable[None]], - ) -> Callable[..., Awaitable[None]]: - if not asyncio.iscoroutinefunction(coro): - raise TypeError("Event registered must be a coroutine function.") - self._event_dispatcher.register(event_name.upper(), coro) - return coro - - return decorator - - def add_listener( - self, event_name: str, coro: Callable[..., Awaitable[None]] - ) -> None: - """Register ``coro`` to listen for ``event_name``.""" - - self._event_dispatcher.register(event_name, coro) - - def remove_listener( - self, event_name: str, coro: Callable[..., Awaitable[None]] - ) -> None: - """Remove ``coro`` from ``event_name`` listeners.""" - - self._event_dispatcher.unregister(event_name, coro) + def decorator( + coro: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + if not asyncio.iscoroutinefunction(coro): + raise TypeError("Event registered must be a coroutine function.") + self._event_dispatcher.register(event_name.upper(), coro) + return coro + + return decorator + + def add_listener( + self, event_name: str, coro: Callable[..., Awaitable[None]] + ) -> None: + """Register ``coro`` to listen for ``event_name``.""" + + self._event_dispatcher.register(event_name, coro) + + def remove_listener( + self, event_name: str, coro: Callable[..., Awaitable[None]] + ) -> None: + """Remove ``coro`` from ``event_name`` listeners.""" + + self._event_dispatcher.unregister(event_name, coro) async def _process_message_for_commands(self, message: "Message") -> None: """Internal listener to process messages for commands.""" @@ -755,7 +755,7 @@ class Client: """Parses user data and returns a User object, updating cache.""" from .models import User # Ensure User model is available - user = User(data, client_instance=self) + user = User(data, client_instance=self) self._users.set(user.id, user) # Cache the user return user @@ -1011,10 +1011,10 @@ class Client: # --- API Methods --- - async def send_message( - self, - channel_id: str, - content: Optional[str] = None, + async def send_message( + self, + channel_id: str, + content: Optional[str] = None, *, # Make additional params keyword-only tts: bool = False, embed: Optional["Embed"] = None, @@ -1106,24 +1106,24 @@ class Client: view.message_id = message_id self._views[message_id] = view - return self.parse_message(message_data) - - async def create_dm(self, user_id: Snowflake) -> "DMChannel": - """|coro| Create or fetch a DM channel with a user.""" - from .models import DMChannel - - dm_data = await self._http.create_dm(user_id) - return cast(DMChannel, self.parse_channel(dm_data)) - - async def send_dm( - self, - user_id: Snowflake, - content: Optional[str] = None, - **kwargs: Any, - ) -> "Message": - """|coro| Convenience method to send a direct message to a user.""" - channel = await self.create_dm(user_id) - return await self.send_message(channel.id, content=content, **kwargs) + return self.parse_message(message_data) + + async def create_dm(self, user_id: Snowflake) -> "DMChannel": + """|coro| Create or fetch a DM channel with a user.""" + from .models import DMChannel + + dm_data = await self._http.create_dm(user_id) + return cast(DMChannel, self.parse_channel(dm_data)) + + async def send_dm( + self, + user_id: Snowflake, + content: Optional[str] = None, + **kwargs: Any, + ) -> "Message": + """|coro| Convenience method to send a direct message to a user.""" + channel = await self.create_dm(user_id) + return await self.send_message(channel.id, content=content, **kwargs) def typing(self, channel_id: str) -> Typing: """Return a context manager to show a typing indicator in a channel.""" @@ -1343,8 +1343,8 @@ class Client: return self._messages.get(message_id) - async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]: - """Fetches a guild by ID from Discord and caches it.""" + async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]: + """Fetches a guild by ID from Discord and caches it.""" if self._closed: raise DisagreementException("Client is closed.") @@ -1358,19 +1358,19 @@ class Client: return self.parse_guild(guild_data) except DisagreementException as e: print(f"Failed to fetch guild {guild_id}: {e}") - return None - - async def fetch_guilds(self) -> List["Guild"]: - """Fetch all guilds the current user is in.""" - - if self._closed: - raise DisagreementException("Client is closed.") - - data = await self._http.get_current_user_guilds() - guilds: List["Guild"] = [] - for guild_data in data: - guilds.append(self.parse_guild(guild_data)) - return guilds + return None + + async def fetch_guilds(self) -> List["Guild"]: + """Fetch all guilds the current user is in.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + data = await self._http.get_current_user_guilds() + guilds: List["Guild"] = [] + for guild_data in data: + guilds.append(self.parse_guild(guild_data)) + return guilds async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]: """Fetches a channel from Discord by its ID and updates the cache.""" @@ -1665,16 +1665,6 @@ class Client: "Ensure the client is connected and READY." ) return - if not self.is_ready(): - print( - "Warning: Client is not ready. Waiting for client to be ready before syncing commands." - ) - await self.wait_until_ready() - if not self.application_id: - print( - "Error: application_id still not set after client is ready. Cannot sync commands." - ) - return await self.app_command_handler.sync_commands( application_id=self.application_id, guild_id=guild_id diff --git a/disagreement/enums.py b/disagreement/enums.py index 7b4d6c7..eb49683 100644 --- a/disagreement/enums.py +++ b/disagreement/enums.py @@ -268,6 +268,8 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum VERIFIED = "VERIFIED" VIP_REGIONS = "VIP_REGIONS" WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED" + SOUNDBOARD = "SOUNDBOARD" + VIDEO_QUALITY_720_60FPS = "VIDEO_QUALITY_720_60FPS" # Add more as they become known or needed # This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work diff --git a/disagreement/gateway.py b/disagreement/gateway.py index 4559e6f..a39ed58 100644 --- a/disagreement/gateway.py +++ b/disagreement/gateway.py @@ -334,6 +334,10 @@ class GatewayClient: self._resume_gateway_url, ) + # The client is now ready for operations. Set the event before dispatching to user code. + self._client_instance._ready_event.set() + logger.info("Client is now marked as ready.") + await self._dispatcher.dispatch(event_name, raw_event_d_payload) elif event_name == "GUILD_MEMBERS_CHUNK": if isinstance(raw_event_d_payload, dict): diff --git a/disagreement/models.py b/disagreement/models.py index 8ac6b66..a38cbcb 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -46,16 +46,18 @@ if TYPE_CHECKING: from .components import component_factory -class User: - """Represents a Discord User.""" - - def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None: - self._client = client_instance - self.id: str = data["id"] - self.username: Optional[str] = data.get("username") - self.discriminator: Optional[str] = data.get("discriminator") - self.bot: bool = data.get("bot", False) - self.avatar: Optional[str] = data.get("avatar") +class User: + """Represents a Discord User.""" + + def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None: + self._client = client_instance + if "id" not in data and "user" in data: + data = data["user"] + self.id: str = data["id"] + self.username: Optional[str] = data.get("username") + self.discriminator: Optional[str] = data.get("discriminator") + self.bot: bool = data.get("bot", False) + self.avatar: Optional[str] = data.get("avatar") @property def mention(self) -> str: @@ -63,23 +65,23 @@ class User: return f"<@{self.id}>" def __repr__(self) -> str: - username = self.username or "Unknown" - disc = self.discriminator or "????" - return f"" - - async def send( - self, - content: Optional[str] = None, - *, - client: Optional["Client"] = None, - **kwargs: Any, - ) -> "Message": - """Send a direct message to this user.""" - - target_client = client or self._client - if target_client is None: - raise DisagreementException("User.send requires a Client instance") - return await target_client.send_dm(self.id, content=content, **kwargs) + username = self.username or "Unknown" + disc = self.discriminator or "????" + return f"" + + async def send( + self, + content: Optional[str] = None, + *, + client: Optional["Client"] = None, + **kwargs: Any, + ) -> "Message": + """Send a direct message to this user.""" + + target_client = client or self._client + if target_client is None: + raise DisagreementException("User.send requires a Client instance") + return await target_client.send_dm(self.id, content=content, **kwargs) class Message: @@ -105,7 +107,7 @@ class Message: self.id: str = data["id"] self.channel_id: str = data["channel_id"] self.guild_id: Optional[str] = data.get("guild_id") - self.author: User = User(data["author"], client_instance) + self.author: User = User(data["author"], client_instance) self.content: str = data["content"] self.timestamp: str = data["timestamp"] if data.get("components"): @@ -115,21 +117,21 @@ class Message: ] else: self.components = None - self.attachments: List[Attachment] = [ - Attachment(a) for a in data.get("attachments", []) - ] - self.pinned: bool = data.get("pinned", False) - # Add other fields as needed, e.g., attachments, embeds, reactions, etc. - # self.mentions: List[User] = [User(u) for u in data.get("mentions", [])] - # self.mention_roles: List[str] = data.get("mention_roles", []) - # self.mention_everyone: bool = data.get("mention_everyone", False) - - @property - def jump_url(self) -> str: - """Return a URL that jumps to this message in the Discord client.""" - - guild_or_dm = self.guild_id or "@me" - return f"https://discord.com/channels/{guild_or_dm}/{self.channel_id}/{self.id}" + self.attachments: List[Attachment] = [ + Attachment(a) for a in data.get("attachments", []) + ] + self.pinned: bool = data.get("pinned", False) + # Add other fields as needed, e.g., attachments, embeds, reactions, etc. + # self.mentions: List[User] = [User(u) for u in data.get("mentions", [])] + # self.mention_roles: List[str] = data.get("mention_roles", []) + # self.mention_everyone: bool = data.get("mention_everyone", False) + + @property + def jump_url(self) -> str: + """Return a URL that jumps to this message in the Discord client.""" + + guild_or_dm = self.guild_id or "@me" + return f"https://discord.com/channels/{guild_or_dm}/{self.channel_id}/{self.id}" @property def clean_content(self) -> str: @@ -203,14 +205,14 @@ class Message: ValueError: If both `embed` and `embeds` are provided. """ # Determine allowed mentions for the reply - if mention_author is None: - mention_author = getattr(self._client, "mention_replies", False) - - if allowed_mentions is None: - allowed_mentions = dict(getattr(self._client, "allowed_mentions", {}) or {}) - else: - allowed_mentions = dict(allowed_mentions) - allowed_mentions.setdefault("replied_user", mention_author) + if mention_author is None: + mention_author = getattr(self._client, "mention_replies", False) + + if allowed_mentions is None: + allowed_mentions = dict(getattr(self._client, "allowed_mentions", {}) or {}) + else: + allowed_mentions = dict(allowed_mentions) + allowed_mentions.setdefault("replied_user", mention_author) # Client.send_message is already updated to handle these parameters return await self._client.send_message( @@ -640,31 +642,31 @@ class File: self.data = data -class AllowedMentions: +class AllowedMentions: """Represents allowed mentions for a message or interaction response.""" - def __init__(self, data: Dict[str, Any]): - self.parse: List[str] = data.get("parse", []) - self.roles: List[str] = data.get("roles", []) - self.users: List[str] = data.get("users", []) - self.replied_user: bool = data.get("replied_user", False) - - @classmethod - def all(cls) -> "AllowedMentions": - """Return an instance allowing all mention types.""" - - return cls( - { - "parse": ["users", "roles", "everyone"], - "replied_user": True, - } - ) - - @classmethod - def none(cls) -> "AllowedMentions": - """Return an instance disallowing all mentions.""" - - return cls({"parse": [], "replied_user": False}) + def __init__(self, data: Dict[str, Any]): + self.parse: List[str] = data.get("parse", []) + self.roles: List[str] = data.get("roles", []) + self.users: List[str] = data.get("users", []) + self.replied_user: bool = data.get("replied_user", False) + + @classmethod + def all(cls) -> "AllowedMentions": + """Return an instance allowing all mention types.""" + + return cls( + { + "parse": ["users", "roles", "everyone"], + "replied_user": True, + } + ) + + @classmethod + def none(cls) -> "AllowedMentions": + """Return an instance disallowing all mentions.""" + + return cls({"parse": [], "replied_user": False}) def to_dict(self) -> Dict[str, Any]: payload: Dict[str, Any] = {"parse": self.parse} @@ -752,12 +754,10 @@ class Member(User): # Member inherits from User ) # Pass user_data or data if user_data is empty self.nick: Optional[str] = data.get("nick") - self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash - self.roles: List[str] = data.get("roles", []) # List of role IDs - self.joined_at: str = data["joined_at"] # ISO8601 timestamp - self.premium_since: Optional[str] = data.get( - "premium_since" - ) # ISO8601 timestamp + self.avatar: Optional[str] = data.get("avatar") + self.roles: List[str] = data.get("roles", []) + self.joined_at: str = data["joined_at"] + self.premium_since: Optional[str] = data.get("premium_since") self.deaf: bool = data.get("deaf", False) self.mute: bool = data.get("mute", False) self.pending: bool = data.get("pending", False) @@ -782,10 +782,10 @@ class Member(User): # Member inherits from User return f"" @property - def display_name(self) -> str: - """Return the nickname if set, otherwise the username.""" - - return self.nick or self.username or "" + def display_name(self) -> str: + """Return the nickname if set, otherwise the username.""" + + return self.nick or self.username or "" async def kick(self, *, reason: Optional[str] = None) -> None: if not self.guild_id or not self._client: @@ -1192,13 +1192,13 @@ class Guild: The matching member if found, otherwise ``None``. """ - lowered = name.lower() - for member in self._members.values(): - if member.username and member.username.lower() == lowered: - return member - if member.nick and member.nick.lower() == lowered: - return member - return None + lowered = name.lower() + for member in self._members.values(): + if member.username and member.username.lower() == lowered: + return member + if member.nick and member.nick.lower() == lowered: + return member + return None def get_role(self, role_id: str) -> Optional[Role]: return next((role for role in self.roles if role.id == role_id), None) @@ -2480,7 +2480,7 @@ class PresenceUpdate: self, data: Dict[str, Any], client_instance: Optional["Client"] = None ): self._client = client_instance - self.user = User(data["user"], client_instance) + self.user = User(data["user"], client_instance) self.guild_id: Optional[str] = data.get("guild_id") self.status: Optional[str] = data.get("status") self.activities: List[Activity] = [] @@ -2500,7 +2500,7 @@ class PresenceUpdate: return f"" -class TypingStart: +class TypingStart: """Represents a TYPING_START event.""" def __init__( @@ -2513,39 +2513,39 @@ class TypingStart: self.timestamp: int = data["timestamp"] self.member: Optional[Member] = ( Member(data["member"], client_instance) if data.get("member") else None - ) - - def __repr__(self) -> str: - return f"" - - -class VoiceStateUpdate: - """Represents a VOICE_STATE_UPDATE event.""" - - def __init__( - self, data: Dict[str, Any], client_instance: Optional["Client"] = None - ): - self._client = client_instance - self.guild_id: Optional[str] = data.get("guild_id") - self.channel_id: Optional[str] = data.get("channel_id") - self.user_id: str = data["user_id"] - self.member: Optional[Member] = ( - Member(data["member"], client_instance) if data.get("member") else None - ) - self.session_id: str = data["session_id"] - self.deaf: bool = data.get("deaf", False) - self.mute: bool = data.get("mute", False) - self.self_deaf: bool = data.get("self_deaf", False) - self.self_mute: bool = data.get("self_mute", False) - self.self_stream: Optional[bool] = data.get("self_stream") - self.self_video: bool = data.get("self_video", False) - self.suppress: bool = data.get("suppress", False) - - def __repr__(self) -> str: - return ( - f"" - ) + ) + + def __repr__(self) -> str: + return f"" + + +class VoiceStateUpdate: + """Represents a VOICE_STATE_UPDATE event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.guild_id: Optional[str] = data.get("guild_id") + self.channel_id: Optional[str] = data.get("channel_id") + self.user_id: str = data["user_id"] + self.member: Optional[Member] = ( + Member(data["member"], client_instance) if data.get("member") else None + ) + self.session_id: str = data["session_id"] + self.deaf: bool = data.get("deaf", False) + self.mute: bool = data.get("mute", False) + self.self_deaf: bool = data.get("self_deaf", False) + self.self_mute: bool = data.get("self_mute", False) + self.self_stream: Optional[bool] = data.get("self_stream") + self.self_video: bool = data.get("self_video", False) + self.suppress: bool = data.get("suppress", False) + + def __repr__(self) -> str: + return ( + f"" + ) class Reaction: diff --git a/pyproject.toml b/pyproject.toml index 7d88a82..d0d5268 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "disagreement" -version = "0.8.0" +version = "0.8.1" description = "A Python library for the Discord API." readme = "README.md" requires-python = ">=3.10"