Compare commits

...

8 Commits

Author SHA1 Message Date
ed83a9da85
Implements caching system with TTL and member filtering
Introduces a flexible caching infrastructure with time-to-live support and configurable member caching based on status, voice state, and join events.

Adds AudioSink abstract base class to support audio output handling in voice connections.

Replaces direct dictionary access with cache objects throughout the client, enabling automatic expiration and intelligent member filtering based on user-defined flags.

Updates guild parsing to incorporate presence and voice state data for more accurate member caching decisions.
2025-06-11 02:11:33 -06:00
0151526d07
chore: Apply code formatting across the codebase
This commit applies consistent code formatting to multiple files. No functional changes are included.
2025-06-11 02:10:33 -06:00
17b7ea35a9
feat: Implement guild.fetch_members 2025-06-11 02:06:19 -06:00
28702fa8a1
feat(voice): Implement voice receiving and audio sinks 2025-06-11 02:06:18 -06:00
97505948ee
feat: Add advanced message, channel, and thread management 2025-06-11 02:06:16 -06:00
152c0f12be
feat(ui): Implement persistent views 2025-06-11 02:06:15 -06:00
eb38ecf671
feat(commands): Add has_role and has_any_role check decorators 2025-06-11 02:06:13 -06:00
2bd45c87ca
feat: Enhance command framework with groups, checks, and events 2025-06-11 02:06:11 -06:00
14 changed files with 7679 additions and 6664 deletions

View File

@ -114,3 +114,20 @@ class FFmpegAudioSource(AudioSource):
if isinstance(self.source, io.IOBase):
with contextlib.suppress(Exception):
self.source.close()
class AudioSink:
"""Abstract base class for audio sinks."""
def write(self, user, data):
"""Write a chunk of PCM audio.
Subclasses must implement this. The data is raw PCM at 48kHz
stereo.
"""
raise NotImplementedError
def close(self) -> None:
"""Cleanup the sink when the voice client disconnects."""
return None

View File

@ -4,7 +4,8 @@ import time
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
if TYPE_CHECKING:
from .models import Channel, Guild
from .models import Channel, Guild, Member
from .caching import MemberCacheFlags
T = TypeVar("T")
@ -53,3 +54,32 @@ class GuildCache(Cache["Guild"]):
class ChannelCache(Cache["Channel"]):
"""Cache specifically for :class:`Channel` objects."""
class MemberCache(Cache["Member"]):
"""
A cache for :class:`Member` objects that respects :class:`MemberCacheFlags`.
"""
def __init__(self, flags: MemberCacheFlags, ttl: Optional[float] = None) -> None:
super().__init__(ttl)
self.flags = flags
def _should_cache(self, member: Member) -> bool:
"""Determines if a member should be cached based on the flags."""
if self.flags.all:
return True
if self.flags.none:
return False
if self.flags.online and member.status != "offline":
return True
if self.flags.voice and member.voice_state is not None:
return True
if self.flags.joined and getattr(member, "_just_joined", False):
return True
return False
def set(self, key: str, value: Member) -> None:
if self._should_cache(value):
super().set(key, value)

120
disagreement/caching.py Normal file
View File

@ -0,0 +1,120 @@
from __future__ import annotations
import operator
from typing import Any, Callable, ClassVar, Dict, Iterator, Tuple
class _MemberCacheFlagValue:
flag: int
def __init__(self, func: Callable[[Any], bool]):
self.flag = getattr(func, 'flag', 0)
self.__doc__ = func.__doc__
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any:
if instance is None:
return self
return instance.value & self.flag != 0
def __set__(self, instance: Any, value: bool) -> None:
if value:
instance.value |= self.flag
else:
instance.value &= ~self.flag
def __repr__(self) -> str:
return f'<{self.__class__.__name__} flag={self.flag}>'
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
setattr(func, 'flag', flag)
return _MemberCacheFlagValue(func)
return decorator
class MemberCacheFlags:
__slots__ = ('value',)
VALID_FLAGS: ClassVar[Dict[str, int]] = {
'joined': 1 << 0,
'voice': 1 << 1,
'online': 1 << 2,
}
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_FLAGS
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid member cache flag.')
setattr(self, key, value)
@classmethod
def _from_value(cls, value: int) -> MemberCacheFlags:
self = cls.__new__(cls)
self.value = value
return self
def __eq__(self, other: object) -> bool:
return isinstance(other, MemberCacheFlags) and self.value == other.value
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
return hash(self.value)
def __repr__(self) -> str:
return f'<MemberCacheFlags value={self.value}>'
def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name in self.VALID_FLAGS:
yield name, getattr(self, name)
def __int__(self) -> int:
return self.value
def __index__(self) -> int:
return self.value
@classmethod
def all(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with all flags enabled."""
return cls._from_value(cls.ALL_FLAGS)
@classmethod
def none(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with all flags disabled."""
return cls._from_value(0)
@classmethod
def only_joined(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['joined'])
@classmethod
def only_voice(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['voice'])
@classmethod
def only_online(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['online'])
@flag_value(1 << 0)
def joined(self) -> bool:
"""Whether to cache members that have just joined the guild."""
return False
@flag_value(1 << 1)
def voice(self) -> bool:
"""Whether to cache members that are in a voice channel."""
return False
@flag_value(1 << 2)
def online(self) -> bool:
"""Whether to cache members that are online."""
return False

View File

@ -26,7 +26,9 @@ from .event_dispatcher import EventDispatcher
from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion
from .errors import DisagreementException, AuthenticationError
from .typing import Typing
from .ext.commands.core import CommandHandler
from .caching import MemberCacheFlags
from .cache import Cache, GuildCache, ChannelCache, MemberCache
from .ext.commands.core import Command, CommandHandler, Group
from .ext.commands.cog import Cog
from .ext.app_commands.handler import AppCommandHandler
from .ext.app_commands.context import AppCommandContext
@ -96,12 +98,16 @@ class Client:
shard_count: Optional[int] = None,
gateway_max_retries: int = 5,
gateway_max_backoff: float = 60.0,
member_cache_flags: Optional[MemberCacheFlags] = None,
http_options: Optional[Dict[str, Any]] = None,
):
if not token:
raise ValueError("A bot token must be provided.")
self.token: str = token
self.member_cache_flags: MemberCacheFlags = (
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
)
self.intents: int = intents if intents is not None else GatewayIntent.default()
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self.application_id: Optional[Snowflake] = (
@ -141,15 +147,12 @@ class Client:
)
# Internal Caches
self._guilds: Dict[Snowflake, "Guild"] = {}
self._channels: Dict[Snowflake, "Channel"] = (
{}
) # Stores all channel types by ID
self._users: Dict[Snowflake, Any] = (
{}
) # Placeholder for User model cache if needed
self._messages: Dict[Snowflake, "Message"] = {}
self._guilds: GuildCache = GuildCache()
self._channels: ChannelCache = ChannelCache()
self._users: Cache["User"] = Cache()
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour
self._views: Dict[Snowflake, "View"] = {}
self._persistent_views: Dict[str, "View"] = {}
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
self._webhooks: Dict[Snowflake, "Webhook"] = {}
@ -579,6 +582,41 @@ class Client:
# For now, assuming name is sufficient for removal from the handler's flat list.
return removed_cog
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
"""
A decorator that adds a global check to the bot.
This check will be called for every command before it's executed.
Example:
@bot.check
async def block_dms(ctx):
return ctx.guild is not None
"""
self.command_handler.add_check(coro)
return coro
def command(
self, **attrs: Any
) -> Callable[[Callable[..., Awaitable[None]]], Command]:
"""A decorator that transforms a function into a Command."""
def decorator(func: Callable[..., Awaitable[None]]) -> Command:
cmd = Command(func, **attrs)
self.command_handler.add_command(cmd)
return cmd
return decorator
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], Group]:
"""A decorator that transforms a function into a Group command."""
def decorator(func: Callable[..., Awaitable[None]]) -> Group:
cmd = Group(func, **attrs)
self.command_handler.add_command(cmd)
return cmd
return decorator
def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
"""
Adds a standalone application command or group to the bot.
@ -649,6 +687,16 @@ class Client:
# import traceback
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
async def on_command_completion(self, ctx: "CommandContext") -> None:
"""
Default command completion handler. Called when a command has successfully completed.
Users can override this method in a subclass of Client.
Args:
ctx (CommandContext): The context of the command that completed.
"""
pass
# --- Extension Management Methods ---
def load_extension(self, name: str) -> ModuleType:
@ -673,7 +721,7 @@ class Client:
from .models import User # Ensure User model is available
user = User(data)
self._users[user.id] = user # Cache the user
self._users.set(user.id, user) # Cache the user
return user
def parse_channel(self, data: Dict[str, Any]) -> "Channel":
@ -682,11 +730,11 @@ class Client:
from .models import channel_factory
channel = channel_factory(data, self)
self._channels[channel.id] = channel
self._channels.set(channel.id, channel)
if channel.guild_id:
guild = self._guilds.get(channel.guild_id)
if guild:
guild._channels[channel.id] = channel
guild._channels.set(channel.id, channel)
return channel
def parse_message(self, data: Dict[str, Any]) -> "Message":
@ -695,7 +743,7 @@ class Client:
from .models import Message
message = Message(data, client_instance=self)
self._messages[message.id] = message
self._messages.set(message.id, message)
return message
def parse_webhook(self, data: Union[Dict[str, Any], "Webhook"]) -> "Webhook":
@ -752,7 +800,7 @@ class Client:
cached_user = self._users.get(user_id)
if cached_user:
return cached_user # Return cached if available, though fetch implies wanting fresh
return cached_user
try:
user_data = await self._http.get_user(user_id)
@ -782,23 +830,26 @@ class Client:
)
return None
def parse_member(self, data: Dict[str, Any], guild_id: Snowflake) -> "Member":
def parse_member(
self, data: Dict[str, Any], guild_id: Snowflake, *, just_joined: bool = False
) -> "Member":
"""Parses member data and returns a Member object, updating relevant caches."""
from .models import Member # Ensure Member model is available
from .models import Member
# Member's __init__ should handle the nested 'user' data.
member = Member(data, client_instance=self)
member.guild_id = str(guild_id)
# Cache the member in the guild's member cache
if just_joined:
setattr(member, "_just_joined", True)
guild = self._guilds.get(guild_id)
if guild:
guild._members[member.id] = member # Assuming Guild has _members dict
guild._members.set(member.id, member)
# Also cache the user part if not already cached or if this is newer
# Since Member inherits from User, the member object itself is the user.
self._users[member.id] = member
# If 'user' was in data and Member.__init__ used it, it's already part of 'member'.
if just_joined and hasattr(member, "_just_joined"):
delattr(member, "_just_joined")
self._users.set(member.id, member)
return member
async def fetch_member(
@ -842,20 +893,29 @@ class Client:
def parse_guild(self, data: Dict[str, Any]) -> "Guild":
"""Parses guild data and returns a Guild object, updating cache."""
from .models import Guild
guild = Guild(data, client_instance=self)
self._guilds[guild.id] = guild
self._guilds.set(guild.id, guild)
# Populate channel and member caches if provided
for ch in data.get("channels", []):
channel_obj = self.parse_channel(ch)
guild._channels[channel_obj.id] = channel_obj
presences = {p["user"]["id"]: p for p in data.get("presences", [])}
voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])}
for member in data.get("members", []):
member_obj = self.parse_member(member, guild.id)
guild._members[member_obj.id] = member_obj
for ch_data in data.get("channels", []):
self.parse_channel(ch_data)
for member_data in data.get("members", []):
user_id = member_data.get("user", {}).get("id")
if user_id:
presence = presences.get(user_id)
if presence:
member_data["status"] = presence.get("status", "offline")
voice_state = voice_states.get(user_id)
if voice_state:
member_data["voice_state"] = voice_state
self.parse_member(member_data, guild.id)
return guild
@ -1067,6 +1127,7 @@ class Client:
session_id = state["session_id"]
voice = VoiceClient(
self,
endpoint,
session_id,
token,
@ -1259,7 +1320,7 @@ class Client:
channel = channel_factory(channel_data, self)
self._channels[channel.id] = channel
self._channels.set(channel.id, channel)
return channel
except DisagreementException as e: # Includes HTTPException
@ -1445,6 +1506,37 @@ class Client:
data = await self._http.get_channel_invites(channel_id)
return [self.parse_invite(inv) for inv in data]
def add_persistent_view(self, view: "View") -> None:
"""
Registers a persistent view with the client.
Persistent views have a timeout of `None` and their components must have a `custom_id`.
This allows the view to be re-instantiated across bot restarts.
Args:
view (View): The view instance to register.
Raises:
ValueError: If the view is not persistent (timeout is not None) or if a component's
custom_id is already registered.
"""
if self.is_ready():
print(
"Warning: Adding a persistent view after the client is ready. "
"This view will only be available for interactions on this session."
)
if view.timeout is not None:
raise ValueError("Persistent views must have a timeout of None.")
for item in view.children:
if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views:
raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered."
)
self._persistent_views[item.custom_id] = view
# --- Application Command Methods ---
async def process_interaction(self, interaction: Interaction) -> None:
"""Internal method to process an interaction from the gateway."""
@ -1455,11 +1547,25 @@ class Client:
if (
interaction.type == InteractionType.MESSAGE_COMPONENT
and interaction.message
and interaction.data
):
view = self._views.get(interaction.message.id)
if view:
asyncio.create_task(view._dispatch(interaction))
return
else:
# No active view found, check for persistent views
custom_id = interaction.data.custom_id
if custom_id:
registered_view = self._persistent_views.get(custom_id)
if registered_view:
# Create a new instance of the persistent view
new_view = registered_view.__class__()
await new_view._start(self)
new_view.message_id = interaction.message.id
self._views[interaction.message.id] = new_view
asyncio.create_task(new_view._dispatch(interaction))
return
await self.app_command_handler.process_interaction(interaction)

View File

@ -76,7 +76,7 @@ class EventDispatcher:
"""Parses MESSAGE_DELETE and updates message cache."""
message_id = data.get("id")
if message_id:
self._client._messages.pop(message_id, None)
self._client._messages.invalidate(message_id)
return data
def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]:
@ -124,7 +124,7 @@ class EventDispatcher:
"""Parses GUILD_MEMBER_ADD into a Member object."""
guild_id = str(data.get("guild_id"))
return self._client.parse_member(data, guild_id)
return self._client.parse_member(data, guild_id, just_joined=True)
def _parse_guild_member_remove(self, data: Dict[str, Any]):
"""Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model."""

View File

@ -18,6 +18,8 @@ from .decorators import (
cooldown,
max_concurrency,
requires_permissions,
has_role,
has_any_role,
)
from .errors import (
CommandError,
@ -47,6 +49,8 @@ __all__ = [
"cooldown",
"max_concurrency",
"requires_permissions",
"has_role",
"has_any_role",
# Errors
"CommandError",
"CommandNotFound",

View File

@ -40,7 +40,42 @@ if TYPE_CHECKING:
from disagreement.models import Message, User
class Command:
class GroupMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.commands: Dict[str, "Command"] = {}
self.name: str = ""
def command(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
cmd = Command(func, **attrs)
cmd.cog = getattr(self, "cog", None)
self.add_command(cmd)
return cmd
return decorator
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
cmd = Group(func, **attrs)
cmd.cog = getattr(self, "cog", None)
self.add_command(cmd)
return cmd
return decorator
def add_command(self, command: "Command") -> None:
if command.name in self.commands:
raise ValueError(f"Command '{command.name}' is already registered in group '{self.name}'.")
self.commands[command.name.lower()] = command
for alias in command.aliases:
if alias in self.commands:
logger.warning(f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias.")
self.commands[alias.lower()] = command
def get_command(self, name: str) -> Optional["Command"]:
return self.commands.get(name.lower())
class Command(GroupMixin):
"""
Represents a bot command.
@ -58,12 +93,14 @@ class Command:
if not asyncio.iscoroutinefunction(callback):
raise TypeError("Command callback must be a coroutine function.")
super().__init__(**attrs)
self.callback: Callable[..., Awaitable[None]] = callback
self.name: str = attrs.get("name", callback.__name__)
self.aliases: List[str] = attrs.get("aliases", [])
self.brief: Optional[str] = attrs.get("brief")
self.description: Optional[str] = attrs.get("description") or callback.__doc__
self.cog: Optional["Cog"] = attrs.get("cog")
self.invoke_without_command: bool = attrs.get("invoke_without_command", False)
self.params = inspect.signature(callback).parameters
self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = []
@ -79,20 +116,73 @@ class Command:
) -> None:
self.checks.append(predicate)
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
async def _run_checks(self, ctx: "CommandContext") -> None:
"""Runs all cog, local and global checks for the command."""
from .errors import CheckFailure
# Run cog-level check first
if self.cog:
cog_check = getattr(self.cog, "cog_check", None)
if cog_check:
try:
result = cog_check(ctx)
if inspect.isawaitable(result):
result = await result
if not result:
raise CheckFailure(
f"The cog-level check for command '{self.name}' failed."
)
except CheckFailure:
raise
except Exception as e:
raise CommandInvokeError(e) from e
# Run local checks
for predicate in self.checks:
result = predicate(ctx)
if inspect.isawaitable(result):
result = await result
if not result:
raise CheckFailure("Check predicate failed.")
raise CheckFailure(f"A local check for command '{self.name}' failed.")
# Then run global checks from the handler
if hasattr(ctx.bot, "command_handler"):
for predicate in ctx.bot.command_handler._global_checks:
result = predicate(ctx)
if inspect.isawaitable(result):
result = await result
if not result:
raise CheckFailure(
f"A global check failed for command '{self.name}'."
)
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
await self._run_checks(ctx)
before_invoke = None
after_invoke = None
if self.cog:
await self.callback(self.cog, ctx, *args, **kwargs)
else:
await self.callback(ctx, *args, **kwargs)
before_invoke = getattr(self.cog, "cog_before_invoke", None)
after_invoke = getattr(self.cog, "cog_after_invoke", None)
if before_invoke:
await before_invoke(ctx)
try:
if self.cog:
await self.callback(self.cog, ctx, *args, **kwargs)
else:
await self.callback(ctx, *args, **kwargs)
finally:
if after_invoke:
await after_invoke(ctx)
class Group(Command):
"""A command that can have subcommands."""
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
super().__init__(callback, **attrs)
PrefixCommand = Command # Alias for clarity in hybrid commands
@ -220,11 +310,20 @@ class CommandHandler:
self.commands: Dict[str, Command] = {}
self.cogs: Dict[str, "Cog"] = {}
self._concurrency: Dict[str, Dict[str, int]] = {}
self._global_checks: List[
Callable[["CommandContext"], Awaitable[bool] | bool]
] = []
from .help import HelpCommand
self.add_command(HelpCommand(self))
def add_check(
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
) -> None:
"""Adds a global check to the command handler."""
self._global_checks.append(predicate)
def add_command(self, command: Command) -> None:
if command.name in self.commands:
raise ValueError(f"Command '{command.name}' is already registered.")
@ -239,6 +338,15 @@ class CommandHandler:
)
self.commands[alias.lower()] = command
if isinstance(command, Group):
for sub_cmd in command.commands.values():
if sub_cmd.name in self.commands:
logger.warning(
"Subcommand '%s' of group '%s' conflicts with a top-level command.",
sub_cmd.name,
command.name,
)
def remove_command(self, name: str) -> Optional[Command]:
command = self.commands.pop(name.lower(), None)
if command:
@ -534,12 +642,28 @@ class CommandHandler:
if not command:
return
invoked_with = command_name
original_command = command
if isinstance(command, Group):
view.skip_whitespace()
potential_subcommand = view.get_word()
if potential_subcommand:
subcommand = command.get_command(potential_subcommand)
if subcommand:
command = subcommand
invoked_with += f" {potential_subcommand}"
elif command.invoke_without_command:
view.index -= len(potential_subcommand) + view.previous
else:
raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.")
ctx = CommandContext(
message=message,
bot=self.client,
prefix=actual_prefix,
command=command,
invoked_with=command_name,
invoked_with=invoked_with,
cog=command.cog,
)
@ -553,11 +677,14 @@ class CommandHandler:
finally:
self._release_concurrency(ctx)
except CommandError as e:
logger.error("Command error for '%s': %s", command.name, e)
logger.error("Command error for '%s': %s", original_command.name, e)
if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, e)
except Exception as e:
logger.error("Unexpected error invoking command '%s': %s", command.name, e)
logger.error("Unexpected error invoking command '%s': %s", original_command.name, e)
exc = CommandInvokeError(e)
if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, exc)
else:
if hasattr(self.client, "on_command_completion"):
await self.client.on_command_completion(ctx)

View File

@ -217,3 +217,82 @@ def requires_permissions(
return True
return check(predicate)
def has_role(
name_or_id: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has a role with the given name or ID."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.models import Member
if not ctx.guild:
raise CheckFailure("This command cannot be used in DMs.")
author = ctx.author
if not isinstance(author, Member):
try:
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
except Exception:
raise CheckFailure("Could not resolve author to a guild member.")
if not author:
raise CheckFailure("Could not resolve author to a guild member.")
# Create a list of the member's role objects by looking them up in the guild's roles list
member_roles = [
role for role in ctx.guild.roles if role.id in author.roles
]
if any(
role.id == str(name_or_id) or role.name == name_or_id
for role in member_roles
):
return True
raise CheckFailure(f"You need the '{name_or_id}' role to use this command.")
return check(predicate)
def has_any_role(
*names_or_ids: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has any of the roles with the given names or IDs."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.models import Member
if not ctx.guild:
raise CheckFailure("This command cannot be used in DMs.")
author = ctx.author
if not isinstance(author, Member):
try:
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
except Exception:
raise CheckFailure("Could not resolve author to a guild member.")
if not author:
raise CheckFailure("Could not resolve author to a guild member.")
member_roles = [
role for role in ctx.guild.roles if role.id in author.roles
]
# Convert names_or_ids to a set for efficient lookup
names_or_ids_set = set(map(str, names_or_ids))
if any(
role.id in names_or_ids_set or role.name in names_or_ids_set
for role in member_roles
):
return True
role_list = ", ".join(f"'{r}'" for r in names_or_ids)
raise CheckFailure(
f"You need one of the following roles to use this command: {role_list}"
)
return check(predicate)

View File

@ -79,6 +79,8 @@ class GatewayClient:
self._buffer = bytearray()
self._inflator = zlib.decompressobj()
self._member_chunk_requests: Dict[str, asyncio.Future] = {}
async def _reconnect(self) -> None:
"""Attempts to reconnect using exponential backoff with jitter."""
delay = 1.0
@ -237,6 +239,32 @@ class GatewayClient:
}
await self._send_json(payload)
async def request_guild_members(
self,
guild_id: str,
query: str = "",
limit: int = 0,
presences: bool = False,
user_ids: Optional[list[str]] = None,
nonce: Optional[str] = None,
):
"""Sends the request guild members payload to the Gateway."""
payload = {
"op": GatewayOpcode.REQUEST_GUILD_MEMBERS,
"d": {
"guild_id": guild_id,
"query": query,
"limit": limit,
"presences": presences,
},
}
if user_ids:
payload["d"]["user_ids"] = user_ids
if nonce:
payload["d"]["nonce"] = nonce
await self._send_json(payload)
async def _handle_dispatch(self, data: Dict[str, Any]):
"""Handles DISPATCH events (actual Discord events)."""
event_name = data.get("t")
@ -313,6 +341,22 @@ class GatewayClient:
)
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
elif event_name == "GUILD_MEMBERS_CHUNK":
if isinstance(raw_event_d_payload, dict):
nonce = raw_event_d_payload.get("nonce")
if nonce and nonce in self._member_chunk_requests:
future = self._member_chunk_requests[nonce]
if not future.done():
# Append members to a temporary list stored on the future object
if not hasattr(future, "_members"):
future._members = [] # type: ignore
future._members.extend(raw_event_d_payload.get("members", [])) # type: ignore
# If this is the last chunk, resolve the future
if raw_event_d_payload.get("chunk_index") == raw_event_d_payload.get("chunk_count", 1) - 1:
future.set_result(future._members) # type: ignore
del self._member_chunk_requests[nonce]
elif event_name == "INTERACTION_CREATE":
# print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}")
if isinstance(raw_event_d_payload, dict):

View File

@ -368,6 +368,20 @@ class HTTPClient:
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
)
async def delete_user_reaction(
self,
channel_id: "Snowflake",
message_id: "Snowflake",
emoji: str,
user_id: "Snowflake",
) -> None:
"""Removes another user's reaction from a message."""
encoded = quote(emoji)
await self.request(
"DELETE",
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}",
)
async def get_reactions(
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
) -> List[Dict[str, Any]]:
@ -400,6 +414,27 @@ class HTTPClient:
)
return messages
async def get_pinned_messages(
self, channel_id: "Snowflake"
) -> List[Dict[str, Any]]:
"""Fetches all pinned messages in a channel."""
return await self.request("GET", f"/channels/{channel_id}/pins")
async def pin_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> None:
"""Pins a message in a channel."""
await self.request("PUT", f"/channels/{channel_id}/pins/{message_id}")
async def unpin_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> None:
"""Unpins a message from a channel."""
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
async def delete_channel(
self, channel_id: str, reason: Optional[str] = None
) -> None:
@ -420,6 +455,21 @@ class HTTPClient:
custom_headers=custom_headers if custom_headers else None,
)
async def edit_channel(
self,
channel_id: "Snowflake",
payload: Dict[str, Any],
reason: Optional[str] = None,
) -> Dict[str, Any]:
"""Edits a channel."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
return await self.request(
"PATCH",
f"/channels/{channel_id}",
payload=payload,
custom_headers=headers,
)
async def get_channel(self, channel_id: str) -> Dict[str, Any]:
"""Fetches a channel by ID."""
return await self.request("GET", f"/channels/{channel_id}")
@ -1039,3 +1089,32 @@ class HTTPClient:
async def get_voice_regions(self) -> List[Dict[str, Any]]:
"""Returns available voice regions."""
return await self.request("GET", "/voice/regions")
async def start_thread_from_message(
self,
channel_id: "Snowflake",
message_id: "Snowflake",
payload: Dict[str, Any],
) -> Dict[str, Any]:
"""Starts a new thread from an existing message."""
return await self.request(
"POST",
f"/channels/{channel_id}/messages/{message_id}/threads",
payload=payload,
)
async def start_thread_without_message(
self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> Dict[str, Any]:
"""Starts a new thread that is not attached to a message."""
return await self.request(
"POST", f"/channels/{channel_id}/threads", payload=payload
)
async def join_thread(self, channel_id: "Snowflake") -> None:
"""Joins the current user to a thread."""
await self.request("PUT", f"/channels/{channel_id}/thread-members/@me")
async def leave_thread(self, channel_id: "Snowflake") -> None:
"""Removes the current user from a thread."""
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")

View File

@ -7,7 +7,9 @@ Data models for Discord objects.
import asyncio
import json
from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
from .cache import ChannelCache, MemberCache
import aiohttp # pylint: disable=import-error
from .color import Color
@ -105,11 +107,38 @@ class Message:
self.attachments: List[Attachment] = [
Attachment(a) for a in data.get("attachments", [])
]
self.pinned: bool = data.get("pinned", False)
# Add other fields as needed, e.g., attachments, embeds, reactions, etc.
# self.mentions: List[User] = [User(u) for u in data.get("mentions", [])]
# self.mention_roles: List[str] = data.get("mention_roles", [])
# self.mention_everyone: bool = data.get("mention_everyone", False)
async def pin(self) -> None:
"""|coro|
Pins this message to its channel.
Raises
------
HTTPException
Pinning the message failed.
"""
await self._client._http.pin_message(self.channel_id, self.id)
self.pinned = True
async def unpin(self) -> None:
"""|coro|
Unpins this message from its channel.
Raises
------
HTTPException
Unpinning the message failed.
"""
await self._client._http.unpin_message(self.channel_id, self.id)
self.pinned = False
async def reply(
self,
content: Optional[str] = None,
@ -210,10 +239,17 @@ class Message:
await self._client.add_reaction(self.channel_id, self.id, emoji)
async def remove_reaction(self, emoji: str) -> None:
"""|coro| Remove the bot's reaction from this message."""
await self._client.remove_reaction(self.channel_id, self.id, emoji)
async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None:
"""|coro|
Removes a reaction from this message.
If no ``member`` is provided, removes the bot's own reaction.
"""
if member:
await self._client._http.delete_user_reaction(
self.channel_id, self.id, emoji, member.id
)
else:
await self._client.remove_reaction(self.channel_id, self.id, emoji)
async def clear_reactions(self) -> None:
"""|coro| Remove all reactions from this message."""
@ -239,6 +275,125 @@ class Message:
def __repr__(self) -> str:
return f"<Message id='{self.id}' channel_id='{self.channel_id}' author='{self.author!r}'>"
async def create_thread(
self,
name: str,
*,
auto_archive_duration: Optional[int] = None,
rate_limit_per_user: Optional[int] = None,
reason: Optional[str] = None,
) -> "Thread":
"""|coro|
Creates a new thread from this message.
Parameters
----------
name: str
The name of the thread.
auto_archive_duration: Optional[int]
The duration in minutes to automatically archive the thread after recent activity.
Can be one of 60, 1440, 4320, 10080.
rate_limit_per_user: Optional[int]
The number of seconds a user has to wait before sending another message.
reason: Optional[str]
The reason for creating the thread.
Returns
-------
Thread
The created thread.
"""
payload: Dict[str, Any] = {"name": name}
if auto_archive_duration is not None:
payload["auto_archive_duration"] = auto_archive_duration
if rate_limit_per_user is not None:
payload["rate_limit_per_user"] = rate_limit_per_user
data = await self._client._http.start_thread_from_message(
self.channel_id, self.id, payload
)
return cast("Thread", self._client.parse_channel(data))
class PartialMessage:
"""Represents a partial message, identified by its ID and channel.
This model is used to perform actions on a message without having the
full message object in the cache.
Attributes:
id (str): The message's unique ID.
channel (TextChannel): The text channel this message belongs to.
"""
def __init__(self, *, id: str, channel: "TextChannel"):
self.id = id
self.channel = channel
self._client = channel._client
async def fetch(self) -> "Message":
"""|coro|
Fetches the full message data from Discord.
Returns
-------
Message
The complete message object.
"""
data = await self._client._http.get_message(self.channel.id, self.id)
return self._client.parse_message(data)
async def delete(self, *, delay: Optional[float] = None) -> None:
"""|coro|
Deletes this message.
Parameters
----------
delay: Optional[float]
If provided, wait this many seconds before deleting.
"""
if delay is not None:
await asyncio.sleep(delay)
await self._client._http.delete_message(self.channel.id, self.id)
async def pin(self) -> None:
"""|coro|
Pins this message to its channel.
"""
await self._client._http.pin_message(self.channel.id, self.id)
async def unpin(self) -> None:
"""|coro|
Unpins this message from its channel.
"""
await self._client._http.unpin_message(self.channel.id, self.id)
async def add_reaction(self, emoji: str) -> None:
"""|coro|
Adds a reaction to this message.
"""
await self._client._http.create_reaction(self.channel.id, self.id, emoji)
async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None:
"""|coro|
Removes a reaction from this message.
If no ``member`` is provided, removes the bot's own reaction.
"""
if member:
await self._client._http.delete_user_reaction(
self.channel.id, self.id, emoji, member.id
)
else:
await self._client._http.delete_reaction(self.channel.id, self.id, emoji)
class EmbedFooter:
"""Represents an embed footer."""
@ -503,6 +658,8 @@ class Member(User): # Member inherits from User
):
self._client: Optional["Client"] = client_instance
self.guild_id: Optional[str] = None
self.status: Optional[str] = None
self.voice_state: Optional[Dict[str, Any]] = None
# User part is nested under 'user' key in member data from gateway/API
user_data = data.get("user", {})
# If 'id' is not in user_data but is top-level (e.g. from interaction resolved member without user object)
@ -929,8 +1086,8 @@ class Guild:
)
# Internal caches, populated by events or specific fetches
self._channels: Dict[str, "Channel"] = {}
self._members: Dict[str, Member] = {}
self._channels: ChannelCache = ChannelCache()
self._members: MemberCache = MemberCache(client_instance.member_cache_flags)
self._threads: Dict[str, "Thread"] = {}
def get_channel(self, channel_id: str) -> Optional["Channel"]:
@ -970,6 +1127,49 @@ class Guild:
def __repr__(self) -> str:
return f"<Guild id='{self.id}' name='{self.name}'>"
async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]:
"""|coro|
Fetches all members for this guild.
This requires the ``GUILD_MEMBERS`` intent.
Parameters
----------
limit: Optional[int]
The maximum number of members to fetch. If ``None``, all members
are fetched.
Returns
-------
List[Member]
A list of all members in the guild.
Raises
------
DisagreementException
The gateway is not available to make the request.
asyncio.TimeoutError
The request timed out.
"""
if not self._client._gateway:
raise DisagreementException("Gateway not available for member fetching.")
nonce = str(asyncio.get_running_loop().time())
future = self._client._gateway._loop.create_future()
self._client._gateway._member_chunk_requests[nonce] = future
try:
await self._client._gateway.request_guild_members(
self.id, limit=limit or 0, nonce=nonce
)
member_data = await asyncio.wait_for(future, timeout=60.0)
return [Member(m, self._client) for m in member_data]
except asyncio.TimeoutError:
if nonce in self._client._gateway._member_chunk_requests:
del self._client._gateway._member_chunk_requests[nonce]
raise
class Channel:
"""Base class for Discord channels."""
@ -1142,12 +1342,98 @@ class TextChannel(Channel):
await self._client._http.bulk_delete_messages(self.id, ids)
for mid in ids:
self._client._messages.pop(mid, None)
self._client._messages.invalidate(mid)
return ids
def get_partial_message(self, id: int) -> "PartialMessage":
"""Returns a :class:`PartialMessage` for the given ID.
This allows performing actions on a message without fetching it first.
Parameters
----------
id: int
The ID of the message to get a partial instance of.
Returns
-------
PartialMessage
The partial message instance.
"""
return PartialMessage(id=str(id), channel=self)
def __repr__(self) -> str:
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
async def pins(self) -> List["Message"]:
"""|coro|
Fetches all pinned messages in this channel.
Returns
-------
List[Message]
The pinned messages.
Raises
------
HTTPException
Fetching the pinned messages failed.
"""
messages_data = await self._client._http.get_pinned_messages(self.id)
return [self._client.parse_message(m) for m in messages_data]
async def create_thread(
self,
name: str,
*,
type: ChannelType = ChannelType.PUBLIC_THREAD,
auto_archive_duration: Optional[int] = None,
invitable: Optional[bool] = None,
rate_limit_per_user: Optional[int] = None,
reason: Optional[str] = None,
) -> "Thread":
"""|coro|
Creates a new thread in this channel.
Parameters
----------
name: str
The name of the thread.
type: ChannelType
The type of thread to create. Defaults to PUBLIC_THREAD.
Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD.
auto_archive_duration: Optional[int]
The duration in minutes to automatically archive the thread after recent activity.
invitable: Optional[bool]
Whether non-moderators can invite other non-moderators to a private thread.
Only applicable to private threads.
rate_limit_per_user: Optional[int]
The number of seconds a user has to wait before sending another message.
reason: Optional[str]
The reason for creating the thread.
Returns
-------
Thread
The created thread.
"""
payload: Dict[str, Any] = {
"name": name,
"type": type.value,
}
if auto_archive_duration is not None:
payload["auto_archive_duration"] = auto_archive_duration
if invitable is not None and type == ChannelType.PRIVATE_THREAD:
payload["invitable"] = invitable
if rate_limit_per_user is not None:
payload["rate_limit_per_user"] = rate_limit_per_user
data = await self._client._http.start_thread_without_message(self.id, payload)
return cast("Thread", self._client.parse_channel(data))
class VoiceChannel(Channel):
"""Represents a guild voice channel or stage voice channel."""
@ -1305,6 +1591,44 @@ class Thread(TextChannel): # Threads are a specialized TextChannel
f"<Thread id='{self.id}' name='{self.name}' parent_id='{self.parent_id}'>"
)
async def join(self) -> None:
"""|coro|
Joins this thread.
"""
await self._client._http.join_thread(self.id)
async def leave(self) -> None:
"""|coro|
Leaves this thread.
"""
await self._client._http.leave_thread(self.id)
async def archive(self, locked: bool = False, *, reason: Optional[str] = None) -> "Thread":
"""|coro|
Archives this thread.
Parameters
----------
locked: bool
Whether to lock the thread.
reason: Optional[str]
The reason for archiving the thread.
Returns
-------
Thread
The updated thread.
"""
payload = {
"archived": True,
"locked": locked,
}
data = await self._client._http.edit_channel(self.id, payload, reason=reason)
return cast("Thread", self._client.parse_channel(data))
class DMChannel(Channel):
"""Represents a Direct Message channel."""

View File

@ -28,6 +28,8 @@ class View:
self._client: Optional[Client] = None
self._message_id: Optional[str] = None
# The below is a bit of a hack to support items defined as class members
# e.g. button = Button(...)
for item in self.__class__.__dict__.values():
if isinstance(item, Item):
self.add_item(item)
@ -44,6 +46,11 @@ class View:
if len(self.__children) >= 25:
raise ValueError("A view can only have a maximum of 25 components.")
if self.timeout is None and item.custom_id is None:
raise ValueError(
"All components in a persistent view must have a 'custom_id'."
)
item._view = self
self.__children.append(item)
@ -65,12 +72,7 @@ class View:
rows: List[ActionRow] = []
for item in self.children:
if item.custom_id is None:
item.custom_id = (
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
)
rows.append(ActionRow(components=[item]))
rows.append(ActionRow(components=[item]))
return rows

View File

@ -6,11 +6,19 @@ from __future__ import annotations
import asyncio
import contextlib
import socket
from typing import Optional, Sequence
import threading
from typing import TYPE_CHECKING, Optional, Sequence
import aiohttp
# The following import is correct, but may be flagged by Pylance if the virtual
# environment is not configured correctly.
from nacl.secret import SecretBox
from .audio import AudioSource, FFmpegAudioSource
from .audio import AudioSink, AudioSource, FFmpegAudioSource
from .models import User
if TYPE_CHECKING:
from .client import Client
class VoiceClient:
@ -18,6 +26,7 @@ class VoiceClient:
def __init__(
self,
client: Client,
endpoint: str,
session_id: str,
token: str,
@ -29,6 +38,7 @@ class VoiceClient:
loop: Optional[asyncio.AbstractEventLoop] = None,
verbose: bool = False,
) -> None:
self.client = client
self.endpoint = endpoint
self.session_id = session_id
self.token = token
@ -38,8 +48,14 @@ class VoiceClient:
self._udp = udp
self._session: Optional[aiohttp.ClientSession] = None
self._heartbeat_task: Optional[asyncio.Task] = None
self._receive_task: Optional[asyncio.Task] = None
self._udp_receive_thread: Optional[threading.Thread] = None
self._heartbeat_interval: Optional[float] = None
self._loop = loop or asyncio.get_event_loop()
try:
self._loop = loop or asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self.verbose = verbose
self.ssrc: Optional[int] = None
self.secret_key: Optional[Sequence[int]] = None
@ -47,6 +63,9 @@ class VoiceClient:
self._server_port: Optional[int] = None
self._current_source: Optional[AudioSource] = None
self._play_task: Optional[asyncio.Task] = None
self._sink: Optional[AudioSink] = None
self._ssrc_map: dict[int, int] = {}
self._ssrc_lock = threading.Lock()
async def connect(self) -> None:
if self._ws is None:
@ -106,6 +125,49 @@ class VoiceClient:
except asyncio.CancelledError:
pass
async def _receive_loop(self) -> None:
assert self._ws is not None
while True:
try:
msg = await self._ws.receive_json()
op = msg.get("op")
data = msg.get("d")
if op == 5: # Speaking
user_id = int(data["user_id"])
ssrc = data["ssrc"]
with self._ssrc_lock:
self._ssrc_map[ssrc] = user_id
except (asyncio.CancelledError, aiohttp.ClientError):
break
def _udp_receive_loop(self) -> None:
assert self._udp is not None
assert self.secret_key is not None
box = SecretBox(bytes(self.secret_key))
while True:
try:
packet = self._udp.recv(4096)
if len(packet) < 12:
continue
ssrc = int.from_bytes(packet[8:12], "big")
with self._ssrc_lock:
if ssrc not in self._ssrc_map:
continue
user_id = self._ssrc_map[ssrc]
user = self.client._users.get(str(user_id))
if not user:
continue
decrypted = box.decrypt(packet[12:])
if self._sink:
self._sink.write(user, decrypted)
except (socket.error, asyncio.CancelledError):
break
except Exception as e:
if self.verbose:
print(f"Error in UDP receive loop: {e}")
async def send_audio_frame(self, frame: bytes) -> None:
if not self._udp:
raise RuntimeError("UDP socket not initialised")
@ -148,15 +210,35 @@ class VoiceClient:
await self.play(FFmpegAudioSource(filename), wait=wait)
def listen(self, sink: AudioSink) -> None:
"""Start listening to voice and routing to a sink."""
if not isinstance(sink, AudioSink):
raise TypeError("sink must be an AudioSink instance")
self._sink = sink
if not self._udp_receive_thread:
self._udp_receive_thread = threading.Thread(
target=self._udp_receive_loop, daemon=True
)
self._udp_receive_thread.start()
async def close(self) -> None:
await self.stop()
if self._heartbeat_task:
self._heartbeat_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._heartbeat_task
if self._receive_task:
self._receive_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._receive_task
if self._ws:
await self._ws.close()
if self._session:
await self._session.close()
if self._udp:
self._udp.close()
if self._udp_receive_thread:
self._udp_receive_thread.join(timeout=1)
if self._sink:
self._sink.close()

View File

@ -26,6 +26,7 @@ classifiers = [
dependencies = [
"aiohttp>=3.9.0,<4.0.0",
"PyNaCl>=1.5.0,<2.0.0",
]
[project.optional-dependencies]