Improve command sync and DM support (#77)

This commit is contained in:
Slipstream 2025-06-14 21:40:52 -06:00 committed by GitHub
parent bd92806c4c
commit c9aec0dc7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 103 additions and 54 deletions

View File

@ -14,6 +14,7 @@ from typing import (
Union, Union,
List, List,
Dict, Dict,
cast,
) )
from types import ModuleType from types import ModuleType
@ -754,7 +755,7 @@ class Client:
"""Parses user data and returns a User object, updating cache.""" """Parses user data and returns a User object, updating cache."""
from .models import User # Ensure User model is available from .models import User # Ensure User model is available
user = User(data) user = User(data, client_instance=self)
self._users.set(user.id, user) # Cache the user self._users.set(user.id, user) # Cache the user
return user return user
@ -1107,6 +1108,23 @@ class Client:
return self.parse_message(message_data) return self.parse_message(message_data)
async def create_dm(self, user_id: Snowflake) -> "DMChannel":
"""|coro| Create or fetch a DM channel with a user."""
from .models import DMChannel
dm_data = await self._http.create_dm(user_id)
return cast(DMChannel, self.parse_channel(dm_data))
async def send_dm(
self,
user_id: Snowflake,
content: Optional[str] = None,
**kwargs: Any,
) -> "Message":
"""|coro| Convenience method to send a direct message to a user."""
channel = await self.create_dm(user_id)
return await self.send_message(channel.id, content=content, **kwargs)
def typing(self, channel_id: str) -> Typing: def typing(self, channel_id: str) -> Typing:
"""Return a context manager to show a typing indicator in a channel.""" """Return a context manager to show a typing indicator in a channel."""

View File

@ -273,7 +273,12 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work # This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
@classmethod @classmethod
def _missing_(cls, value): # type: ignore def _missing_(cls, value): # type: ignore
return str(value) member = object.__new__(cls)
member._name_ = str(value)
member._value_ = str(value)
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
cls._member_map_[member._name_] = member # pylint: disable=no-member
return member
# --- Guild Scheduled Event Enums --- # --- Guild Scheduled Event Enums ---
@ -329,7 +334,12 @@ class VoiceRegion(str, Enum):
@classmethod @classmethod
def _missing_(cls, value): # type: ignore def _missing_(cls, value): # type: ignore
return str(value) member = object.__new__(cls)
member._name_ = str(value)
member._value_ = str(value)
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
cls._member_map_[member._name_] = member # pylint: disable=no-member
return member
# --- Channel Enums --- # --- Channel Enums ---

View File

@ -587,12 +587,19 @@ class AppCommandHandler:
# print(f"Failed to send error message for app command: {send_e}") # print(f"Failed to send error message for app command: {send_e}")
async def sync_commands( async def sync_commands(
self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None self,
application_id: Optional["Snowflake"] = None,
guild_id: Optional["Snowflake"] = None,
) -> None: ) -> None:
""" """
Synchronizes (registers/updates) all application commands with Discord. Synchronizes (registers/updates) all application commands with Discord.
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands. If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
""" """
if application_id is None:
application_id = self.client.application_id
if application_id is None:
raise ValueError("application_id must be provided to sync commands")
cache = self._load_cached_ids() cache = self._load_cached_ids()
scope_key = str(guild_id) if guild_id else "global" scope_key = str(guild_id) if guild_id else "global"
stored = cache.get(scope_key, {}) stored = cache.get(scope_key, {})

View File

@ -1364,3 +1364,8 @@ class HTTPClient:
async def leave_thread(self, channel_id: "Snowflake") -> None: async def leave_thread(self, channel_id: "Snowflake") -> None:
"""Removes the current user from a thread.""" """Removes the current user from a thread."""
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me") await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")
async def create_dm(self, recipient_id: "Snowflake") -> Dict[str, Any]:
"""Creates (or opens) a DM channel with the given user."""
payload = {"recipient_id": str(recipient_id)}
return await self.request("POST", "/users/@me/channels", payload=payload)

View File

@ -47,20 +47,13 @@ if TYPE_CHECKING:
class User: class User:
"""Represents a Discord User. """Represents a Discord User."""
Attributes: def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
id (str): The user's unique ID. self._client = client_instance
username (str): The user's username.
discriminator (str): The user's 4-digit discord-tag.
bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False.
avatar (Optional[str]): The user's avatar hash, if any.
"""
def __init__(self, data: dict):
self.id: str = data["id"] self.id: str = data["id"]
self.username: str = data["username"] self.username: Optional[str] = data.get("username")
self.discriminator: str = data["discriminator"] self.discriminator: Optional[str] = data.get("discriminator")
self.bot: bool = data.get("bot", False) self.bot: bool = data.get("bot", False)
self.avatar: Optional[str] = data.get("avatar") self.avatar: Optional[str] = data.get("avatar")
@ -70,7 +63,23 @@ class User:
return f"<@{self.id}>" return f"<@{self.id}>"
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<User id='{self.id}' username='{self.username}' discriminator='{self.discriminator}'>" username = self.username or "Unknown"
disc = self.discriminator or "????"
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
async def send(
self,
content: Optional[str] = None,
*,
client: Optional["Client"] = None,
**kwargs: Any,
) -> "Message":
"""Send a direct message to this user."""
target_client = client or self._client
if target_client is None:
raise DisagreementException("User.send requires a Client instance")
return await target_client.send_dm(self.id, content=content, **kwargs)
class Message: class Message:
@ -96,7 +105,7 @@ class Message:
self.id: str = data["id"] self.id: str = data["id"]
self.channel_id: str = data["channel_id"] self.channel_id: str = data["channel_id"]
self.guild_id: Optional[str] = data.get("guild_id") self.guild_id: Optional[str] = data.get("guild_id")
self.author: User = User(data["author"]) self.author: User = User(data["author"], client_instance)
self.content: str = data["content"] self.content: str = data["content"]
self.timestamp: str = data["timestamp"] self.timestamp: str = data["timestamp"]
if data.get("components"): if data.get("components"):
@ -776,7 +785,7 @@ class Member(User): # Member inherits from User
def display_name(self) -> str: def display_name(self) -> str:
"""Return the nickname if set, otherwise the username.""" """Return the nickname if set, otherwise the username."""
return self.nick or self.username return self.nick or self.username or ""
async def kick(self, *, reason: Optional[str] = None) -> None: async def kick(self, *, reason: Optional[str] = None) -> None:
if not self.guild_id or not self._client: if not self.guild_id or not self._client:
@ -1185,7 +1194,7 @@ class Guild:
lowered = name.lower() lowered = name.lower()
for member in self._members.values(): for member in self._members.values():
if member.username.lower() == lowered: if member.username and member.username.lower() == lowered:
return member return member
if member.nick and member.nick.lower() == lowered: if member.nick and member.nick.lower() == lowered:
return member return member
@ -2471,7 +2480,7 @@ class PresenceUpdate:
self, data: Dict[str, Any], client_instance: Optional["Client"] = None self, data: Dict[str, Any], client_instance: Optional["Client"] = None
): ):
self._client = client_instance self._client = client_instance
self.user = User(data["user"]) self.user = User(data["user"], client_instance)
self.guild_id: Optional[str] = data.get("guild_id") self.guild_id: Optional[str] = data.get("guild_id")
self.status: Optional[str] = data.get("status") self.status: Optional[str] = data.get("status")
self.activities: List[Activity] = [] self.activities: List[Activity] = []