Improve command sync and DM support (#77)

This commit is contained in:
Slipstream 2025-06-14 21:40:52 -06:00 committed by GitHub
parent bd92806c4c
commit c9aec0dc7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 103 additions and 54 deletions

View File

@ -4,17 +4,18 @@ The main Client class for interacting with the Discord API.
import asyncio
import signal
from typing import (
Optional,
Callable,
Any,
TYPE_CHECKING,
Awaitable,
AsyncIterator,
Union,
List,
Dict,
)
from typing import (
Optional,
Callable,
Any,
TYPE_CHECKING,
Awaitable,
AsyncIterator,
Union,
List,
Dict,
cast,
)
from types import ModuleType
from .http import HTTPClient
@ -754,7 +755,7 @@ class Client:
"""Parses user data and returns a User object, updating cache."""
from .models import User # Ensure User model is available
user = User(data)
user = User(data, client_instance=self)
self._users.set(user.id, user) # Cache the user
return user
@ -1010,10 +1011,10 @@ class Client:
# --- API Methods ---
async def send_message(
self,
channel_id: str,
content: Optional[str] = None,
async def send_message(
self,
channel_id: str,
content: Optional[str] = None,
*, # Make additional params keyword-only
tts: bool = False,
embed: Optional["Embed"] = None,
@ -1105,7 +1106,24 @@ class Client:
view.message_id = message_id
self._views[message_id] = view
return self.parse_message(message_data)
return self.parse_message(message_data)
async def create_dm(self, user_id: Snowflake) -> "DMChannel":
"""|coro| Create or fetch a DM channel with a user."""
from .models import DMChannel
dm_data = await self._http.create_dm(user_id)
return cast(DMChannel, self.parse_channel(dm_data))
async def send_dm(
self,
user_id: Snowflake,
content: Optional[str] = None,
**kwargs: Any,
) -> "Message":
"""|coro| Convenience method to send a direct message to a user."""
channel = await self.create_dm(user_id)
return await self.send_message(channel.id, content=content, **kwargs)
def typing(self, channel_id: str) -> Typing:
"""Return a context manager to show a typing indicator in a channel."""

View File

@ -273,7 +273,12 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
@classmethod
def _missing_(cls, value): # type: ignore
return str(value)
member = object.__new__(cls)
member._name_ = str(value)
member._value_ = str(value)
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
cls._member_map_[member._name_] = member # pylint: disable=no-member
return member
# --- Guild Scheduled Event Enums ---
@ -329,7 +334,12 @@ class VoiceRegion(str, Enum):
@classmethod
def _missing_(cls, value): # type: ignore
return str(value)
member = object.__new__(cls)
member._name_ = str(value)
member._value_ = str(value)
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
cls._member_map_[member._name_] = member # pylint: disable=no-member
return member
# --- Channel Enums ---

View File

@ -587,12 +587,19 @@ class AppCommandHandler:
# print(f"Failed to send error message for app command: {send_e}")
async def sync_commands(
self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None
self,
application_id: Optional["Snowflake"] = None,
guild_id: Optional["Snowflake"] = None,
) -> None:
"""
Synchronizes (registers/updates) all application commands with Discord.
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
"""
if application_id is None:
application_id = self.client.application_id
if application_id is None:
raise ValueError("application_id must be provided to sync commands")
cache = self._load_cached_ids()
scope_key = str(guild_id) if guild_id else "global"
stored = cache.get(scope_key, {})

View File

@ -1361,6 +1361,11 @@ class HTTPClient:
"""Joins the current user to a thread."""
await self.request("PUT", f"/channels/{channel_id}/thread-members/@me")
async def leave_thread(self, channel_id: "Snowflake") -> None:
"""Removes the current user from a thread."""
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")
async def leave_thread(self, channel_id: "Snowflake") -> None:
"""Removes the current user from a thread."""
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")
async def create_dm(self, recipient_id: "Snowflake") -> Dict[str, Any]:
"""Creates (or opens) a DM channel with the given user."""
payload = {"recipient_id": str(recipient_id)}
return await self.request("POST", "/users/@me/channels", payload=payload)

View File

@ -46,23 +46,16 @@ if TYPE_CHECKING:
from .components import component_factory
class User:
"""Represents a Discord User.
Attributes:
id (str): The user's unique ID.
username (str): The user's username.
discriminator (str): The user's 4-digit discord-tag.
bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False.
avatar (Optional[str]): The user's avatar hash, if any.
"""
def __init__(self, data: dict):
self.id: str = data["id"]
self.username: str = data["username"]
self.discriminator: str = data["discriminator"]
self.bot: bool = data.get("bot", False)
self.avatar: Optional[str] = data.get("avatar")
class User:
"""Represents a Discord User."""
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
self._client = client_instance
self.id: str = data["id"]
self.username: Optional[str] = data.get("username")
self.discriminator: Optional[str] = data.get("discriminator")
self.bot: bool = data.get("bot", False)
self.avatar: Optional[str] = data.get("avatar")
@property
def mention(self) -> str:
@ -70,7 +63,23 @@ class User:
return f"<@{self.id}>"
def __repr__(self) -> str:
return f"<User id='{self.id}' username='{self.username}' discriminator='{self.discriminator}'>"
username = self.username or "Unknown"
disc = self.discriminator or "????"
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
async def send(
self,
content: Optional[str] = None,
*,
client: Optional["Client"] = None,
**kwargs: Any,
) -> "Message":
"""Send a direct message to this user."""
target_client = client or self._client
if target_client is None:
raise DisagreementException("User.send requires a Client instance")
return await target_client.send_dm(self.id, content=content, **kwargs)
class Message:
@ -96,7 +105,7 @@ class Message:
self.id: str = data["id"]
self.channel_id: str = data["channel_id"]
self.guild_id: Optional[str] = data.get("guild_id")
self.author: User = User(data["author"])
self.author: User = User(data["author"], client_instance)
self.content: str = data["content"]
self.timestamp: str = data["timestamp"]
if data.get("components"):
@ -773,10 +782,10 @@ class Member(User): # Member inherits from User
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
@property
def display_name(self) -> str:
"""Return the nickname if set, otherwise the username."""
return self.nick or self.username
def display_name(self) -> str:
"""Return the nickname if set, otherwise the username."""
return self.nick or self.username or ""
async def kick(self, *, reason: Optional[str] = None) -> None:
if not self.guild_id or not self._client:
@ -1183,13 +1192,13 @@ class Guild:
The matching member if found, otherwise ``None``.
"""
lowered = name.lower()
for member in self._members.values():
if member.username.lower() == lowered:
return member
if member.nick and member.nick.lower() == lowered:
return member
return None
lowered = name.lower()
for member in self._members.values():
if member.username and member.username.lower() == lowered:
return member
if member.nick and member.nick.lower() == lowered:
return member
return None
def get_role(self, role_id: str) -> Optional[Role]:
return next((role for role in self.roles if role.id == role_id), None)
@ -2471,7 +2480,7 @@ class PresenceUpdate:
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
):
self._client = client_instance
self.user = User(data["user"])
self.user = User(data["user"], client_instance)
self.guild_id: Optional[str] = data.get("guild_id")
self.status: Optional[str] = data.get("status")
self.activities: List[Activity] = []