Compare commits

...

8 Commits

Author SHA1 Message Date
ed83a9da85
Implements caching system with TTL and member filtering
Introduces a flexible caching infrastructure with time-to-live support and configurable member caching based on status, voice state, and join events.

Adds AudioSink abstract base class to support audio output handling in voice connections.

Replaces direct dictionary access with cache objects throughout the client, enabling automatic expiration and intelligent member filtering based on user-defined flags.

Updates guild parsing to incorporate presence and voice state data for more accurate member caching decisions.
2025-06-11 02:11:33 -06:00
0151526d07
chore: Apply code formatting across the codebase
This commit applies consistent code formatting to multiple files. No functional changes are included.
2025-06-11 02:10:33 -06:00
17b7ea35a9
feat: Implement guild.fetch_members 2025-06-11 02:06:19 -06:00
28702fa8a1
feat(voice): Implement voice receiving and audio sinks 2025-06-11 02:06:18 -06:00
97505948ee
feat: Add advanced message, channel, and thread management 2025-06-11 02:06:16 -06:00
152c0f12be
feat(ui): Implement persistent views 2025-06-11 02:06:15 -06:00
eb38ecf671
feat(commands): Add has_role and has_any_role check decorators 2025-06-11 02:06:13 -06:00
2bd45c87ca
feat: Enhance command framework with groups, checks, and events 2025-06-11 02:06:11 -06:00
14 changed files with 7679 additions and 6664 deletions

View File

@ -114,3 +114,20 @@ class FFmpegAudioSource(AudioSource):
if isinstance(self.source, io.IOBase): if isinstance(self.source, io.IOBase):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self.source.close() self.source.close()
class AudioSink:
"""Abstract base class for audio sinks."""
def write(self, user, data):
"""Write a chunk of PCM audio.
Subclasses must implement this. The data is raw PCM at 48kHz
stereo.
"""
raise NotImplementedError
def close(self) -> None:
"""Cleanup the sink when the voice client disconnects."""
return None

View File

@ -4,7 +4,8 @@ import time
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from .models import Channel, Guild from .models import Channel, Guild, Member
from .caching import MemberCacheFlags
T = TypeVar("T") T = TypeVar("T")
@ -53,3 +54,32 @@ class GuildCache(Cache["Guild"]):
class ChannelCache(Cache["Channel"]): class ChannelCache(Cache["Channel"]):
"""Cache specifically for :class:`Channel` objects.""" """Cache specifically for :class:`Channel` objects."""
class MemberCache(Cache["Member"]):
"""
A cache for :class:`Member` objects that respects :class:`MemberCacheFlags`.
"""
def __init__(self, flags: MemberCacheFlags, ttl: Optional[float] = None) -> None:
super().__init__(ttl)
self.flags = flags
def _should_cache(self, member: Member) -> bool:
"""Determines if a member should be cached based on the flags."""
if self.flags.all:
return True
if self.flags.none:
return False
if self.flags.online and member.status != "offline":
return True
if self.flags.voice and member.voice_state is not None:
return True
if self.flags.joined and getattr(member, "_just_joined", False):
return True
return False
def set(self, key: str, value: Member) -> None:
if self._should_cache(value):
super().set(key, value)

120
disagreement/caching.py Normal file
View File

@ -0,0 +1,120 @@
from __future__ import annotations
import operator
from typing import Any, Callable, ClassVar, Dict, Iterator, Tuple
class _MemberCacheFlagValue:
flag: int
def __init__(self, func: Callable[[Any], bool]):
self.flag = getattr(func, 'flag', 0)
self.__doc__ = func.__doc__
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any:
if instance is None:
return self
return instance.value & self.flag != 0
def __set__(self, instance: Any, value: bool) -> None:
if value:
instance.value |= self.flag
else:
instance.value &= ~self.flag
def __repr__(self) -> str:
return f'<{self.__class__.__name__} flag={self.flag}>'
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
setattr(func, 'flag', flag)
return _MemberCacheFlagValue(func)
return decorator
class MemberCacheFlags:
__slots__ = ('value',)
VALID_FLAGS: ClassVar[Dict[str, int]] = {
'joined': 1 << 0,
'voice': 1 << 1,
'online': 1 << 2,
}
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_FLAGS
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid member cache flag.')
setattr(self, key, value)
@classmethod
def _from_value(cls, value: int) -> MemberCacheFlags:
self = cls.__new__(cls)
self.value = value
return self
def __eq__(self, other: object) -> bool:
return isinstance(other, MemberCacheFlags) and self.value == other.value
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
return hash(self.value)
def __repr__(self) -> str:
return f'<MemberCacheFlags value={self.value}>'
def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name in self.VALID_FLAGS:
yield name, getattr(self, name)
def __int__(self) -> int:
return self.value
def __index__(self) -> int:
return self.value
@classmethod
def all(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with all flags enabled."""
return cls._from_value(cls.ALL_FLAGS)
@classmethod
def none(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with all flags disabled."""
return cls._from_value(0)
@classmethod
def only_joined(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['joined'])
@classmethod
def only_voice(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['voice'])
@classmethod
def only_online(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['online'])
@flag_value(1 << 0)
def joined(self) -> bool:
"""Whether to cache members that have just joined the guild."""
return False
@flag_value(1 << 1)
def voice(self) -> bool:
"""Whether to cache members that are in a voice channel."""
return False
@flag_value(1 << 2)
def online(self) -> bool:
"""Whether to cache members that are online."""
return False

File diff suppressed because it is too large Load Diff

View File

@ -76,7 +76,7 @@ class EventDispatcher:
"""Parses MESSAGE_DELETE and updates message cache.""" """Parses MESSAGE_DELETE and updates message cache."""
message_id = data.get("id") message_id = data.get("id")
if message_id: if message_id:
self._client._messages.pop(message_id, None) self._client._messages.invalidate(message_id)
return data return data
def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]: def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]:
@ -124,7 +124,7 @@ class EventDispatcher:
"""Parses GUILD_MEMBER_ADD into a Member object.""" """Parses GUILD_MEMBER_ADD into a Member object."""
guild_id = str(data.get("guild_id")) guild_id = str(data.get("guild_id"))
return self._client.parse_member(data, guild_id) return self._client.parse_member(data, guild_id, just_joined=True)
def _parse_guild_member_remove(self, data: Dict[str, Any]): def _parse_guild_member_remove(self, data: Dict[str, Any]):
"""Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model.""" """Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model."""

View File

@ -1,61 +1,65 @@
# disagreement/ext/commands/__init__.py # disagreement/ext/commands/__init__.py
""" """
disagreement.ext.commands - A command framework extension for the Disagreement library. disagreement.ext.commands - A command framework extension for the Disagreement library.
""" """
from .cog import Cog from .cog import Cog
from .core import ( from .core import (
Command, Command,
CommandContext, CommandContext,
CommandHandler, CommandHandler,
) # CommandHandler might be internal ) # CommandHandler might be internal
from .decorators import ( from .decorators import (
command, command,
listener, listener,
check, check,
check_any, check_any,
cooldown, cooldown,
max_concurrency, max_concurrency,
requires_permissions, requires_permissions,
) has_role,
from .errors import ( has_any_role,
CommandError, )
CommandNotFound, from .errors import (
BadArgument, CommandError,
MissingRequiredArgument, CommandNotFound,
ArgumentParsingError, BadArgument,
CheckFailure, MissingRequiredArgument,
CheckAnyFailure, ArgumentParsingError,
CommandOnCooldown, CheckFailure,
CommandInvokeError, CheckAnyFailure,
MaxConcurrencyReached, CommandOnCooldown,
) CommandInvokeError,
MaxConcurrencyReached,
__all__ = [ )
# Cog
"Cog", __all__ = [
# Core # Cog
"Command", "Cog",
"CommandContext", # Core
# "CommandHandler", # Usually not part of public API for direct use by bot devs "Command",
# Decorators "CommandContext",
"command", # "CommandHandler", # Usually not part of public API for direct use by bot devs
"listener", # Decorators
"check", "command",
"check_any", "listener",
"cooldown", "check",
"max_concurrency", "check_any",
"requires_permissions", "cooldown",
# Errors "max_concurrency",
"CommandError", "requires_permissions",
"CommandNotFound", "has_role",
"BadArgument", "has_any_role",
"MissingRequiredArgument", # Errors
"ArgumentParsingError", "CommandError",
"CheckFailure", "CommandNotFound",
"CheckAnyFailure", "BadArgument",
"CommandOnCooldown", "MissingRequiredArgument",
"CommandInvokeError", "ArgumentParsingError",
"MaxConcurrencyReached", "CheckFailure",
] "CheckAnyFailure",
"CommandOnCooldown",
"CommandInvokeError",
"MaxConcurrencyReached",
]

File diff suppressed because it is too large Load Diff

View File

@ -1,219 +1,298 @@
# disagreement/ext/commands/decorators.py # disagreement/ext/commands/decorators.py
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import time import time
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
from .core import Command, CommandContext from .core import Command, CommandContext
from disagreement.permissions import Permissions from disagreement.permissions import Permissions
from disagreement.models import Member, Guild, Channel from disagreement.models import Member, Guild, Channel
def command( def command(
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
) -> Callable: ) -> Callable:
""" """
A decorator that transforms a function into a Command. A decorator that transforms a function into a Command.
Args: Args:
name (Optional[str]): The name of the command. Defaults to the function name. name (Optional[str]): The name of the command. Defaults to the function name.
aliases (Optional[List[str]]): Alternative names for the command. aliases (Optional[List[str]]): Alternative names for the command.
**attrs: Additional attributes to pass to the Command constructor **attrs: Additional attributes to pass to the Command constructor
(e.g., brief, description, hidden). (e.g., brief, description, hidden).
Returns: Returns:
Callable: A decorator that registers the command. Callable: A decorator that registers the command.
""" """
def decorator( def decorator(
func: Callable[..., Awaitable[None]], func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]: ) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func): if not asyncio.iscoroutinefunction(func):
raise TypeError("Command callback must be a coroutine function.") raise TypeError("Command callback must be a coroutine function.")
from .core import Command from .core import Command
cmd_name = name or func.__name__ cmd_name = name or func.__name__
if hasattr(func, "__command_attrs__"): if hasattr(func, "__command_attrs__"):
raise TypeError("Function is already a command or has command attributes.") raise TypeError("Function is already a command or has command attributes.")
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs) cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
func.__command_object__ = cmd # type: ignore func.__command_object__ = cmd # type: ignore
return func return func
return decorator return decorator
def listener( def listener(
name: Optional[str] = None, name: Optional[str] = None,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
""" """
A decorator that marks a function as an event listener within a Cog. A decorator that marks a function as an event listener within a Cog.
""" """
def decorator( def decorator(
func: Callable[..., Awaitable[None]], func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]: ) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func): if not asyncio.iscoroutinefunction(func):
raise TypeError("Listener callback must be a coroutine function.") raise TypeError("Listener callback must be a coroutine function.")
actual_event_name = name or func.__name__ actual_event_name = name or func.__name__
setattr(func, "__listener_name__", actual_event_name) setattr(func, "__listener_name__", actual_event_name)
return func return func
return decorator return decorator
def check( def check(
predicate: Callable[["CommandContext"], Awaitable[bool] | bool], predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator to add a check to a command.""" """Decorator to add a check to a command."""
def decorator( def decorator(
func: Callable[..., Awaitable[None]], func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]: ) -> Callable[..., Awaitable[None]]:
checks = getattr(func, "__command_checks__", []) checks = getattr(func, "__command_checks__", [])
checks.append(predicate) checks.append(predicate)
setattr(func, "__command_checks__", checks) setattr(func, "__command_checks__", checks)
return func return func
return decorator return decorator
def check_any( def check_any(
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool] *predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator that passes if any predicate returns ``True``.""" """Decorator that passes if any predicate returns ``True``."""
async def predicate(ctx: "CommandContext") -> bool: async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckAnyFailure, CheckFailure from .errors import CheckAnyFailure, CheckFailure
errors = [] errors = []
for p in predicates: for p in predicates:
try: try:
result = p(ctx) result = p(ctx)
if inspect.isawaitable(result): if inspect.isawaitable(result):
result = await result result = await result
if result: if result:
return True return True
except CheckFailure as e: except CheckFailure as e:
errors.append(e) errors.append(e)
raise CheckAnyFailure(errors) raise CheckAnyFailure(errors)
return check(predicate) return check(predicate)
def max_concurrency( def max_concurrency(
number: int, per: str = "user" number: int, per: str = "user"
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Limit how many concurrent invocations of a command are allowed. """Limit how many concurrent invocations of a command are allowed.
Parameters Parameters
---------- ----------
number: number:
The maximum number of concurrent invocations. The maximum number of concurrent invocations.
per: per:
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``. The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
""" """
if number < 1: if number < 1:
raise ValueError("Concurrency number must be at least 1.") raise ValueError("Concurrency number must be at least 1.")
if per not in {"user", "guild", "global"}: if per not in {"user", "guild", "global"}:
raise ValueError("per must be 'user', 'guild', or 'global'.") raise ValueError("per must be 'user', 'guild', or 'global'.")
def decorator( def decorator(
func: Callable[..., Awaitable[None]], func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]: ) -> Callable[..., Awaitable[None]]:
setattr(func, "__max_concurrency__", (number, per)) setattr(func, "__max_concurrency__", (number, per))
return func return func
return decorator return decorator
def cooldown( def cooldown(
rate: int, per: float rate: int, per: float
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Simple per-user cooldown decorator.""" """Simple per-user cooldown decorator."""
buckets: dict[str, dict[str, float]] = {} buckets: dict[str, dict[str, float]] = {}
async def predicate(ctx: "CommandContext") -> bool: async def predicate(ctx: "CommandContext") -> bool:
from .errors import CommandOnCooldown from .errors import CommandOnCooldown
now = time.monotonic() now = time.monotonic()
user_buckets = buckets.setdefault(ctx.command.name, {}) user_buckets = buckets.setdefault(ctx.command.name, {})
reset = user_buckets.get(ctx.author.id, 0) reset = user_buckets.get(ctx.author.id, 0)
if now < reset: if now < reset:
raise CommandOnCooldown(reset - now) raise CommandOnCooldown(reset - now)
user_buckets[ctx.author.id] = now + per user_buckets[ctx.author.id] = now + per
return True return True
return check(predicate) return check(predicate)
def _compute_permissions( def _compute_permissions(
member: "Member", channel: "Channel", guild: "Guild" member: "Member", channel: "Channel", guild: "Guild"
) -> "Permissions": ) -> "Permissions":
"""Compute the effective permissions for a member in a channel.""" """Compute the effective permissions for a member in a channel."""
return channel.permissions_for(member) return channel.permissions_for(member)
def requires_permissions( def requires_permissions(
*perms: "Permissions", *perms: "Permissions",
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has the given permissions in the channel.""" """Check that the invoking member has the given permissions in the channel."""
async def predicate(ctx: "CommandContext") -> bool: async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure from .errors import CheckFailure
from disagreement.permissions import ( from disagreement.permissions import (
has_permissions, has_permissions,
missing_permissions, missing_permissions,
) )
from disagreement.models import Member from disagreement.models import Member
channel = getattr(ctx, "channel", None) channel = getattr(ctx, "channel", None)
if channel is None and hasattr(ctx.bot, "get_channel"): if channel is None and hasattr(ctx.bot, "get_channel"):
channel = ctx.bot.get_channel(ctx.message.channel_id) channel = ctx.bot.get_channel(ctx.message.channel_id)
if channel is None and hasattr(ctx.bot, "fetch_channel"): if channel is None and hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(ctx.message.channel_id) channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
if channel is None: if channel is None:
raise CheckFailure("Channel for permission check not found.") raise CheckFailure("Channel for permission check not found.")
guild = getattr(channel, "guild", None) guild = getattr(channel, "guild", None)
if not guild and hasattr(channel, "guild_id") and channel.guild_id: if not guild and hasattr(channel, "guild_id") and channel.guild_id:
if hasattr(ctx.bot, "get_guild"): if hasattr(ctx.bot, "get_guild"):
guild = ctx.bot.get_guild(channel.guild_id) guild = ctx.bot.get_guild(channel.guild_id)
if not guild and hasattr(ctx.bot, "fetch_guild"): if not guild and hasattr(ctx.bot, "fetch_guild"):
guild = await ctx.bot.fetch_guild(channel.guild_id) guild = await ctx.bot.fetch_guild(channel.guild_id)
if not guild: if not guild:
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
if is_dm: if is_dm:
if perms: if perms:
raise CheckFailure("Permission checks are not supported in DMs.") raise CheckFailure("Permission checks are not supported in DMs.")
return True return True
raise CheckFailure("Guild for permission check not found.") raise CheckFailure("Guild for permission check not found.")
member = ctx.author member = ctx.author
if not isinstance(member, Member): if not isinstance(member, Member):
member = guild.get_member(ctx.author.id) member = guild.get_member(ctx.author.id)
if not member and hasattr(ctx.bot, "fetch_member"): if not member and hasattr(ctx.bot, "fetch_member"):
member = await ctx.bot.fetch_member(guild.id, ctx.author.id) member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
if not member: if not member:
raise CheckFailure("Could not resolve author to a guild member.") raise CheckFailure("Could not resolve author to a guild member.")
perms_value = _compute_permissions(member, channel, guild) perms_value = _compute_permissions(member, channel, guild)
if not has_permissions(perms_value, *perms): if not has_permissions(perms_value, *perms):
missing = missing_permissions(perms_value, *perms) missing = missing_permissions(perms_value, *perms)
missing_names = ", ".join(p.name for p in missing if p.name) missing_names = ", ".join(p.name for p in missing if p.name)
raise CheckFailure(f"Missing permissions: {missing_names}") raise CheckFailure(f"Missing permissions: {missing_names}")
return True return True
return check(predicate) return check(predicate)
def has_role(
name_or_id: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has a role with the given name or ID."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.models import Member
if not ctx.guild:
raise CheckFailure("This command cannot be used in DMs.")
author = ctx.author
if not isinstance(author, Member):
try:
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
except Exception:
raise CheckFailure("Could not resolve author to a guild member.")
if not author:
raise CheckFailure("Could not resolve author to a guild member.")
# Create a list of the member's role objects by looking them up in the guild's roles list
member_roles = [
role for role in ctx.guild.roles if role.id in author.roles
]
if any(
role.id == str(name_or_id) or role.name == name_or_id
for role in member_roles
):
return True
raise CheckFailure(f"You need the '{name_or_id}' role to use this command.")
return check(predicate)
def has_any_role(
*names_or_ids: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has any of the roles with the given names or IDs."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.models import Member
if not ctx.guild:
raise CheckFailure("This command cannot be used in DMs.")
author = ctx.author
if not isinstance(author, Member):
try:
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
except Exception:
raise CheckFailure("Could not resolve author to a guild member.")
if not author:
raise CheckFailure("Could not resolve author to a guild member.")
member_roles = [
role for role in ctx.guild.roles if role.id in author.roles
]
# Convert names_or_ids to a set for efficient lookup
names_or_ids_set = set(map(str, names_or_ids))
if any(
role.id in names_or_ids_set or role.name in names_or_ids_set
for role in member_roles
):
return True
role_list = ", ".join(f"'{r}'" for r in names_or_ids)
raise CheckFailure(
f"You need one of the following roles to use this command: {role_list}"
)
return check(predicate)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,165 +1,167 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import uuid import uuid
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
from ..models import ActionRow from ..models import ActionRow
from .item import Item from .item import Item
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import Client from ..client import Client
from ..interactions import Interaction from ..interactions import Interaction
class View: class View:
"""Represents a container for UI components that can be sent with a message. """Represents a container for UI components that can be sent with a message.
Args: Args:
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out. timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
Defaults to 180. Defaults to 180.
""" """
def __init__(self, *, timeout: Optional[float] = 180.0): def __init__(self, *, timeout: Optional[float] = 180.0):
self.timeout = timeout self.timeout = timeout
self.id = str(uuid.uuid4()) self.id = str(uuid.uuid4())
self.__children: List[Item] = [] self.__children: List[Item] = []
self.__stopped = asyncio.Event() self.__stopped = asyncio.Event()
self._client: Optional[Client] = None self._client: Optional[Client] = None
self._message_id: Optional[str] = None self._message_id: Optional[str] = None
for item in self.__class__.__dict__.values(): # The below is a bit of a hack to support items defined as class members
if isinstance(item, Item): # e.g. button = Button(...)
self.add_item(item) for item in self.__class__.__dict__.values():
if isinstance(item, Item):
@property self.add_item(item)
def children(self) -> List[Item]:
return self.__children @property
def children(self) -> List[Item]:
def add_item(self, item: Item): return self.__children
"""Adds an item to the view."""
if not isinstance(item, Item): def add_item(self, item: Item):
raise TypeError("Only instances of 'Item' can be added to a View.") """Adds an item to the view."""
if not isinstance(item, Item):
if len(self.__children) >= 25: raise TypeError("Only instances of 'Item' can be added to a View.")
raise ValueError("A view can only have a maximum of 25 components.")
if len(self.__children) >= 25:
item._view = self raise ValueError("A view can only have a maximum of 25 components.")
self.__children.append(item)
if self.timeout is None and item.custom_id is None:
@property raise ValueError(
def message_id(self) -> Optional[str]: "All components in a persistent view must have a 'custom_id'."
return self._message_id )
@message_id.setter item._view = self
def message_id(self, value: str): self.__children.append(item)
self._message_id = value
@property
def to_components(self) -> List[ActionRow]: def message_id(self) -> Optional[str]:
"""Converts the view's children into a list of ActionRow components. return self._message_id
This retains the original, simple layout behaviour where each item is @message_id.setter
placed in its own :class:`ActionRow` to ensure backward compatibility. def message_id(self, value: str):
""" self._message_id = value
rows: List[ActionRow] = [] def to_components(self) -> List[ActionRow]:
"""Converts the view's children into a list of ActionRow components.
for item in self.children:
if item.custom_id is None: This retains the original, simple layout behaviour where each item is
item.custom_id = ( placed in its own :class:`ActionRow` to ensure backward compatibility.
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}" """
)
rows: List[ActionRow] = []
rows.append(ActionRow(components=[item]))
for item in self.children:
return rows rows.append(ActionRow(components=[item]))
def layout_components_advanced(self) -> List[ActionRow]: return rows
"""Group compatible components into rows following Discord rules."""
def layout_components_advanced(self) -> List[ActionRow]:
rows: List[ActionRow] = [] """Group compatible components into rows following Discord rules."""
for item in self.children: rows: List[ActionRow] = []
if item.custom_id is None:
item.custom_id = ( for item in self.children:
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}" if item.custom_id is None:
) item.custom_id = (
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
target_row = item.row )
if target_row is not None:
if not 0 <= target_row <= 4: target_row = item.row
raise ValueError("Row index must be between 0 and 4.") if target_row is not None:
if not 0 <= target_row <= 4:
while len(rows) <= target_row: raise ValueError("Row index must be between 0 and 4.")
if len(rows) >= 5:
raise ValueError("A view can have at most 5 action rows.") while len(rows) <= target_row:
rows.append(ActionRow()) if len(rows) >= 5:
raise ValueError("A view can have at most 5 action rows.")
rows[target_row].add_component(item) rows.append(ActionRow())
continue
rows[target_row].add_component(item)
placed = False continue
for row in rows:
try: placed = False
row.add_component(item) for row in rows:
placed = True try:
break row.add_component(item)
except ValueError: placed = True
continue break
except ValueError:
if not placed: continue
if len(rows) >= 5:
raise ValueError("A view can have at most 5 action rows.") if not placed:
new_row = ActionRow([item]) if len(rows) >= 5:
rows.append(new_row) raise ValueError("A view can have at most 5 action rows.")
new_row = ActionRow([item])
return rows rows.append(new_row)
def to_components_payload(self) -> List[Dict[str, Any]]: return rows
"""Converts the view's children into a list of component dictionaries
that can be sent to the Discord API.""" def to_components_payload(self) -> List[Dict[str, Any]]:
return [row.to_dict() for row in self.to_components()] """Converts the view's children into a list of component dictionaries
that can be sent to the Discord API."""
async def _dispatch(self, interaction: Interaction): return [row.to_dict() for row in self.to_components()]
"""Called by the client to dispatch an interaction to the correct item."""
if self.timeout is not None: async def _dispatch(self, interaction: Interaction):
self.__stopped.set() # Reset the timeout on each interaction """Called by the client to dispatch an interaction to the correct item."""
self.__stopped.clear() if self.timeout is not None:
self.__stopped.set() # Reset the timeout on each interaction
if interaction.data: self.__stopped.clear()
custom_id = interaction.data.custom_id
for child in self.children: if interaction.data:
if child.custom_id == custom_id: custom_id = interaction.data.custom_id
if child.callback: for child in self.children:
await child.callback(self, interaction) if child.custom_id == custom_id:
break if child.callback:
await child.callback(self, interaction)
async def wait(self) -> bool: break
"""Waits until the view has stopped interacting."""
return await self.__stopped.wait() async def wait(self) -> bool:
"""Waits until the view has stopped interacting."""
def stop(self): return await self.__stopped.wait()
"""Stops the view from listening to interactions."""
if not self.__stopped.is_set(): def stop(self):
self.__stopped.set() """Stops the view from listening to interactions."""
if not self.__stopped.is_set():
async def on_timeout(self): self.__stopped.set()
"""Called when the view times out."""
pass # User can override this async def on_timeout(self):
"""Called when the view times out."""
async def _start(self, client: Client): pass # User can override this
"""Starts the view's internal listener."""
self._client = client async def _start(self, client: Client):
if self.timeout is not None: """Starts the view's internal listener."""
asyncio.create_task(self._timeout_task()) self._client = client
if self.timeout is not None:
async def _timeout_task(self): asyncio.create_task(self._timeout_task())
"""The task that waits for the timeout and then stops the view."""
try: async def _timeout_task(self):
await asyncio.wait_for(self.wait(), timeout=self.timeout) """The task that waits for the timeout and then stops the view."""
except asyncio.TimeoutError: try:
self.stop() await asyncio.wait_for(self.wait(), timeout=self.timeout)
await self.on_timeout() except asyncio.TimeoutError:
if self._client and self._message_id: self.stop()
# Remove the view from the client's listeners await self.on_timeout()
self._client._views.pop(self._message_id, None) if self._client and self._message_id:
# Remove the view from the client's listeners
self._client._views.pop(self._message_id, None)

View File

@ -1,162 +1,244 @@
# disagreement/voice_client.py # disagreement/voice_client.py
"""Voice gateway and UDP audio client.""" """Voice gateway and UDP audio client."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import socket import socket
from typing import Optional, Sequence import threading
from typing import TYPE_CHECKING, Optional, Sequence
import aiohttp
import aiohttp
from .audio import AudioSource, FFmpegAudioSource # The following import is correct, but may be flagged by Pylance if the virtual
# environment is not configured correctly.
from nacl.secret import SecretBox
class VoiceClient:
"""Handles the Discord voice WebSocket connection and UDP streaming.""" from .audio import AudioSink, AudioSource, FFmpegAudioSource
from .models import User
def __init__(
self, if TYPE_CHECKING:
endpoint: str, from .client import Client
session_id: str,
token: str,
guild_id: int, class VoiceClient:
user_id: int, """Handles the Discord voice WebSocket connection and UDP streaming."""
*,
ws=None, def __init__(
udp: Optional[socket.socket] = None, self,
loop: Optional[asyncio.AbstractEventLoop] = None, client: Client,
verbose: bool = False, endpoint: str,
) -> None: session_id: str,
self.endpoint = endpoint token: str,
self.session_id = session_id guild_id: int,
self.token = token user_id: int,
self.guild_id = str(guild_id) *,
self.user_id = str(user_id) ws=None,
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws udp: Optional[socket.socket] = None,
self._udp = udp loop: Optional[asyncio.AbstractEventLoop] = None,
self._session: Optional[aiohttp.ClientSession] = None verbose: bool = False,
self._heartbeat_task: Optional[asyncio.Task] = None ) -> None:
self._heartbeat_interval: Optional[float] = None self.client = client
self._loop = loop or asyncio.get_event_loop() self.endpoint = endpoint
self.verbose = verbose self.session_id = session_id
self.ssrc: Optional[int] = None self.token = token
self.secret_key: Optional[Sequence[int]] = None self.guild_id = str(guild_id)
self._server_ip: Optional[str] = None self.user_id = str(user_id)
self._server_port: Optional[int] = None self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
self._current_source: Optional[AudioSource] = None self._udp = udp
self._play_task: Optional[asyncio.Task] = None self._session: Optional[aiohttp.ClientSession] = None
self._heartbeat_task: Optional[asyncio.Task] = None
async def connect(self) -> None: self._receive_task: Optional[asyncio.Task] = None
if self._ws is None: self._udp_receive_thread: Optional[threading.Thread] = None
self._session = aiohttp.ClientSession() self._heartbeat_interval: Optional[float] = None
self._ws = await self._session.ws_connect(self.endpoint) try:
self._loop = loop or asyncio.get_running_loop()
hello = await self._ws.receive_json() except RuntimeError:
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000 self._loop = asyncio.new_event_loop()
self._heartbeat_task = self._loop.create_task(self._heartbeat()) asyncio.set_event_loop(self._loop)
self.verbose = verbose
await self._ws.send_json( self.ssrc: Optional[int] = None
{ self.secret_key: Optional[Sequence[int]] = None
"op": 0, self._server_ip: Optional[str] = None
"d": { self._server_port: Optional[int] = None
"server_id": self.guild_id, self._current_source: Optional[AudioSource] = None
"user_id": self.user_id, self._play_task: Optional[asyncio.Task] = None
"session_id": self.session_id, self._sink: Optional[AudioSink] = None
"token": self.token, self._ssrc_map: dict[int, int] = {}
}, self._ssrc_lock = threading.Lock()
}
) async def connect(self) -> None:
if self._ws is None:
ready = await self._ws.receive_json() self._session = aiohttp.ClientSession()
data = ready["d"] self._ws = await self._session.ws_connect(self.endpoint)
self.ssrc = data["ssrc"]
self._server_ip = data["ip"] hello = await self._ws.receive_json()
self._server_port = data["port"] self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
self._heartbeat_task = self._loop.create_task(self._heartbeat())
if self._udp is None:
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) await self._ws.send_json(
self._udp.connect((self._server_ip, self._server_port)) {
"op": 0,
await self._ws.send_json( "d": {
{ "server_id": self.guild_id,
"op": 1, "user_id": self.user_id,
"d": { "session_id": self.session_id,
"protocol": "udp", "token": self.token,
"data": { },
"address": self._udp.getsockname()[0], }
"port": self._udp.getsockname()[1], )
"mode": "xsalsa20_poly1305",
}, ready = await self._ws.receive_json()
}, data = ready["d"]
} self.ssrc = data["ssrc"]
) self._server_ip = data["ip"]
self._server_port = data["port"]
session_desc = await self._ws.receive_json()
self.secret_key = session_desc["d"].get("secret_key") if self._udp is None:
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
async def _heartbeat(self) -> None: self._udp.connect((self._server_ip, self._server_port))
assert self._ws is not None
assert self._heartbeat_interval is not None await self._ws.send_json(
try: {
while True: "op": 1,
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)}) "d": {
await asyncio.sleep(self._heartbeat_interval) "protocol": "udp",
except asyncio.CancelledError: "data": {
pass "address": self._udp.getsockname()[0],
"port": self._udp.getsockname()[1],
async def send_audio_frame(self, frame: bytes) -> None: "mode": "xsalsa20_poly1305",
if not self._udp: },
raise RuntimeError("UDP socket not initialised") },
self._udp.send(frame) }
)
async def _play_loop(self) -> None:
assert self._current_source is not None session_desc = await self._ws.receive_json()
try: self.secret_key = session_desc["d"].get("secret_key")
while True:
data = await self._current_source.read() async def _heartbeat(self) -> None:
if not data: assert self._ws is not None
break assert self._heartbeat_interval is not None
await self.send_audio_frame(data) try:
finally: while True:
await self._current_source.close() await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
self._current_source = None await asyncio.sleep(self._heartbeat_interval)
self._play_task = None except asyncio.CancelledError:
pass
async def stop(self) -> None:
if self._play_task: async def _receive_loop(self) -> None:
self._play_task.cancel() assert self._ws is not None
with contextlib.suppress(asyncio.CancelledError): while True:
await self._play_task try:
self._play_task = None msg = await self._ws.receive_json()
if self._current_source: op = msg.get("op")
await self._current_source.close() data = msg.get("d")
self._current_source = None if op == 5: # Speaking
user_id = int(data["user_id"])
async def play(self, source: AudioSource, *, wait: bool = True) -> None: ssrc = data["ssrc"]
"""|coro| Play an :class:`AudioSource` on the voice connection.""" with self._ssrc_lock:
self._ssrc_map[ssrc] = user_id
await self.stop() except (asyncio.CancelledError, aiohttp.ClientError):
self._current_source = source break
self._play_task = self._loop.create_task(self._play_loop())
if wait: def _udp_receive_loop(self) -> None:
await self._play_task assert self._udp is not None
assert self.secret_key is not None
async def play_file(self, filename: str, *, wait: bool = True) -> None: box = SecretBox(bytes(self.secret_key))
"""|coro| Stream an audio file or URL using FFmpeg.""" while True:
try:
await self.play(FFmpegAudioSource(filename), wait=wait) packet = self._udp.recv(4096)
if len(packet) < 12:
async def close(self) -> None: continue
await self.stop()
if self._heartbeat_task: ssrc = int.from_bytes(packet[8:12], "big")
self._heartbeat_task.cancel() with self._ssrc_lock:
with contextlib.suppress(asyncio.CancelledError): if ssrc not in self._ssrc_map:
await self._heartbeat_task continue
if self._ws: user_id = self._ssrc_map[ssrc]
await self._ws.close() user = self.client._users.get(str(user_id))
if self._session: if not user:
await self._session.close() continue
if self._udp:
self._udp.close() decrypted = box.decrypt(packet[12:])
if self._sink:
self._sink.write(user, decrypted)
except (socket.error, asyncio.CancelledError):
break
except Exception as e:
if self.verbose:
print(f"Error in UDP receive loop: {e}")
async def send_audio_frame(self, frame: bytes) -> None:
if not self._udp:
raise RuntimeError("UDP socket not initialised")
self._udp.send(frame)
async def _play_loop(self) -> None:
assert self._current_source is not None
try:
while True:
data = await self._current_source.read()
if not data:
break
await self.send_audio_frame(data)
finally:
await self._current_source.close()
self._current_source = None
self._play_task = None
async def stop(self) -> None:
if self._play_task:
self._play_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._play_task
self._play_task = None
if self._current_source:
await self._current_source.close()
self._current_source = None
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
"""|coro| Play an :class:`AudioSource` on the voice connection."""
await self.stop()
self._current_source = source
self._play_task = self._loop.create_task(self._play_loop())
if wait:
await self._play_task
async def play_file(self, filename: str, *, wait: bool = True) -> None:
"""|coro| Stream an audio file or URL using FFmpeg."""
await self.play(FFmpegAudioSource(filename), wait=wait)
def listen(self, sink: AudioSink) -> None:
"""Start listening to voice and routing to a sink."""
if not isinstance(sink, AudioSink):
raise TypeError("sink must be an AudioSink instance")
self._sink = sink
if not self._udp_receive_thread:
self._udp_receive_thread = threading.Thread(
target=self._udp_receive_loop, daemon=True
)
self._udp_receive_thread.start()
async def close(self) -> None:
await self.stop()
if self._heartbeat_task:
self._heartbeat_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._heartbeat_task
if self._receive_task:
self._receive_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._receive_task
if self._ws:
await self._ws.close()
if self._session:
await self._session.close()
if self._udp:
self._udp.close()
if self._udp_receive_thread:
self._udp_receive_thread.join(timeout=1)
if self._sink:
self._sink.close()

View File

@ -1,56 +1,57 @@
[project] [project]
name = "disagreement" name = "disagreement"
version = "0.2.0rc1" version = "0.2.0rc1"
description = "A Python library for the Discord API." description = "A Python library for the Discord API."
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
license = {text = "BSD 3-Clause"} license = {text = "BSD 3-Clause"}
authors = [ authors = [
{name = "Slipstream", email = "me@slipstreamm.dev"} {name = "Slipstream", email = "me@slipstreamm.dev"}
] ]
keywords = ["discord", "api", "bot", "async", "aiohttp"] keywords = ["discord", "api", "bot", "async", "aiohttp"]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: BSD License", "License :: OSI Approved :: BSD License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Internet", "Topic :: Internet",
] ]
dependencies = [ dependencies = [
"aiohttp>=3.9.0,<4.0.0", "aiohttp>=3.9.0,<4.0.0",
] "PyNaCl>=1.5.0,<2.0.0",
]
[project.optional-dependencies]
test = [ [project.optional-dependencies]
"pytest>=8.0.0", test = [
"pytest-asyncio>=1.0.0", "pytest>=8.0.0",
"hypothesis>=6.132.0", "pytest-asyncio>=1.0.0",
] "hypothesis>=6.132.0",
dev = [ ]
"python-dotenv>=1.0.0", dev = [
] "python-dotenv>=1.0.0",
]
[project.urls]
Homepage = "https://github.com/Slipstreamm/disagreement" [project.urls]
Issues = "https://github.com/Slipstreamm/disagreement/issues" Homepage = "https://github.com/Slipstreamm/disagreement"
Issues = "https://github.com/Slipstreamm/disagreement/issues"
[build-system]
requires = ["setuptools>=61.0"] [build-system]
build-backend = "setuptools.build_meta" requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Optional: for linting/formatting, e.g., Ruff
# [tool.ruff] # Optional: for linting/formatting, e.g., Ruff
# line-length = 88 # [tool.ruff]
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set # line-length = 88
# ignore = [] # select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
# ignore = []
# [tool.ruff.format]
# quote-style = "double" # [tool.ruff.format]
# quote-style = "double"