diff --git a/disagreement/client.py b/disagreement/client.py index 497dbc0..502af57 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -4,17 +4,18 @@ The main Client class for interacting with the Discord API. import asyncio import signal -from typing import ( - Optional, - Callable, - Any, - TYPE_CHECKING, - Awaitable, - AsyncIterator, - Union, - List, - Dict, -) +from typing import ( + Optional, + Callable, + Any, + TYPE_CHECKING, + Awaitable, + AsyncIterator, + Union, + List, + Dict, + cast, +) from types import ModuleType from .http import HTTPClient @@ -754,7 +755,7 @@ class Client: """Parses user data and returns a User object, updating cache.""" from .models import User # Ensure User model is available - user = User(data) + user = User(data, client_instance=self) self._users.set(user.id, user) # Cache the user return user @@ -1010,10 +1011,10 @@ class Client: # --- API Methods --- - async def send_message( - self, - channel_id: str, - content: Optional[str] = None, + async def send_message( + self, + channel_id: str, + content: Optional[str] = None, *, # Make additional params keyword-only tts: bool = False, embed: Optional["Embed"] = None, @@ -1105,7 +1106,24 @@ class Client: view.message_id = message_id self._views[message_id] = view - return self.parse_message(message_data) + return self.parse_message(message_data) + + async def create_dm(self, user_id: Snowflake) -> "DMChannel": + """|coro| Create or fetch a DM channel with a user.""" + from .models import DMChannel + + dm_data = await self._http.create_dm(user_id) + return cast(DMChannel, self.parse_channel(dm_data)) + + async def send_dm( + self, + user_id: Snowflake, + content: Optional[str] = None, + **kwargs: Any, + ) -> "Message": + """|coro| Convenience method to send a direct message to a user.""" + channel = await self.create_dm(user_id) + return await self.send_message(channel.id, content=content, **kwargs) def typing(self, channel_id: str) -> Typing: """Return a context manager to show a typing indicator in a channel.""" diff --git a/disagreement/enums.py b/disagreement/enums.py index 99418c3..7b4d6c7 100644 --- a/disagreement/enums.py +++ b/disagreement/enums.py @@ -273,7 +273,12 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum # This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work @classmethod def _missing_(cls, value): # type: ignore - return str(value) + member = object.__new__(cls) + member._name_ = str(value) + member._value_ = str(value) + cls._value2member_map_[member._value_] = member # pylint: disable=no-member + cls._member_map_[member._name_] = member # pylint: disable=no-member + return member # --- Guild Scheduled Event Enums --- @@ -329,7 +334,12 @@ class VoiceRegion(str, Enum): @classmethod def _missing_(cls, value): # type: ignore - return str(value) + member = object.__new__(cls) + member._name_ = str(value) + member._value_ = str(value) + cls._value2member_map_[member._value_] = member # pylint: disable=no-member + cls._member_map_[member._name_] = member # pylint: disable=no-member + return member # --- Channel Enums --- diff --git a/disagreement/ext/app_commands/handler.py b/disagreement/ext/app_commands/handler.py index 5e67c61..7f0782f 100644 --- a/disagreement/ext/app_commands/handler.py +++ b/disagreement/ext/app_commands/handler.py @@ -587,12 +587,19 @@ class AppCommandHandler: # print(f"Failed to send error message for app command: {send_e}") async def sync_commands( - self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None + self, + application_id: Optional["Snowflake"] = None, + guild_id: Optional["Snowflake"] = None, ) -> None: """ Synchronizes (registers/updates) all application commands with Discord. If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands. """ + if application_id is None: + application_id = self.client.application_id + if application_id is None: + raise ValueError("application_id must be provided to sync commands") + cache = self._load_cached_ids() scope_key = str(guild_id) if guild_id else "global" stored = cache.get(scope_key, {}) diff --git a/disagreement/http.py b/disagreement/http.py index 3345eef..543b9ee 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -1361,6 +1361,11 @@ class HTTPClient: """Joins the current user to a thread.""" await self.request("PUT", f"/channels/{channel_id}/thread-members/@me") - async def leave_thread(self, channel_id: "Snowflake") -> None: - """Removes the current user from a thread.""" - await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me") + async def leave_thread(self, channel_id: "Snowflake") -> None: + """Removes the current user from a thread.""" + await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me") + + async def create_dm(self, recipient_id: "Snowflake") -> Dict[str, Any]: + """Creates (or opens) a DM channel with the given user.""" + payload = {"recipient_id": str(recipient_id)} + return await self.request("POST", "/users/@me/channels", payload=payload) diff --git a/disagreement/models.py b/disagreement/models.py index b4502c3..8ac6b66 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -46,23 +46,16 @@ if TYPE_CHECKING: from .components import component_factory -class User: - """Represents a Discord User. - - Attributes: - id (str): The user's unique ID. - username (str): The user's username. - discriminator (str): The user's 4-digit discord-tag. - bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False. - avatar (Optional[str]): The user's avatar hash, if any. - """ - - def __init__(self, data: dict): - self.id: str = data["id"] - self.username: str = data["username"] - self.discriminator: str = data["discriminator"] - self.bot: bool = data.get("bot", False) - self.avatar: Optional[str] = data.get("avatar") +class User: + """Represents a Discord User.""" + + def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None: + self._client = client_instance + self.id: str = data["id"] + self.username: Optional[str] = data.get("username") + self.discriminator: Optional[str] = data.get("discriminator") + self.bot: bool = data.get("bot", False) + self.avatar: Optional[str] = data.get("avatar") @property def mention(self) -> str: @@ -70,7 +63,23 @@ class User: return f"<@{self.id}>" def __repr__(self) -> str: - return f"" + username = self.username or "Unknown" + disc = self.discriminator or "????" + return f"" + + async def send( + self, + content: Optional[str] = None, + *, + client: Optional["Client"] = None, + **kwargs: Any, + ) -> "Message": + """Send a direct message to this user.""" + + target_client = client or self._client + if target_client is None: + raise DisagreementException("User.send requires a Client instance") + return await target_client.send_dm(self.id, content=content, **kwargs) class Message: @@ -96,7 +105,7 @@ class Message: self.id: str = data["id"] self.channel_id: str = data["channel_id"] self.guild_id: Optional[str] = data.get("guild_id") - self.author: User = User(data["author"]) + self.author: User = User(data["author"], client_instance) self.content: str = data["content"] self.timestamp: str = data["timestamp"] if data.get("components"): @@ -773,10 +782,10 @@ class Member(User): # Member inherits from User return f"" @property - def display_name(self) -> str: - """Return the nickname if set, otherwise the username.""" - - return self.nick or self.username + def display_name(self) -> str: + """Return the nickname if set, otherwise the username.""" + + return self.nick or self.username or "" async def kick(self, *, reason: Optional[str] = None) -> None: if not self.guild_id or not self._client: @@ -1183,13 +1192,13 @@ class Guild: The matching member if found, otherwise ``None``. """ - lowered = name.lower() - for member in self._members.values(): - if member.username.lower() == lowered: - return member - if member.nick and member.nick.lower() == lowered: - return member - return None + lowered = name.lower() + for member in self._members.values(): + if member.username and member.username.lower() == lowered: + return member + if member.nick and member.nick.lower() == lowered: + return member + return None def get_role(self, role_id: str) -> Optional[Role]: return next((role for role in self.roles if role.id == role_id), None) @@ -2471,7 +2480,7 @@ class PresenceUpdate: self, data: Dict[str, Any], client_instance: Optional["Client"] = None ): self._client = client_instance - self.user = User(data["user"]) + self.user = User(data["user"], client_instance) self.guild_id: Optional[str] = data.get("guild_id") self.status: Optional[str] = data.get("status") self.activities: List[Activity] = []