Compare commits
8 Commits
be85444aa0
...
ed83a9da85
Author | SHA1 | Date | |
---|---|---|---|
ed83a9da85 | |||
0151526d07 | |||
17b7ea35a9 | |||
28702fa8a1 | |||
97505948ee | |||
152c0f12be | |||
eb38ecf671 | |||
2bd45c87ca |
@ -114,3 +114,20 @@ class FFmpegAudioSource(AudioSource):
|
||||
if isinstance(self.source, io.IOBase):
|
||||
with contextlib.suppress(Exception):
|
||||
self.source.close()
|
||||
|
||||
class AudioSink:
|
||||
"""Abstract base class for audio sinks."""
|
||||
|
||||
def write(self, user, data):
|
||||
"""Write a chunk of PCM audio.
|
||||
|
||||
Subclasses must implement this. The data is raw PCM at 48kHz
|
||||
stereo.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cleanup the sink when the voice client disconnects."""
|
||||
|
||||
return None
|
||||
|
@ -4,7 +4,8 @@ import time
|
||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import Channel, Guild
|
||||
from .models import Channel, Guild, Member
|
||||
from .caching import MemberCacheFlags
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -53,3 +54,32 @@ class GuildCache(Cache["Guild"]):
|
||||
|
||||
class ChannelCache(Cache["Channel"]):
|
||||
"""Cache specifically for :class:`Channel` objects."""
|
||||
|
||||
|
||||
class MemberCache(Cache["Member"]):
|
||||
"""
|
||||
A cache for :class:`Member` objects that respects :class:`MemberCacheFlags`.
|
||||
"""
|
||||
|
||||
def __init__(self, flags: MemberCacheFlags, ttl: Optional[float] = None) -> None:
|
||||
super().__init__(ttl)
|
||||
self.flags = flags
|
||||
|
||||
def _should_cache(self, member: Member) -> bool:
|
||||
"""Determines if a member should be cached based on the flags."""
|
||||
if self.flags.all:
|
||||
return True
|
||||
if self.flags.none:
|
||||
return False
|
||||
|
||||
if self.flags.online and member.status != "offline":
|
||||
return True
|
||||
if self.flags.voice and member.voice_state is not None:
|
||||
return True
|
||||
if self.flags.joined and getattr(member, "_just_joined", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
def set(self, key: str, value: Member) -> None:
|
||||
if self._should_cache(value):
|
||||
super().set(key, value)
|
||||
|
120
disagreement/caching.py
Normal file
120
disagreement/caching.py
Normal file
@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterator, Tuple
|
||||
|
||||
|
||||
class _MemberCacheFlagValue:
|
||||
flag: int
|
||||
|
||||
def __init__(self, func: Callable[[Any], bool]):
|
||||
self.flag = getattr(func, 'flag', 0)
|
||||
self.__doc__ = func.__doc__
|
||||
|
||||
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any:
|
||||
if instance is None:
|
||||
return self
|
||||
return instance.value & self.flag != 0
|
||||
|
||||
def __set__(self, instance: Any, value: bool) -> None:
|
||||
if value:
|
||||
instance.value |= self.flag
|
||||
else:
|
||||
instance.value &= ~self.flag
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} flag={self.flag}>'
|
||||
|
||||
|
||||
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
|
||||
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
|
||||
setattr(func, 'flag', flag)
|
||||
return _MemberCacheFlagValue(func)
|
||||
return decorator
|
||||
|
||||
|
||||
class MemberCacheFlags:
|
||||
__slots__ = ('value',)
|
||||
|
||||
VALID_FLAGS: ClassVar[Dict[str, int]] = {
|
||||
'joined': 1 << 0,
|
||||
'voice': 1 << 1,
|
||||
'online': 1 << 2,
|
||||
}
|
||||
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
|
||||
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
|
||||
|
||||
def __init__(self, **kwargs: bool):
|
||||
self.value = self.DEFAULT_FLAGS
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError(f'{key!r} is not a valid member cache flag.')
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
def _from_value(cls, value: int) -> MemberCacheFlags:
|
||||
self = cls.__new__(cls)
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, MemberCacheFlags) and self.value == other.value
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<MemberCacheFlags value={self.value}>'
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[str, bool]]:
|
||||
for name in self.VALID_FLAGS:
|
||||
yield name, getattr(self, name)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.value
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with all flags enabled."""
|
||||
return cls._from_value(cls.ALL_FLAGS)
|
||||
|
||||
@classmethod
|
||||
def none(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with all flags disabled."""
|
||||
return cls._from_value(0)
|
||||
|
||||
@classmethod
|
||||
def only_joined(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
|
||||
return cls._from_value(cls.VALID_FLAGS['joined'])
|
||||
|
||||
@classmethod
|
||||
def only_voice(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
|
||||
return cls._from_value(cls.VALID_FLAGS['voice'])
|
||||
|
||||
@classmethod
|
||||
def only_online(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
|
||||
return cls._from_value(cls.VALID_FLAGS['online'])
|
||||
|
||||
@flag_value(1 << 0)
|
||||
def joined(self) -> bool:
|
||||
"""Whether to cache members that have just joined the guild."""
|
||||
return False
|
||||
|
||||
@flag_value(1 << 1)
|
||||
def voice(self) -> bool:
|
||||
"""Whether to cache members that are in a voice channel."""
|
||||
return False
|
||||
|
||||
@flag_value(1 << 2)
|
||||
def online(self) -> bool:
|
||||
"""Whether to cache members that are online."""
|
||||
return False
|
@ -26,7 +26,9 @@ from .event_dispatcher import EventDispatcher
|
||||
from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion
|
||||
from .errors import DisagreementException, AuthenticationError
|
||||
from .typing import Typing
|
||||
from .ext.commands.core import CommandHandler
|
||||
from .caching import MemberCacheFlags
|
||||
from .cache import Cache, GuildCache, ChannelCache, MemberCache
|
||||
from .ext.commands.core import Command, CommandHandler, Group
|
||||
from .ext.commands.cog import Cog
|
||||
from .ext.app_commands.handler import AppCommandHandler
|
||||
from .ext.app_commands.context import AppCommandContext
|
||||
@ -96,12 +98,16 @@ class Client:
|
||||
shard_count: Optional[int] = None,
|
||||
gateway_max_retries: int = 5,
|
||||
gateway_max_backoff: float = 60.0,
|
||||
member_cache_flags: Optional[MemberCacheFlags] = None,
|
||||
http_options: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if not token:
|
||||
raise ValueError("A bot token must be provided.")
|
||||
|
||||
self.token: str = token
|
||||
self.member_cache_flags: MemberCacheFlags = (
|
||||
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
|
||||
)
|
||||
self.intents: int = intents if intents is not None else GatewayIntent.default()
|
||||
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
||||
self.application_id: Optional[Snowflake] = (
|
||||
@ -141,15 +147,12 @@ class Client:
|
||||
)
|
||||
|
||||
# Internal Caches
|
||||
self._guilds: Dict[Snowflake, "Guild"] = {}
|
||||
self._channels: Dict[Snowflake, "Channel"] = (
|
||||
{}
|
||||
) # Stores all channel types by ID
|
||||
self._users: Dict[Snowflake, Any] = (
|
||||
{}
|
||||
) # Placeholder for User model cache if needed
|
||||
self._messages: Dict[Snowflake, "Message"] = {}
|
||||
self._guilds: GuildCache = GuildCache()
|
||||
self._channels: ChannelCache = ChannelCache()
|
||||
self._users: Cache["User"] = Cache()
|
||||
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour
|
||||
self._views: Dict[Snowflake, "View"] = {}
|
||||
self._persistent_views: Dict[str, "View"] = {}
|
||||
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
||||
self._webhooks: Dict[Snowflake, "Webhook"] = {}
|
||||
|
||||
@ -579,6 +582,41 @@ class Client:
|
||||
# For now, assuming name is sufficient for removal from the handler's flat list.
|
||||
return removed_cog
|
||||
|
||||
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
|
||||
"""
|
||||
A decorator that adds a global check to the bot.
|
||||
This check will be called for every command before it's executed.
|
||||
|
||||
Example:
|
||||
@bot.check
|
||||
async def block_dms(ctx):
|
||||
return ctx.guild is not None
|
||||
"""
|
||||
self.command_handler.add_check(coro)
|
||||
return coro
|
||||
|
||||
def command(
|
||||
self, **attrs: Any
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Command]:
|
||||
"""A decorator that transforms a function into a Command."""
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> Command:
|
||||
cmd = Command(func, **attrs)
|
||||
self.command_handler.add_command(cmd)
|
||||
return cmd
|
||||
|
||||
return decorator
|
||||
|
||||
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], Group]:
|
||||
"""A decorator that transforms a function into a Group command."""
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> Group:
|
||||
cmd = Group(func, **attrs)
|
||||
self.command_handler.add_command(cmd)
|
||||
return cmd
|
||||
|
||||
return decorator
|
||||
|
||||
def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
|
||||
"""
|
||||
Adds a standalone application command or group to the bot.
|
||||
@ -649,6 +687,16 @@ class Client:
|
||||
# import traceback
|
||||
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
|
||||
|
||||
async def on_command_completion(self, ctx: "CommandContext") -> None:
|
||||
"""
|
||||
Default command completion handler. Called when a command has successfully completed.
|
||||
Users can override this method in a subclass of Client.
|
||||
|
||||
Args:
|
||||
ctx (CommandContext): The context of the command that completed.
|
||||
"""
|
||||
pass
|
||||
|
||||
# --- Extension Management Methods ---
|
||||
|
||||
def load_extension(self, name: str) -> ModuleType:
|
||||
@ -673,7 +721,7 @@ class Client:
|
||||
from .models import User # Ensure User model is available
|
||||
|
||||
user = User(data)
|
||||
self._users[user.id] = user # Cache the user
|
||||
self._users.set(user.id, user) # Cache the user
|
||||
return user
|
||||
|
||||
def parse_channel(self, data: Dict[str, Any]) -> "Channel":
|
||||
@ -682,11 +730,11 @@ class Client:
|
||||
from .models import channel_factory
|
||||
|
||||
channel = channel_factory(data, self)
|
||||
self._channels[channel.id] = channel
|
||||
self._channels.set(channel.id, channel)
|
||||
if channel.guild_id:
|
||||
guild = self._guilds.get(channel.guild_id)
|
||||
if guild:
|
||||
guild._channels[channel.id] = channel
|
||||
guild._channels.set(channel.id, channel)
|
||||
return channel
|
||||
|
||||
def parse_message(self, data: Dict[str, Any]) -> "Message":
|
||||
@ -695,7 +743,7 @@ class Client:
|
||||
from .models import Message
|
||||
|
||||
message = Message(data, client_instance=self)
|
||||
self._messages[message.id] = message
|
||||
self._messages.set(message.id, message)
|
||||
return message
|
||||
|
||||
def parse_webhook(self, data: Union[Dict[str, Any], "Webhook"]) -> "Webhook":
|
||||
@ -752,7 +800,7 @@ class Client:
|
||||
|
||||
cached_user = self._users.get(user_id)
|
||||
if cached_user:
|
||||
return cached_user # Return cached if available, though fetch implies wanting fresh
|
||||
return cached_user
|
||||
|
||||
try:
|
||||
user_data = await self._http.get_user(user_id)
|
||||
@ -782,23 +830,26 @@ class Client:
|
||||
)
|
||||
return None
|
||||
|
||||
def parse_member(self, data: Dict[str, Any], guild_id: Snowflake) -> "Member":
|
||||
def parse_member(
|
||||
self, data: Dict[str, Any], guild_id: Snowflake, *, just_joined: bool = False
|
||||
) -> "Member":
|
||||
"""Parses member data and returns a Member object, updating relevant caches."""
|
||||
from .models import Member # Ensure Member model is available
|
||||
from .models import Member
|
||||
|
||||
# Member's __init__ should handle the nested 'user' data.
|
||||
member = Member(data, client_instance=self)
|
||||
member.guild_id = str(guild_id)
|
||||
|
||||
# Cache the member in the guild's member cache
|
||||
if just_joined:
|
||||
setattr(member, "_just_joined", True)
|
||||
|
||||
guild = self._guilds.get(guild_id)
|
||||
if guild:
|
||||
guild._members[member.id] = member # Assuming Guild has _members dict
|
||||
guild._members.set(member.id, member)
|
||||
|
||||
# Also cache the user part if not already cached or if this is newer
|
||||
# Since Member inherits from User, the member object itself is the user.
|
||||
self._users[member.id] = member
|
||||
# If 'user' was in data and Member.__init__ used it, it's already part of 'member'.
|
||||
if just_joined and hasattr(member, "_just_joined"):
|
||||
delattr(member, "_just_joined")
|
||||
|
||||
self._users.set(member.id, member)
|
||||
return member
|
||||
|
||||
async def fetch_member(
|
||||
@ -842,20 +893,29 @@ class Client:
|
||||
|
||||
def parse_guild(self, data: Dict[str, Any]) -> "Guild":
|
||||
"""Parses guild data and returns a Guild object, updating cache."""
|
||||
|
||||
from .models import Guild
|
||||
|
||||
guild = Guild(data, client_instance=self)
|
||||
self._guilds[guild.id] = guild
|
||||
self._guilds.set(guild.id, guild)
|
||||
|
||||
# Populate channel and member caches if provided
|
||||
for ch in data.get("channels", []):
|
||||
channel_obj = self.parse_channel(ch)
|
||||
guild._channels[channel_obj.id] = channel_obj
|
||||
presences = {p["user"]["id"]: p for p in data.get("presences", [])}
|
||||
voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])}
|
||||
|
||||
for member in data.get("members", []):
|
||||
member_obj = self.parse_member(member, guild.id)
|
||||
guild._members[member_obj.id] = member_obj
|
||||
for ch_data in data.get("channels", []):
|
||||
self.parse_channel(ch_data)
|
||||
|
||||
for member_data in data.get("members", []):
|
||||
user_id = member_data.get("user", {}).get("id")
|
||||
if user_id:
|
||||
presence = presences.get(user_id)
|
||||
if presence:
|
||||
member_data["status"] = presence.get("status", "offline")
|
||||
|
||||
voice_state = voice_states.get(user_id)
|
||||
if voice_state:
|
||||
member_data["voice_state"] = voice_state
|
||||
|
||||
self.parse_member(member_data, guild.id)
|
||||
|
||||
return guild
|
||||
|
||||
@ -1067,6 +1127,7 @@ class Client:
|
||||
session_id = state["session_id"]
|
||||
|
||||
voice = VoiceClient(
|
||||
self,
|
||||
endpoint,
|
||||
session_id,
|
||||
token,
|
||||
@ -1259,7 +1320,7 @@ class Client:
|
||||
|
||||
channel = channel_factory(channel_data, self)
|
||||
|
||||
self._channels[channel.id] = channel
|
||||
self._channels.set(channel.id, channel)
|
||||
return channel
|
||||
|
||||
except DisagreementException as e: # Includes HTTPException
|
||||
@ -1445,6 +1506,37 @@ class Client:
|
||||
data = await self._http.get_channel_invites(channel_id)
|
||||
return [self.parse_invite(inv) for inv in data]
|
||||
|
||||
def add_persistent_view(self, view: "View") -> None:
|
||||
"""
|
||||
Registers a persistent view with the client.
|
||||
|
||||
Persistent views have a timeout of `None` and their components must have a `custom_id`.
|
||||
This allows the view to be re-instantiated across bot restarts.
|
||||
|
||||
Args:
|
||||
view (View): The view instance to register.
|
||||
|
||||
Raises:
|
||||
ValueError: If the view is not persistent (timeout is not None) or if a component's
|
||||
custom_id is already registered.
|
||||
"""
|
||||
if self.is_ready():
|
||||
print(
|
||||
"Warning: Adding a persistent view after the client is ready. "
|
||||
"This view will only be available for interactions on this session."
|
||||
)
|
||||
|
||||
if view.timeout is not None:
|
||||
raise ValueError("Persistent views must have a timeout of None.")
|
||||
|
||||
for item in view.children:
|
||||
if item.custom_id: # Ensure custom_id is not None
|
||||
if item.custom_id in self._persistent_views:
|
||||
raise ValueError(
|
||||
f"A component with custom_id '{item.custom_id}' is already registered."
|
||||
)
|
||||
self._persistent_views[item.custom_id] = view
|
||||
|
||||
# --- Application Command Methods ---
|
||||
async def process_interaction(self, interaction: Interaction) -> None:
|
||||
"""Internal method to process an interaction from the gateway."""
|
||||
@ -1455,11 +1547,25 @@ class Client:
|
||||
if (
|
||||
interaction.type == InteractionType.MESSAGE_COMPONENT
|
||||
and interaction.message
|
||||
and interaction.data
|
||||
):
|
||||
view = self._views.get(interaction.message.id)
|
||||
if view:
|
||||
asyncio.create_task(view._dispatch(interaction))
|
||||
return
|
||||
else:
|
||||
# No active view found, check for persistent views
|
||||
custom_id = interaction.data.custom_id
|
||||
if custom_id:
|
||||
registered_view = self._persistent_views.get(custom_id)
|
||||
if registered_view:
|
||||
# Create a new instance of the persistent view
|
||||
new_view = registered_view.__class__()
|
||||
await new_view._start(self)
|
||||
new_view.message_id = interaction.message.id
|
||||
self._views[interaction.message.id] = new_view
|
||||
asyncio.create_task(new_view._dispatch(interaction))
|
||||
return
|
||||
|
||||
await self.app_command_handler.process_interaction(interaction)
|
||||
|
||||
|
@ -76,7 +76,7 @@ class EventDispatcher:
|
||||
"""Parses MESSAGE_DELETE and updates message cache."""
|
||||
message_id = data.get("id")
|
||||
if message_id:
|
||||
self._client._messages.pop(message_id, None)
|
||||
self._client._messages.invalidate(message_id)
|
||||
return data
|
||||
|
||||
def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@ -124,7 +124,7 @@ class EventDispatcher:
|
||||
"""Parses GUILD_MEMBER_ADD into a Member object."""
|
||||
|
||||
guild_id = str(data.get("guild_id"))
|
||||
return self._client.parse_member(data, guild_id)
|
||||
return self._client.parse_member(data, guild_id, just_joined=True)
|
||||
|
||||
def _parse_guild_member_remove(self, data: Dict[str, Any]):
|
||||
"""Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model."""
|
||||
|
@ -18,6 +18,8 @@ from .decorators import (
|
||||
cooldown,
|
||||
max_concurrency,
|
||||
requires_permissions,
|
||||
has_role,
|
||||
has_any_role,
|
||||
)
|
||||
from .errors import (
|
||||
CommandError,
|
||||
@ -47,6 +49,8 @@ __all__ = [
|
||||
"cooldown",
|
||||
"max_concurrency",
|
||||
"requires_permissions",
|
||||
"has_role",
|
||||
"has_any_role",
|
||||
# Errors
|
||||
"CommandError",
|
||||
"CommandNotFound",
|
||||
|
@ -40,7 +40,42 @@ if TYPE_CHECKING:
|
||||
from disagreement.models import Message, User
|
||||
|
||||
|
||||
class Command:
|
||||
class GroupMixin:
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.commands: Dict[str, "Command"] = {}
|
||||
self.name: str = ""
|
||||
|
||||
def command(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
|
||||
cmd = Command(func, **attrs)
|
||||
cmd.cog = getattr(self, "cog", None)
|
||||
self.add_command(cmd)
|
||||
return cmd
|
||||
return decorator
|
||||
|
||||
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
|
||||
cmd = Group(func, **attrs)
|
||||
cmd.cog = getattr(self, "cog", None)
|
||||
self.add_command(cmd)
|
||||
return cmd
|
||||
return decorator
|
||||
|
||||
def add_command(self, command: "Command") -> None:
|
||||
if command.name in self.commands:
|
||||
raise ValueError(f"Command '{command.name}' is already registered in group '{self.name}'.")
|
||||
self.commands[command.name.lower()] = command
|
||||
for alias in command.aliases:
|
||||
if alias in self.commands:
|
||||
logger.warning(f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias.")
|
||||
self.commands[alias.lower()] = command
|
||||
|
||||
def get_command(self, name: str) -> Optional["Command"]:
|
||||
return self.commands.get(name.lower())
|
||||
|
||||
|
||||
class Command(GroupMixin):
|
||||
"""
|
||||
Represents a bot command.
|
||||
|
||||
@ -58,12 +93,14 @@ class Command:
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise TypeError("Command callback must be a coroutine function.")
|
||||
|
||||
super().__init__(**attrs)
|
||||
self.callback: Callable[..., Awaitable[None]] = callback
|
||||
self.name: str = attrs.get("name", callback.__name__)
|
||||
self.aliases: List[str] = attrs.get("aliases", [])
|
||||
self.brief: Optional[str] = attrs.get("brief")
|
||||
self.description: Optional[str] = attrs.get("description") or callback.__doc__
|
||||
self.cog: Optional["Cog"] = attrs.get("cog")
|
||||
self.invoke_without_command: bool = attrs.get("invoke_without_command", False)
|
||||
|
||||
self.params = inspect.signature(callback).parameters
|
||||
self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = []
|
||||
@ -79,20 +116,73 @@ class Command:
|
||||
) -> None:
|
||||
self.checks.append(predicate)
|
||||
|
||||
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
|
||||
async def _run_checks(self, ctx: "CommandContext") -> None:
|
||||
"""Runs all cog, local and global checks for the command."""
|
||||
from .errors import CheckFailure
|
||||
|
||||
# Run cog-level check first
|
||||
if self.cog:
|
||||
cog_check = getattr(self.cog, "cog_check", None)
|
||||
if cog_check:
|
||||
try:
|
||||
result = cog_check(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if not result:
|
||||
raise CheckFailure(
|
||||
f"The cog-level check for command '{self.name}' failed."
|
||||
)
|
||||
except CheckFailure:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise CommandInvokeError(e) from e
|
||||
|
||||
# Run local checks
|
||||
for predicate in self.checks:
|
||||
result = predicate(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if not result:
|
||||
raise CheckFailure("Check predicate failed.")
|
||||
raise CheckFailure(f"A local check for command '{self.name}' failed.")
|
||||
|
||||
# Then run global checks from the handler
|
||||
if hasattr(ctx.bot, "command_handler"):
|
||||
for predicate in ctx.bot.command_handler._global_checks:
|
||||
result = predicate(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if not result:
|
||||
raise CheckFailure(
|
||||
f"A global check failed for command '{self.name}'."
|
||||
)
|
||||
|
||||
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
|
||||
await self._run_checks(ctx)
|
||||
|
||||
before_invoke = None
|
||||
after_invoke = None
|
||||
|
||||
if self.cog:
|
||||
await self.callback(self.cog, ctx, *args, **kwargs)
|
||||
else:
|
||||
await self.callback(ctx, *args, **kwargs)
|
||||
before_invoke = getattr(self.cog, "cog_before_invoke", None)
|
||||
after_invoke = getattr(self.cog, "cog_after_invoke", None)
|
||||
|
||||
if before_invoke:
|
||||
await before_invoke(ctx)
|
||||
|
||||
try:
|
||||
if self.cog:
|
||||
await self.callback(self.cog, ctx, *args, **kwargs)
|
||||
else:
|
||||
await self.callback(ctx, *args, **kwargs)
|
||||
finally:
|
||||
if after_invoke:
|
||||
await after_invoke(ctx)
|
||||
|
||||
|
||||
class Group(Command):
|
||||
"""A command that can have subcommands."""
|
||||
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
|
||||
super().__init__(callback, **attrs)
|
||||
|
||||
|
||||
PrefixCommand = Command # Alias for clarity in hybrid commands
|
||||
@ -220,11 +310,20 @@ class CommandHandler:
|
||||
self.commands: Dict[str, Command] = {}
|
||||
self.cogs: Dict[str, "Cog"] = {}
|
||||
self._concurrency: Dict[str, Dict[str, int]] = {}
|
||||
self._global_checks: List[
|
||||
Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
] = []
|
||||
|
||||
from .help import HelpCommand
|
||||
|
||||
self.add_command(HelpCommand(self))
|
||||
|
||||
def add_check(
|
||||
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
) -> None:
|
||||
"""Adds a global check to the command handler."""
|
||||
self._global_checks.append(predicate)
|
||||
|
||||
def add_command(self, command: Command) -> None:
|
||||
if command.name in self.commands:
|
||||
raise ValueError(f"Command '{command.name}' is already registered.")
|
||||
@ -239,6 +338,15 @@ class CommandHandler:
|
||||
)
|
||||
self.commands[alias.lower()] = command
|
||||
|
||||
if isinstance(command, Group):
|
||||
for sub_cmd in command.commands.values():
|
||||
if sub_cmd.name in self.commands:
|
||||
logger.warning(
|
||||
"Subcommand '%s' of group '%s' conflicts with a top-level command.",
|
||||
sub_cmd.name,
|
||||
command.name,
|
||||
)
|
||||
|
||||
def remove_command(self, name: str) -> Optional[Command]:
|
||||
command = self.commands.pop(name.lower(), None)
|
||||
if command:
|
||||
@ -534,12 +642,28 @@ class CommandHandler:
|
||||
if not command:
|
||||
return
|
||||
|
||||
invoked_with = command_name
|
||||
original_command = command
|
||||
|
||||
if isinstance(command, Group):
|
||||
view.skip_whitespace()
|
||||
potential_subcommand = view.get_word()
|
||||
if potential_subcommand:
|
||||
subcommand = command.get_command(potential_subcommand)
|
||||
if subcommand:
|
||||
command = subcommand
|
||||
invoked_with += f" {potential_subcommand}"
|
||||
elif command.invoke_without_command:
|
||||
view.index -= len(potential_subcommand) + view.previous
|
||||
else:
|
||||
raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.")
|
||||
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=self.client,
|
||||
prefix=actual_prefix,
|
||||
command=command,
|
||||
invoked_with=command_name,
|
||||
invoked_with=invoked_with,
|
||||
cog=command.cog,
|
||||
)
|
||||
|
||||
@ -553,11 +677,14 @@ class CommandHandler:
|
||||
finally:
|
||||
self._release_concurrency(ctx)
|
||||
except CommandError as e:
|
||||
logger.error("Command error for '%s': %s", command.name, e)
|
||||
logger.error("Command error for '%s': %s", original_command.name, e)
|
||||
if hasattr(self.client, "on_command_error"):
|
||||
await self.client.on_command_error(ctx, e)
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error invoking command '%s': %s", command.name, e)
|
||||
logger.error("Unexpected error invoking command '%s': %s", original_command.name, e)
|
||||
exc = CommandInvokeError(e)
|
||||
if hasattr(self.client, "on_command_error"):
|
||||
await self.client.on_command_error(ctx, exc)
|
||||
else:
|
||||
if hasattr(self.client, "on_command_completion"):
|
||||
await self.client.on_command_completion(ctx)
|
||||
|
@ -217,3 +217,82 @@ def requires_permissions(
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
def has_role(
|
||||
name_or_id: str | int,
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Check that the invoking member has a role with the given name or ID."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
from disagreement.models import Member
|
||||
|
||||
if not ctx.guild:
|
||||
raise CheckFailure("This command cannot be used in DMs.")
|
||||
|
||||
author = ctx.author
|
||||
if not isinstance(author, Member):
|
||||
try:
|
||||
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
|
||||
except Exception:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
if not author:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
# Create a list of the member's role objects by looking them up in the guild's roles list
|
||||
member_roles = [
|
||||
role for role in ctx.guild.roles if role.id in author.roles
|
||||
]
|
||||
|
||||
if any(
|
||||
role.id == str(name_or_id) or role.name == name_or_id
|
||||
for role in member_roles
|
||||
):
|
||||
return True
|
||||
|
||||
raise CheckFailure(f"You need the '{name_or_id}' role to use this command.")
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def has_any_role(
|
||||
*names_or_ids: str | int,
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Check that the invoking member has any of the roles with the given names or IDs."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
from disagreement.models import Member
|
||||
|
||||
if not ctx.guild:
|
||||
raise CheckFailure("This command cannot be used in DMs.")
|
||||
|
||||
author = ctx.author
|
||||
if not isinstance(author, Member):
|
||||
try:
|
||||
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
|
||||
except Exception:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
if not author:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
member_roles = [
|
||||
role for role in ctx.guild.roles if role.id in author.roles
|
||||
]
|
||||
# Convert names_or_ids to a set for efficient lookup
|
||||
names_or_ids_set = set(map(str, names_or_ids))
|
||||
|
||||
if any(
|
||||
role.id in names_or_ids_set or role.name in names_or_ids_set
|
||||
for role in member_roles
|
||||
):
|
||||
return True
|
||||
|
||||
role_list = ", ".join(f"'{r}'" for r in names_or_ids)
|
||||
raise CheckFailure(
|
||||
f"You need one of the following roles to use this command: {role_list}"
|
||||
)
|
||||
|
||||
return check(predicate)
|
||||
|
@ -79,6 +79,8 @@ class GatewayClient:
|
||||
self._buffer = bytearray()
|
||||
self._inflator = zlib.decompressobj()
|
||||
|
||||
self._member_chunk_requests: Dict[str, asyncio.Future] = {}
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
"""Attempts to reconnect using exponential backoff with jitter."""
|
||||
delay = 1.0
|
||||
@ -237,6 +239,32 @@ class GatewayClient:
|
||||
}
|
||||
await self._send_json(payload)
|
||||
|
||||
async def request_guild_members(
|
||||
self,
|
||||
guild_id: str,
|
||||
query: str = "",
|
||||
limit: int = 0,
|
||||
presences: bool = False,
|
||||
user_ids: Optional[list[str]] = None,
|
||||
nonce: Optional[str] = None,
|
||||
):
|
||||
"""Sends the request guild members payload to the Gateway."""
|
||||
payload = {
|
||||
"op": GatewayOpcode.REQUEST_GUILD_MEMBERS,
|
||||
"d": {
|
||||
"guild_id": guild_id,
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"presences": presences,
|
||||
},
|
||||
}
|
||||
if user_ids:
|
||||
payload["d"]["user_ids"] = user_ids
|
||||
if nonce:
|
||||
payload["d"]["nonce"] = nonce
|
||||
|
||||
await self._send_json(payload)
|
||||
|
||||
async def _handle_dispatch(self, data: Dict[str, Any]):
|
||||
"""Handles DISPATCH events (actual Discord events)."""
|
||||
event_name = data.get("t")
|
||||
@ -313,6 +341,22 @@ class GatewayClient:
|
||||
)
|
||||
|
||||
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
|
||||
elif event_name == "GUILD_MEMBERS_CHUNK":
|
||||
if isinstance(raw_event_d_payload, dict):
|
||||
nonce = raw_event_d_payload.get("nonce")
|
||||
if nonce and nonce in self._member_chunk_requests:
|
||||
future = self._member_chunk_requests[nonce]
|
||||
if not future.done():
|
||||
# Append members to a temporary list stored on the future object
|
||||
if not hasattr(future, "_members"):
|
||||
future._members = [] # type: ignore
|
||||
future._members.extend(raw_event_d_payload.get("members", [])) # type: ignore
|
||||
|
||||
# If this is the last chunk, resolve the future
|
||||
if raw_event_d_payload.get("chunk_index") == raw_event_d_payload.get("chunk_count", 1) - 1:
|
||||
future.set_result(future._members) # type: ignore
|
||||
del self._member_chunk_requests[nonce]
|
||||
|
||||
elif event_name == "INTERACTION_CREATE":
|
||||
# print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}")
|
||||
if isinstance(raw_event_d_payload, dict):
|
||||
|
@ -368,6 +368,20 @@ class HTTPClient:
|
||||
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
|
||||
)
|
||||
|
||||
async def delete_user_reaction(
|
||||
self,
|
||||
channel_id: "Snowflake",
|
||||
message_id: "Snowflake",
|
||||
emoji: str,
|
||||
user_id: "Snowflake",
|
||||
) -> None:
|
||||
"""Removes another user's reaction from a message."""
|
||||
encoded = quote(emoji)
|
||||
await self.request(
|
||||
"DELETE",
|
||||
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}",
|
||||
)
|
||||
|
||||
async def get_reactions(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
@ -400,6 +414,27 @@ class HTTPClient:
|
||||
)
|
||||
return messages
|
||||
|
||||
async def get_pinned_messages(
|
||||
self, channel_id: "Snowflake"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Fetches all pinned messages in a channel."""
|
||||
|
||||
return await self.request("GET", f"/channels/{channel_id}/pins")
|
||||
|
||||
async def pin_message(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||
) -> None:
|
||||
"""Pins a message in a channel."""
|
||||
|
||||
await self.request("PUT", f"/channels/{channel_id}/pins/{message_id}")
|
||||
|
||||
async def unpin_message(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||
) -> None:
|
||||
"""Unpins a message from a channel."""
|
||||
|
||||
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
|
||||
|
||||
async def delete_channel(
|
||||
self, channel_id: str, reason: Optional[str] = None
|
||||
) -> None:
|
||||
@ -420,6 +455,21 @@ class HTTPClient:
|
||||
custom_headers=custom_headers if custom_headers else None,
|
||||
)
|
||||
|
||||
async def edit_channel(
|
||||
self,
|
||||
channel_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
reason: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Edits a channel."""
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
return await self.request(
|
||||
"PATCH",
|
||||
f"/channels/{channel_id}",
|
||||
payload=payload,
|
||||
custom_headers=headers,
|
||||
)
|
||||
|
||||
async def get_channel(self, channel_id: str) -> Dict[str, Any]:
|
||||
"""Fetches a channel by ID."""
|
||||
return await self.request("GET", f"/channels/{channel_id}")
|
||||
@ -1039,3 +1089,32 @@ class HTTPClient:
|
||||
async def get_voice_regions(self) -> List[Dict[str, Any]]:
|
||||
"""Returns available voice regions."""
|
||||
return await self.request("GET", "/voice/regions")
|
||||
|
||||
async def start_thread_from_message(
|
||||
self,
|
||||
channel_id: "Snowflake",
|
||||
message_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Starts a new thread from an existing message."""
|
||||
return await self.request(
|
||||
"POST",
|
||||
f"/channels/{channel_id}/messages/{message_id}/threads",
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
async def start_thread_without_message(
|
||||
self, channel_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Starts a new thread that is not attached to a message."""
|
||||
return await self.request(
|
||||
"POST", f"/channels/{channel_id}/threads", payload=payload
|
||||
)
|
||||
|
||||
async def join_thread(self, channel_id: "Snowflake") -> None:
|
||||
"""Joins the current user to a thread."""
|
||||
await self.request("PUT", f"/channels/{channel_id}/thread-members/@me")
|
||||
|
||||
async def leave_thread(self, channel_id: "Snowflake") -> None:
|
||||
"""Removes the current user from a thread."""
|
||||
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")
|
||||
|
@ -7,7 +7,9 @@ Data models for Discord objects.
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||
|
||||
from .cache import ChannelCache, MemberCache
|
||||
|
||||
import aiohttp # pylint: disable=import-error
|
||||
from .color import Color
|
||||
@ -105,11 +107,38 @@ class Message:
|
||||
self.attachments: List[Attachment] = [
|
||||
Attachment(a) for a in data.get("attachments", [])
|
||||
]
|
||||
self.pinned: bool = data.get("pinned", False)
|
||||
# Add other fields as needed, e.g., attachments, embeds, reactions, etc.
|
||||
# self.mentions: List[User] = [User(u) for u in data.get("mentions", [])]
|
||||
# self.mention_roles: List[str] = data.get("mention_roles", [])
|
||||
# self.mention_everyone: bool = data.get("mention_everyone", False)
|
||||
|
||||
async def pin(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Pins this message to its channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException
|
||||
Pinning the message failed.
|
||||
"""
|
||||
await self._client._http.pin_message(self.channel_id, self.id)
|
||||
self.pinned = True
|
||||
|
||||
async def unpin(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Unpins this message from its channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException
|
||||
Unpinning the message failed.
|
||||
"""
|
||||
await self._client._http.unpin_message(self.channel_id, self.id)
|
||||
self.pinned = False
|
||||
|
||||
async def reply(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
@ -210,10 +239,17 @@ class Message:
|
||||
|
||||
await self._client.add_reaction(self.channel_id, self.id, emoji)
|
||||
|
||||
async def remove_reaction(self, emoji: str) -> None:
|
||||
"""|coro| Remove the bot's reaction from this message."""
|
||||
|
||||
await self._client.remove_reaction(self.channel_id, self.id, emoji)
|
||||
async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None:
|
||||
"""|coro|
|
||||
Removes a reaction from this message.
|
||||
If no ``member`` is provided, removes the bot's own reaction.
|
||||
"""
|
||||
if member:
|
||||
await self._client._http.delete_user_reaction(
|
||||
self.channel_id, self.id, emoji, member.id
|
||||
)
|
||||
else:
|
||||
await self._client.remove_reaction(self.channel_id, self.id, emoji)
|
||||
|
||||
async def clear_reactions(self) -> None:
|
||||
"""|coro| Remove all reactions from this message."""
|
||||
@ -239,6 +275,125 @@ class Message:
|
||||
def __repr__(self) -> str:
|
||||
return f"<Message id='{self.id}' channel_id='{self.channel_id}' author='{self.author!r}'>"
|
||||
|
||||
async def create_thread(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
auto_archive_duration: Optional[int] = None,
|
||||
rate_limit_per_user: Optional[int] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> "Thread":
|
||||
"""|coro|
|
||||
|
||||
Creates a new thread from this message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the thread.
|
||||
auto_archive_duration: Optional[int]
|
||||
The duration in minutes to automatically archive the thread after recent activity.
|
||||
Can be one of 60, 1440, 4320, 10080.
|
||||
rate_limit_per_user: Optional[int]
|
||||
The number of seconds a user has to wait before sending another message.
|
||||
reason: Optional[str]
|
||||
The reason for creating the thread.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Thread
|
||||
The created thread.
|
||||
"""
|
||||
payload: Dict[str, Any] = {"name": name}
|
||||
if auto_archive_duration is not None:
|
||||
payload["auto_archive_duration"] = auto_archive_duration
|
||||
if rate_limit_per_user is not None:
|
||||
payload["rate_limit_per_user"] = rate_limit_per_user
|
||||
|
||||
data = await self._client._http.start_thread_from_message(
|
||||
self.channel_id, self.id, payload
|
||||
)
|
||||
return cast("Thread", self._client.parse_channel(data))
|
||||
|
||||
|
||||
class PartialMessage:
|
||||
"""Represents a partial message, identified by its ID and channel.
|
||||
|
||||
This model is used to perform actions on a message without having the
|
||||
full message object in the cache.
|
||||
|
||||
Attributes:
|
||||
id (str): The message's unique ID.
|
||||
channel (TextChannel): The text channel this message belongs to.
|
||||
"""
|
||||
|
||||
def __init__(self, *, id: str, channel: "TextChannel"):
|
||||
self.id = id
|
||||
self.channel = channel
|
||||
self._client = channel._client
|
||||
|
||||
async def fetch(self) -> "Message":
|
||||
"""|coro|
|
||||
|
||||
Fetches the full message data from Discord.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Message
|
||||
The complete message object.
|
||||
"""
|
||||
data = await self._client._http.get_message(self.channel.id, self.id)
|
||||
return self._client.parse_message(data)
|
||||
|
||||
async def delete(self, *, delay: Optional[float] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes this message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
delay: Optional[float]
|
||||
If provided, wait this many seconds before deleting.
|
||||
"""
|
||||
if delay is not None:
|
||||
await asyncio.sleep(delay)
|
||||
await self._client._http.delete_message(self.channel.id, self.id)
|
||||
|
||||
async def pin(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Pins this message to its channel.
|
||||
"""
|
||||
await self._client._http.pin_message(self.channel.id, self.id)
|
||||
|
||||
async def unpin(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Unpins this message from its channel.
|
||||
"""
|
||||
await self._client._http.unpin_message(self.channel.id, self.id)
|
||||
|
||||
async def add_reaction(self, emoji: str) -> None:
|
||||
"""|coro|
|
||||
|
||||
Adds a reaction to this message.
|
||||
"""
|
||||
await self._client._http.create_reaction(self.channel.id, self.id, emoji)
|
||||
|
||||
async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Removes a reaction from this message.
|
||||
|
||||
If no ``member`` is provided, removes the bot's own reaction.
|
||||
"""
|
||||
if member:
|
||||
await self._client._http.delete_user_reaction(
|
||||
self.channel.id, self.id, emoji, member.id
|
||||
)
|
||||
else:
|
||||
await self._client._http.delete_reaction(self.channel.id, self.id, emoji)
|
||||
|
||||
|
||||
class EmbedFooter:
|
||||
"""Represents an embed footer."""
|
||||
@ -503,6 +658,8 @@ class Member(User): # Member inherits from User
|
||||
):
|
||||
self._client: Optional["Client"] = client_instance
|
||||
self.guild_id: Optional[str] = None
|
||||
self.status: Optional[str] = None
|
||||
self.voice_state: Optional[Dict[str, Any]] = None
|
||||
# User part is nested under 'user' key in member data from gateway/API
|
||||
user_data = data.get("user", {})
|
||||
# If 'id' is not in user_data but is top-level (e.g. from interaction resolved member without user object)
|
||||
@ -929,8 +1086,8 @@ class Guild:
|
||||
)
|
||||
|
||||
# Internal caches, populated by events or specific fetches
|
||||
self._channels: Dict[str, "Channel"] = {}
|
||||
self._members: Dict[str, Member] = {}
|
||||
self._channels: ChannelCache = ChannelCache()
|
||||
self._members: MemberCache = MemberCache(client_instance.member_cache_flags)
|
||||
self._threads: Dict[str, "Thread"] = {}
|
||||
|
||||
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
||||
@ -970,6 +1127,49 @@ class Guild:
|
||||
def __repr__(self) -> str:
|
||||
return f"<Guild id='{self.id}' name='{self.name}'>"
|
||||
|
||||
async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]:
|
||||
"""|coro|
|
||||
|
||||
Fetches all members for this guild.
|
||||
|
||||
This requires the ``GUILD_MEMBERS`` intent.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: Optional[int]
|
||||
The maximum number of members to fetch. If ``None``, all members
|
||||
are fetched.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Member]
|
||||
A list of all members in the guild.
|
||||
|
||||
Raises
|
||||
------
|
||||
DisagreementException
|
||||
The gateway is not available to make the request.
|
||||
asyncio.TimeoutError
|
||||
The request timed out.
|
||||
"""
|
||||
if not self._client._gateway:
|
||||
raise DisagreementException("Gateway not available for member fetching.")
|
||||
|
||||
nonce = str(asyncio.get_running_loop().time())
|
||||
future = self._client._gateway._loop.create_future()
|
||||
self._client._gateway._member_chunk_requests[nonce] = future
|
||||
|
||||
try:
|
||||
await self._client._gateway.request_guild_members(
|
||||
self.id, limit=limit or 0, nonce=nonce
|
||||
)
|
||||
member_data = await asyncio.wait_for(future, timeout=60.0)
|
||||
return [Member(m, self._client) for m in member_data]
|
||||
except asyncio.TimeoutError:
|
||||
if nonce in self._client._gateway._member_chunk_requests:
|
||||
del self._client._gateway._member_chunk_requests[nonce]
|
||||
raise
|
||||
|
||||
|
||||
class Channel:
|
||||
"""Base class for Discord channels."""
|
||||
@ -1142,12 +1342,98 @@ class TextChannel(Channel):
|
||||
|
||||
await self._client._http.bulk_delete_messages(self.id, ids)
|
||||
for mid in ids:
|
||||
self._client._messages.pop(mid, None)
|
||||
self._client._messages.invalidate(mid)
|
||||
return ids
|
||||
|
||||
def get_partial_message(self, id: int) -> "PartialMessage":
|
||||
"""Returns a :class:`PartialMessage` for the given ID.
|
||||
|
||||
This allows performing actions on a message without fetching it first.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
id: int
|
||||
The ID of the message to get a partial instance of.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PartialMessage
|
||||
The partial message instance.
|
||||
"""
|
||||
return PartialMessage(id=str(id), channel=self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
|
||||
|
||||
async def pins(self) -> List["Message"]:
|
||||
"""|coro|
|
||||
|
||||
Fetches all pinned messages in this channel.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Message]
|
||||
The pinned messages.
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException
|
||||
Fetching the pinned messages failed.
|
||||
"""
|
||||
|
||||
messages_data = await self._client._http.get_pinned_messages(self.id)
|
||||
return [self._client.parse_message(m) for m in messages_data]
|
||||
|
||||
async def create_thread(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
type: ChannelType = ChannelType.PUBLIC_THREAD,
|
||||
auto_archive_duration: Optional[int] = None,
|
||||
invitable: Optional[bool] = None,
|
||||
rate_limit_per_user: Optional[int] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> "Thread":
|
||||
"""|coro|
|
||||
|
||||
Creates a new thread in this channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the thread.
|
||||
type: ChannelType
|
||||
The type of thread to create. Defaults to PUBLIC_THREAD.
|
||||
Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD.
|
||||
auto_archive_duration: Optional[int]
|
||||
The duration in minutes to automatically archive the thread after recent activity.
|
||||
invitable: Optional[bool]
|
||||
Whether non-moderators can invite other non-moderators to a private thread.
|
||||
Only applicable to private threads.
|
||||
rate_limit_per_user: Optional[int]
|
||||
The number of seconds a user has to wait before sending another message.
|
||||
reason: Optional[str]
|
||||
The reason for creating the thread.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Thread
|
||||
The created thread.
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"name": name,
|
||||
"type": type.value,
|
||||
}
|
||||
if auto_archive_duration is not None:
|
||||
payload["auto_archive_duration"] = auto_archive_duration
|
||||
if invitable is not None and type == ChannelType.PRIVATE_THREAD:
|
||||
payload["invitable"] = invitable
|
||||
if rate_limit_per_user is not None:
|
||||
payload["rate_limit_per_user"] = rate_limit_per_user
|
||||
|
||||
data = await self._client._http.start_thread_without_message(self.id, payload)
|
||||
return cast("Thread", self._client.parse_channel(data))
|
||||
|
||||
|
||||
class VoiceChannel(Channel):
|
||||
"""Represents a guild voice channel or stage voice channel."""
|
||||
@ -1305,6 +1591,44 @@ class Thread(TextChannel): # Threads are a specialized TextChannel
|
||||
f"<Thread id='{self.id}' name='{self.name}' parent_id='{self.parent_id}'>"
|
||||
)
|
||||
|
||||
async def join(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Joins this thread.
|
||||
"""
|
||||
await self._client._http.join_thread(self.id)
|
||||
|
||||
async def leave(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Leaves this thread.
|
||||
"""
|
||||
await self._client._http.leave_thread(self.id)
|
||||
|
||||
async def archive(self, locked: bool = False, *, reason: Optional[str] = None) -> "Thread":
|
||||
"""|coro|
|
||||
|
||||
Archives this thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
locked: bool
|
||||
Whether to lock the thread.
|
||||
reason: Optional[str]
|
||||
The reason for archiving the thread.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Thread
|
||||
The updated thread.
|
||||
"""
|
||||
payload = {
|
||||
"archived": True,
|
||||
"locked": locked,
|
||||
}
|
||||
data = await self._client._http.edit_channel(self.id, payload, reason=reason)
|
||||
return cast("Thread", self._client.parse_channel(data))
|
||||
|
||||
|
||||
class DMChannel(Channel):
|
||||
"""Represents a Direct Message channel."""
|
||||
|
@ -28,6 +28,8 @@ class View:
|
||||
self._client: Optional[Client] = None
|
||||
self._message_id: Optional[str] = None
|
||||
|
||||
# The below is a bit of a hack to support items defined as class members
|
||||
# e.g. button = Button(...)
|
||||
for item in self.__class__.__dict__.values():
|
||||
if isinstance(item, Item):
|
||||
self.add_item(item)
|
||||
@ -44,6 +46,11 @@ class View:
|
||||
if len(self.__children) >= 25:
|
||||
raise ValueError("A view can only have a maximum of 25 components.")
|
||||
|
||||
if self.timeout is None and item.custom_id is None:
|
||||
raise ValueError(
|
||||
"All components in a persistent view must have a 'custom_id'."
|
||||
)
|
||||
|
||||
item._view = self
|
||||
self.__children.append(item)
|
||||
|
||||
@ -65,12 +72,7 @@ class View:
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
if item.custom_id is None:
|
||||
item.custom_id = (
|
||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||
)
|
||||
|
||||
rows.append(ActionRow(components=[item]))
|
||||
rows.append(ActionRow(components=[item]))
|
||||
|
||||
return rows
|
||||
|
||||
|
@ -6,11 +6,19 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
from typing import Optional, Sequence
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Optional, Sequence
|
||||
|
||||
import aiohttp
|
||||
# The following import is correct, but may be flagged by Pylance if the virtual
|
||||
# environment is not configured correctly.
|
||||
from nacl.secret import SecretBox
|
||||
|
||||
from .audio import AudioSource, FFmpegAudioSource
|
||||
from .audio import AudioSink, AudioSource, FFmpegAudioSource
|
||||
from .models import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
|
||||
|
||||
class VoiceClient:
|
||||
@ -18,6 +26,7 @@ class VoiceClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
endpoint: str,
|
||||
session_id: str,
|
||||
token: str,
|
||||
@ -29,6 +38,7 @@ class VoiceClient:
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.endpoint = endpoint
|
||||
self.session_id = session_id
|
||||
self.token = token
|
||||
@ -38,8 +48,14 @@ class VoiceClient:
|
||||
self._udp = udp
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._udp_receive_thread: Optional[threading.Thread] = None
|
||||
self._heartbeat_interval: Optional[float] = None
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
try:
|
||||
self._loop = loop or asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self.verbose = verbose
|
||||
self.ssrc: Optional[int] = None
|
||||
self.secret_key: Optional[Sequence[int]] = None
|
||||
@ -47,6 +63,9 @@ class VoiceClient:
|
||||
self._server_port: Optional[int] = None
|
||||
self._current_source: Optional[AudioSource] = None
|
||||
self._play_task: Optional[asyncio.Task] = None
|
||||
self._sink: Optional[AudioSink] = None
|
||||
self._ssrc_map: dict[int, int] = {}
|
||||
self._ssrc_lock = threading.Lock()
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._ws is None:
|
||||
@ -106,6 +125,49 @@ class VoiceClient:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
assert self._ws is not None
|
||||
while True:
|
||||
try:
|
||||
msg = await self._ws.receive_json()
|
||||
op = msg.get("op")
|
||||
data = msg.get("d")
|
||||
if op == 5: # Speaking
|
||||
user_id = int(data["user_id"])
|
||||
ssrc = data["ssrc"]
|
||||
with self._ssrc_lock:
|
||||
self._ssrc_map[ssrc] = user_id
|
||||
except (asyncio.CancelledError, aiohttp.ClientError):
|
||||
break
|
||||
|
||||
def _udp_receive_loop(self) -> None:
|
||||
assert self._udp is not None
|
||||
assert self.secret_key is not None
|
||||
box = SecretBox(bytes(self.secret_key))
|
||||
while True:
|
||||
try:
|
||||
packet = self._udp.recv(4096)
|
||||
if len(packet) < 12:
|
||||
continue
|
||||
|
||||
ssrc = int.from_bytes(packet[8:12], "big")
|
||||
with self._ssrc_lock:
|
||||
if ssrc not in self._ssrc_map:
|
||||
continue
|
||||
user_id = self._ssrc_map[ssrc]
|
||||
user = self.client._users.get(str(user_id))
|
||||
if not user:
|
||||
continue
|
||||
|
||||
decrypted = box.decrypt(packet[12:])
|
||||
if self._sink:
|
||||
self._sink.write(user, decrypted)
|
||||
except (socket.error, asyncio.CancelledError):
|
||||
break
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error in UDP receive loop: {e}")
|
||||
|
||||
async def send_audio_frame(self, frame: bytes) -> None:
|
||||
if not self._udp:
|
||||
raise RuntimeError("UDP socket not initialised")
|
||||
@ -148,15 +210,35 @@ class VoiceClient:
|
||||
|
||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||
|
||||
def listen(self, sink: AudioSink) -> None:
|
||||
"""Start listening to voice and routing to a sink."""
|
||||
if not isinstance(sink, AudioSink):
|
||||
raise TypeError("sink must be an AudioSink instance")
|
||||
|
||||
self._sink = sink
|
||||
if not self._udp_receive_thread:
|
||||
self._udp_receive_thread = threading.Thread(
|
||||
target=self._udp_receive_loop, daemon=True
|
||||
)
|
||||
self._udp_receive_thread.start()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.stop()
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._heartbeat_task
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._receive_task
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
if self._udp:
|
||||
self._udp.close()
|
||||
if self._udp_receive_thread:
|
||||
self._udp_receive_thread.join(timeout=1)
|
||||
if self._sink:
|
||||
self._sink.close()
|
||||
|
@ -26,6 +26,7 @@ classifiers = [
|
||||
|
||||
dependencies = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"PyNaCl>=1.5.0,<2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
Loading…
x
Reference in New Issue
Block a user