Improve command sync and DM support (#77)
This commit is contained in:
parent
bd92806c4c
commit
c9aec0dc7e
@ -4,17 +4,18 @@ The main Client class for interacting with the Discord API.
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
AsyncIterator,
|
||||
Union,
|
||||
List,
|
||||
Dict,
|
||||
)
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
AsyncIterator,
|
||||
Union,
|
||||
List,
|
||||
Dict,
|
||||
cast,
|
||||
)
|
||||
from types import ModuleType
|
||||
|
||||
from .http import HTTPClient
|
||||
@ -754,7 +755,7 @@ class Client:
|
||||
"""Parses user data and returns a User object, updating cache."""
|
||||
from .models import User # Ensure User model is available
|
||||
|
||||
user = User(data)
|
||||
user = User(data, client_instance=self)
|
||||
self._users.set(user.id, user) # Cache the user
|
||||
return user
|
||||
|
||||
@ -1010,10 +1011,10 @@ class Client:
|
||||
|
||||
# --- API Methods ---
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
channel_id: str,
|
||||
content: Optional[str] = None,
|
||||
async def send_message(
|
||||
self,
|
||||
channel_id: str,
|
||||
content: Optional[str] = None,
|
||||
*, # Make additional params keyword-only
|
||||
tts: bool = False,
|
||||
embed: Optional["Embed"] = None,
|
||||
@ -1105,7 +1106,24 @@ class Client:
|
||||
view.message_id = message_id
|
||||
self._views[message_id] = view
|
||||
|
||||
return self.parse_message(message_data)
|
||||
return self.parse_message(message_data)
|
||||
|
||||
async def create_dm(self, user_id: Snowflake) -> "DMChannel":
|
||||
"""|coro| Create or fetch a DM channel with a user."""
|
||||
from .models import DMChannel
|
||||
|
||||
dm_data = await self._http.create_dm(user_id)
|
||||
return cast(DMChannel, self.parse_channel(dm_data))
|
||||
|
||||
async def send_dm(
|
||||
self,
|
||||
user_id: Snowflake,
|
||||
content: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> "Message":
|
||||
"""|coro| Convenience method to send a direct message to a user."""
|
||||
channel = await self.create_dm(user_id)
|
||||
return await self.send_message(channel.id, content=content, **kwargs)
|
||||
|
||||
def typing(self, channel_id: str) -> Typing:
|
||||
"""Return a context manager to show a typing indicator in a channel."""
|
||||
|
@ -273,7 +273,12 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum
|
||||
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
|
||||
@classmethod
|
||||
def _missing_(cls, value): # type: ignore
|
||||
return str(value)
|
||||
member = object.__new__(cls)
|
||||
member._name_ = str(value)
|
||||
member._value_ = str(value)
|
||||
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
|
||||
cls._member_map_[member._name_] = member # pylint: disable=no-member
|
||||
return member
|
||||
|
||||
|
||||
# --- Guild Scheduled Event Enums ---
|
||||
@ -329,7 +334,12 @@ class VoiceRegion(str, Enum):
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value): # type: ignore
|
||||
return str(value)
|
||||
member = object.__new__(cls)
|
||||
member._name_ = str(value)
|
||||
member._value_ = str(value)
|
||||
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
|
||||
cls._member_map_[member._name_] = member # pylint: disable=no-member
|
||||
return member
|
||||
|
||||
|
||||
# --- Channel Enums ---
|
||||
|
@ -587,12 +587,19 @@ class AppCommandHandler:
|
||||
# print(f"Failed to send error message for app command: {send_e}")
|
||||
|
||||
async def sync_commands(
|
||||
self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None
|
||||
self,
|
||||
application_id: Optional["Snowflake"] = None,
|
||||
guild_id: Optional["Snowflake"] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Synchronizes (registers/updates) all application commands with Discord.
|
||||
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
|
||||
"""
|
||||
if application_id is None:
|
||||
application_id = self.client.application_id
|
||||
if application_id is None:
|
||||
raise ValueError("application_id must be provided to sync commands")
|
||||
|
||||
cache = self._load_cached_ids()
|
||||
scope_key = str(guild_id) if guild_id else "global"
|
||||
stored = cache.get(scope_key, {})
|
||||
|
@ -1361,6 +1361,11 @@ class HTTPClient:
|
||||
"""Joins the current user to a thread."""
|
||||
await self.request("PUT", f"/channels/{channel_id}/thread-members/@me")
|
||||
|
||||
async def leave_thread(self, channel_id: "Snowflake") -> None:
|
||||
"""Removes the current user from a thread."""
|
||||
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")
|
||||
async def leave_thread(self, channel_id: "Snowflake") -> None:
|
||||
"""Removes the current user from a thread."""
|
||||
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")
|
||||
|
||||
async def create_dm(self, recipient_id: "Snowflake") -> Dict[str, Any]:
|
||||
"""Creates (or opens) a DM channel with the given user."""
|
||||
payload = {"recipient_id": str(recipient_id)}
|
||||
return await self.request("POST", "/users/@me/channels", payload=payload)
|
||||
|
@ -46,23 +46,16 @@ if TYPE_CHECKING:
|
||||
from .components import component_factory
|
||||
|
||||
|
||||
class User:
|
||||
"""Represents a Discord User.
|
||||
|
||||
Attributes:
|
||||
id (str): The user's unique ID.
|
||||
username (str): The user's username.
|
||||
discriminator (str): The user's 4-digit discord-tag.
|
||||
bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False.
|
||||
avatar (Optional[str]): The user's avatar hash, if any.
|
||||
"""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.id: str = data["id"]
|
||||
self.username: str = data["username"]
|
||||
self.discriminator: str = data["discriminator"]
|
||||
self.bot: bool = data.get("bot", False)
|
||||
self.avatar: Optional[str] = data.get("avatar")
|
||||
class User:
|
||||
"""Represents a Discord User."""
|
||||
|
||||
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
|
||||
self._client = client_instance
|
||||
self.id: str = data["id"]
|
||||
self.username: Optional[str] = data.get("username")
|
||||
self.discriminator: Optional[str] = data.get("discriminator")
|
||||
self.bot: bool = data.get("bot", False)
|
||||
self.avatar: Optional[str] = data.get("avatar")
|
||||
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
@ -70,7 +63,23 @@ class User:
|
||||
return f"<@{self.id}>"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<User id='{self.id}' username='{self.username}' discriminator='{self.discriminator}'>"
|
||||
username = self.username or "Unknown"
|
||||
disc = self.discriminator or "????"
|
||||
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
|
||||
|
||||
async def send(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
*,
|
||||
client: Optional["Client"] = None,
|
||||
**kwargs: Any,
|
||||
) -> "Message":
|
||||
"""Send a direct message to this user."""
|
||||
|
||||
target_client = client or self._client
|
||||
if target_client is None:
|
||||
raise DisagreementException("User.send requires a Client instance")
|
||||
return await target_client.send_dm(self.id, content=content, **kwargs)
|
||||
|
||||
|
||||
class Message:
|
||||
@ -96,7 +105,7 @@ class Message:
|
||||
self.id: str = data["id"]
|
||||
self.channel_id: str = data["channel_id"]
|
||||
self.guild_id: Optional[str] = data.get("guild_id")
|
||||
self.author: User = User(data["author"])
|
||||
self.author: User = User(data["author"], client_instance)
|
||||
self.content: str = data["content"]
|
||||
self.timestamp: str = data["timestamp"]
|
||||
if data.get("components"):
|
||||
@ -773,10 +782,10 @@ class Member(User): # Member inherits from User
|
||||
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Return the nickname if set, otherwise the username."""
|
||||
|
||||
return self.nick or self.username
|
||||
def display_name(self) -> str:
|
||||
"""Return the nickname if set, otherwise the username."""
|
||||
|
||||
return self.nick or self.username or ""
|
||||
|
||||
async def kick(self, *, reason: Optional[str] = None) -> None:
|
||||
if not self.guild_id or not self._client:
|
||||
@ -1183,13 +1192,13 @@ class Guild:
|
||||
The matching member if found, otherwise ``None``.
|
||||
"""
|
||||
|
||||
lowered = name.lower()
|
||||
for member in self._members.values():
|
||||
if member.username.lower() == lowered:
|
||||
return member
|
||||
if member.nick and member.nick.lower() == lowered:
|
||||
return member
|
||||
return None
|
||||
lowered = name.lower()
|
||||
for member in self._members.values():
|
||||
if member.username and member.username.lower() == lowered:
|
||||
return member
|
||||
if member.nick and member.nick.lower() == lowered:
|
||||
return member
|
||||
return None
|
||||
|
||||
def get_role(self, role_id: str) -> Optional[Role]:
|
||||
return next((role for role in self.roles if role.id == role_id), None)
|
||||
@ -2471,7 +2480,7 @@ class PresenceUpdate:
|
||||
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
|
||||
):
|
||||
self._client = client_instance
|
||||
self.user = User(data["user"])
|
||||
self.user = User(data["user"], client_instance)
|
||||
self.guild_id: Optional[str] = data.get("guild_id")
|
||||
self.status: Optional[str] = data.get("status")
|
||||
self.activities: List[Activity] = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user