fix(core): Improve client ready state and user parsing

The `_ready_event` is now set in `GatewayClient` immediately after
receiving the `READY` payload, before dispatching `on_ready` to user code.
This ensures `Client.wait_until_ready()` and `Client.is_ready()`
accurately reflect the client's state before dependent user logic executes.

This change allows simplifying `Client.sync_commands` by removing
redundant `wait_until_ready()` calls and `application_id` checks,
as the application ID is guaranteed to be available upon READY.

Additionally, `User` model initialization is improved to correctly handle
nested user data found in certain API payloads (e.g., within `member`
objects in events like `PresenceUpdate`).

Add `SOUNDBOARD` and `VIDEO_QUALITY_720_60FPS` to `GuildFeature` enum.
This commit is contained in:
Slipstreamm 2025-06-14 23:49:33 -06:00
parent bd16b1c026
commit a41a301927
6 changed files with 217 additions and 221 deletions

View File

@ -12,7 +12,7 @@ __title__ = "disagreement"
__author__ = "Slipstream"
__license__ = "BSD 3-Clause License"
__copyright__ = "Copyright 2025 Slipstream"
__version__ = "0.8.0"
__version__ = "0.8.1"
from .client import Client, AutoShardedClient
from .models import (

View File

@ -4,18 +4,18 @@ The main Client class for interacting with the Discord API.
import asyncio
import signal
from typing import (
Optional,
Callable,
Any,
TYPE_CHECKING,
Awaitable,
AsyncIterator,
Union,
List,
Dict,
cast,
)
from typing import (
Optional,
Callable,
Any,
TYPE_CHECKING,
Awaitable,
AsyncIterator,
Union,
List,
Dict,
cast,
)
from types import ModuleType
from .http import HTTPClient
@ -263,16 +263,16 @@ class Client:
raise
except DisagreementException as e: # Includes GatewayException
print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
print(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
retry_delay = min(
retry_delay * 2, 60
) # Exponential backoff up to 60s
else:
print("Max connection retries reached. Giving up.")
await self.close() # Ensure cleanup
raise
if attempt < max_retries - 1:
print(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
retry_delay = min(
retry_delay * 2, 60
) # Exponential backoff up to 60s
else:
print("Max connection retries reached. Giving up.")
await self.close() # Ensure cleanup
raise
if max_retries == 0: # If max_retries was 0, means no retries attempted
raise DisagreementException("Connection failed with 0 retries allowed.")
@ -530,29 +530,29 @@ class Client:
print(f"Message: {message.content}")
"""
def decorator(
coro: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(coro):
raise TypeError("Event registered must be a coroutine function.")
self._event_dispatcher.register(event_name.upper(), coro)
return coro
return decorator
def add_listener(
self, event_name: str, coro: Callable[..., Awaitable[None]]
) -> None:
"""Register ``coro`` to listen for ``event_name``."""
self._event_dispatcher.register(event_name, coro)
def remove_listener(
self, event_name: str, coro: Callable[..., Awaitable[None]]
) -> None:
"""Remove ``coro`` from ``event_name`` listeners."""
self._event_dispatcher.unregister(event_name, coro)
def decorator(
coro: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(coro):
raise TypeError("Event registered must be a coroutine function.")
self._event_dispatcher.register(event_name.upper(), coro)
return coro
return decorator
def add_listener(
self, event_name: str, coro: Callable[..., Awaitable[None]]
) -> None:
"""Register ``coro`` to listen for ``event_name``."""
self._event_dispatcher.register(event_name, coro)
def remove_listener(
self, event_name: str, coro: Callable[..., Awaitable[None]]
) -> None:
"""Remove ``coro`` from ``event_name`` listeners."""
self._event_dispatcher.unregister(event_name, coro)
async def _process_message_for_commands(self, message: "Message") -> None:
"""Internal listener to process messages for commands."""
@ -755,7 +755,7 @@ class Client:
"""Parses user data and returns a User object, updating cache."""
from .models import User # Ensure User model is available
user = User(data, client_instance=self)
user = User(data, client_instance=self)
self._users.set(user.id, user) # Cache the user
return user
@ -1011,10 +1011,10 @@ class Client:
# --- API Methods ---
async def send_message(
self,
channel_id: str,
content: Optional[str] = None,
async def send_message(
self,
channel_id: str,
content: Optional[str] = None,
*, # Make additional params keyword-only
tts: bool = False,
embed: Optional["Embed"] = None,
@ -1106,24 +1106,24 @@ class Client:
view.message_id = message_id
self._views[message_id] = view
return self.parse_message(message_data)
async def create_dm(self, user_id: Snowflake) -> "DMChannel":
"""|coro| Create or fetch a DM channel with a user."""
from .models import DMChannel
dm_data = await self._http.create_dm(user_id)
return cast(DMChannel, self.parse_channel(dm_data))
async def send_dm(
self,
user_id: Snowflake,
content: Optional[str] = None,
**kwargs: Any,
) -> "Message":
"""|coro| Convenience method to send a direct message to a user."""
channel = await self.create_dm(user_id)
return await self.send_message(channel.id, content=content, **kwargs)
return self.parse_message(message_data)
async def create_dm(self, user_id: Snowflake) -> "DMChannel":
"""|coro| Create or fetch a DM channel with a user."""
from .models import DMChannel
dm_data = await self._http.create_dm(user_id)
return cast(DMChannel, self.parse_channel(dm_data))
async def send_dm(
self,
user_id: Snowflake,
content: Optional[str] = None,
**kwargs: Any,
) -> "Message":
"""|coro| Convenience method to send a direct message to a user."""
channel = await self.create_dm(user_id)
return await self.send_message(channel.id, content=content, **kwargs)
def typing(self, channel_id: str) -> Typing:
"""Return a context manager to show a typing indicator in a channel."""
@ -1343,8 +1343,8 @@ class Client:
return self._messages.get(message_id)
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
"""Fetches a guild by ID from Discord and caches it."""
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
"""Fetches a guild by ID from Discord and caches it."""
if self._closed:
raise DisagreementException("Client is closed.")
@ -1358,19 +1358,19 @@ class Client:
return self.parse_guild(guild_data)
except DisagreementException as e:
print(f"Failed to fetch guild {guild_id}: {e}")
return None
async def fetch_guilds(self) -> List["Guild"]:
"""Fetch all guilds the current user is in."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.get_current_user_guilds()
guilds: List["Guild"] = []
for guild_data in data:
guilds.append(self.parse_guild(guild_data))
return guilds
return None
async def fetch_guilds(self) -> List["Guild"]:
"""Fetch all guilds the current user is in."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.get_current_user_guilds()
guilds: List["Guild"] = []
for guild_data in data:
guilds.append(self.parse_guild(guild_data))
return guilds
async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]:
"""Fetches a channel from Discord by its ID and updates the cache."""
@ -1665,16 +1665,6 @@ class Client:
"Ensure the client is connected and READY."
)
return
if not self.is_ready():
print(
"Warning: Client is not ready. Waiting for client to be ready before syncing commands."
)
await self.wait_until_ready()
if not self.application_id:
print(
"Error: application_id still not set after client is ready. Cannot sync commands."
)
return
await self.app_command_handler.sync_commands(
application_id=self.application_id, guild_id=guild_id

View File

@ -268,6 +268,8 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum
VERIFIED = "VERIFIED"
VIP_REGIONS = "VIP_REGIONS"
WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED"
SOUNDBOARD = "SOUNDBOARD"
VIDEO_QUALITY_720_60FPS = "VIDEO_QUALITY_720_60FPS"
# Add more as they become known or needed
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work

View File

@ -334,6 +334,10 @@ class GatewayClient:
self._resume_gateway_url,
)
# The client is now ready for operations. Set the event before dispatching to user code.
self._client_instance._ready_event.set()
logger.info("Client is now marked as ready.")
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
elif event_name == "GUILD_MEMBERS_CHUNK":
if isinstance(raw_event_d_payload, dict):

View File

@ -46,16 +46,18 @@ if TYPE_CHECKING:
from .components import component_factory
class User:
"""Represents a Discord User."""
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
self._client = client_instance
self.id: str = data["id"]
self.username: Optional[str] = data.get("username")
self.discriminator: Optional[str] = data.get("discriminator")
self.bot: bool = data.get("bot", False)
self.avatar: Optional[str] = data.get("avatar")
class User:
"""Represents a Discord User."""
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
self._client = client_instance
if "id" not in data and "user" in data:
data = data["user"]
self.id: str = data["id"]
self.username: Optional[str] = data.get("username")
self.discriminator: Optional[str] = data.get("discriminator")
self.bot: bool = data.get("bot", False)
self.avatar: Optional[str] = data.get("avatar")
@property
def mention(self) -> str:
@ -63,23 +65,23 @@ class User:
return f"<@{self.id}>"
def __repr__(self) -> str:
username = self.username or "Unknown"
disc = self.discriminator or "????"
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
async def send(
self,
content: Optional[str] = None,
*,
client: Optional["Client"] = None,
**kwargs: Any,
) -> "Message":
"""Send a direct message to this user."""
target_client = client or self._client
if target_client is None:
raise DisagreementException("User.send requires a Client instance")
return await target_client.send_dm(self.id, content=content, **kwargs)
username = self.username or "Unknown"
disc = self.discriminator or "????"
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
async def send(
self,
content: Optional[str] = None,
*,
client: Optional["Client"] = None,
**kwargs: Any,
) -> "Message":
"""Send a direct message to this user."""
target_client = client or self._client
if target_client is None:
raise DisagreementException("User.send requires a Client instance")
return await target_client.send_dm(self.id, content=content, **kwargs)
class Message:
@ -105,7 +107,7 @@ class Message:
self.id: str = data["id"]
self.channel_id: str = data["channel_id"]
self.guild_id: Optional[str] = data.get("guild_id")
self.author: User = User(data["author"], client_instance)
self.author: User = User(data["author"], client_instance)
self.content: str = data["content"]
self.timestamp: str = data["timestamp"]
if data.get("components"):
@ -115,21 +117,21 @@ class Message:
]
else:
self.components = None
self.attachments: List[Attachment] = [
Attachment(a) for a in data.get("attachments", [])
]
self.pinned: bool = data.get("pinned", False)
# Add other fields as needed, e.g., attachments, embeds, reactions, etc.
# self.mentions: List[User] = [User(u) for u in data.get("mentions", [])]
# self.mention_roles: List[str] = data.get("mention_roles", [])
# self.mention_everyone: bool = data.get("mention_everyone", False)
@property
def jump_url(self) -> str:
"""Return a URL that jumps to this message in the Discord client."""
guild_or_dm = self.guild_id or "@me"
return f"https://discord.com/channels/{guild_or_dm}/{self.channel_id}/{self.id}"
self.attachments: List[Attachment] = [
Attachment(a) for a in data.get("attachments", [])
]
self.pinned: bool = data.get("pinned", False)
# Add other fields as needed, e.g., attachments, embeds, reactions, etc.
# self.mentions: List[User] = [User(u) for u in data.get("mentions", [])]
# self.mention_roles: List[str] = data.get("mention_roles", [])
# self.mention_everyone: bool = data.get("mention_everyone", False)
@property
def jump_url(self) -> str:
"""Return a URL that jumps to this message in the Discord client."""
guild_or_dm = self.guild_id or "@me"
return f"https://discord.com/channels/{guild_or_dm}/{self.channel_id}/{self.id}"
@property
def clean_content(self) -> str:
@ -203,14 +205,14 @@ class Message:
ValueError: If both `embed` and `embeds` are provided.
"""
# Determine allowed mentions for the reply
if mention_author is None:
mention_author = getattr(self._client, "mention_replies", False)
if allowed_mentions is None:
allowed_mentions = dict(getattr(self._client, "allowed_mentions", {}) or {})
else:
allowed_mentions = dict(allowed_mentions)
allowed_mentions.setdefault("replied_user", mention_author)
if mention_author is None:
mention_author = getattr(self._client, "mention_replies", False)
if allowed_mentions is None:
allowed_mentions = dict(getattr(self._client, "allowed_mentions", {}) or {})
else:
allowed_mentions = dict(allowed_mentions)
allowed_mentions.setdefault("replied_user", mention_author)
# Client.send_message is already updated to handle these parameters
return await self._client.send_message(
@ -640,31 +642,31 @@ class File:
self.data = data
class AllowedMentions:
class AllowedMentions:
"""Represents allowed mentions for a message or interaction response."""
def __init__(self, data: Dict[str, Any]):
self.parse: List[str] = data.get("parse", [])
self.roles: List[str] = data.get("roles", [])
self.users: List[str] = data.get("users", [])
self.replied_user: bool = data.get("replied_user", False)
@classmethod
def all(cls) -> "AllowedMentions":
"""Return an instance allowing all mention types."""
return cls(
{
"parse": ["users", "roles", "everyone"],
"replied_user": True,
}
)
@classmethod
def none(cls) -> "AllowedMentions":
"""Return an instance disallowing all mentions."""
return cls({"parse": [], "replied_user": False})
def __init__(self, data: Dict[str, Any]):
self.parse: List[str] = data.get("parse", [])
self.roles: List[str] = data.get("roles", [])
self.users: List[str] = data.get("users", [])
self.replied_user: bool = data.get("replied_user", False)
@classmethod
def all(cls) -> "AllowedMentions":
"""Return an instance allowing all mention types."""
return cls(
{
"parse": ["users", "roles", "everyone"],
"replied_user": True,
}
)
@classmethod
def none(cls) -> "AllowedMentions":
"""Return an instance disallowing all mentions."""
return cls({"parse": [], "replied_user": False})
def to_dict(self) -> Dict[str, Any]:
payload: Dict[str, Any] = {"parse": self.parse}
@ -752,12 +754,10 @@ class Member(User): # Member inherits from User
) # Pass user_data or data if user_data is empty
self.nick: Optional[str] = data.get("nick")
self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash
self.roles: List[str] = data.get("roles", []) # List of role IDs
self.joined_at: str = data["joined_at"] # ISO8601 timestamp
self.premium_since: Optional[str] = data.get(
"premium_since"
) # ISO8601 timestamp
self.avatar: Optional[str] = data.get("avatar")
self.roles: List[str] = data.get("roles", [])
self.joined_at: str = data["joined_at"]
self.premium_since: Optional[str] = data.get("premium_since")
self.deaf: bool = data.get("deaf", False)
self.mute: bool = data.get("mute", False)
self.pending: bool = data.get("pending", False)
@ -782,10 +782,10 @@ class Member(User): # Member inherits from User
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
@property
def display_name(self) -> str:
"""Return the nickname if set, otherwise the username."""
return self.nick or self.username or ""
def display_name(self) -> str:
"""Return the nickname if set, otherwise the username."""
return self.nick or self.username or ""
async def kick(self, *, reason: Optional[str] = None) -> None:
if not self.guild_id or not self._client:
@ -1192,13 +1192,13 @@ class Guild:
The matching member if found, otherwise ``None``.
"""
lowered = name.lower()
for member in self._members.values():
if member.username and member.username.lower() == lowered:
return member
if member.nick and member.nick.lower() == lowered:
return member
return None
lowered = name.lower()
for member in self._members.values():
if member.username and member.username.lower() == lowered:
return member
if member.nick and member.nick.lower() == lowered:
return member
return None
def get_role(self, role_id: str) -> Optional[Role]:
return next((role for role in self.roles if role.id == role_id), None)
@ -2480,7 +2480,7 @@ class PresenceUpdate:
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
):
self._client = client_instance
self.user = User(data["user"], client_instance)
self.user = User(data["user"], client_instance)
self.guild_id: Optional[str] = data.get("guild_id")
self.status: Optional[str] = data.get("status")
self.activities: List[Activity] = []
@ -2500,7 +2500,7 @@ class PresenceUpdate:
return f"<PresenceUpdate user_id='{self.user.id}' guild_id='{self.guild_id}' status='{self.status}'>"
class TypingStart:
class TypingStart:
"""Represents a TYPING_START event."""
def __init__(
@ -2513,39 +2513,39 @@ class TypingStart:
self.timestamp: int = data["timestamp"]
self.member: Optional[Member] = (
Member(data["member"], client_instance) if data.get("member") else None
)
def __repr__(self) -> str:
return f"<TypingStart channel_id='{self.channel_id}' user_id='{self.user_id}'>"
class VoiceStateUpdate:
"""Represents a VOICE_STATE_UPDATE event."""
def __init__(
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
):
self._client = client_instance
self.guild_id: Optional[str] = data.get("guild_id")
self.channel_id: Optional[str] = data.get("channel_id")
self.user_id: str = data["user_id"]
self.member: Optional[Member] = (
Member(data["member"], client_instance) if data.get("member") else None
)
self.session_id: str = data["session_id"]
self.deaf: bool = data.get("deaf", False)
self.mute: bool = data.get("mute", False)
self.self_deaf: bool = data.get("self_deaf", False)
self.self_mute: bool = data.get("self_mute", False)
self.self_stream: Optional[bool] = data.get("self_stream")
self.self_video: bool = data.get("self_video", False)
self.suppress: bool = data.get("suppress", False)
def __repr__(self) -> str:
return (
f"<VoiceStateUpdate guild_id='{self.guild_id}' user_id='{self.user_id}' "
f"channel_id='{self.channel_id}'>"
)
)
def __repr__(self) -> str:
return f"<TypingStart channel_id='{self.channel_id}' user_id='{self.user_id}'>"
class VoiceStateUpdate:
"""Represents a VOICE_STATE_UPDATE event."""
def __init__(
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
):
self._client = client_instance
self.guild_id: Optional[str] = data.get("guild_id")
self.channel_id: Optional[str] = data.get("channel_id")
self.user_id: str = data["user_id"]
self.member: Optional[Member] = (
Member(data["member"], client_instance) if data.get("member") else None
)
self.session_id: str = data["session_id"]
self.deaf: bool = data.get("deaf", False)
self.mute: bool = data.get("mute", False)
self.self_deaf: bool = data.get("self_deaf", False)
self.self_mute: bool = data.get("self_mute", False)
self.self_stream: Optional[bool] = data.get("self_stream")
self.self_video: bool = data.get("self_video", False)
self.suppress: bool = data.get("suppress", False)
def __repr__(self) -> str:
return (
f"<VoiceStateUpdate guild_id='{self.guild_id}' user_id='{self.user_id}' "
f"channel_id='{self.channel_id}'>"
)
class Reaction:

View File

@ -1,6 +1,6 @@
[project]
name = "disagreement"
version = "0.8.0"
version = "0.8.1"
description = "A Python library for the Discord API."
readme = "README.md"
requires-python = ">=3.10"