Compare commits

..

8 Commits

Author SHA1 Message Date
ed83a9da85
Implements caching system with TTL and member filtering
Introduces a flexible caching infrastructure with time-to-live support and configurable member caching based on status, voice state, and join events.

Adds AudioSink abstract base class to support audio output handling in voice connections.

Replaces direct dictionary access with cache objects throughout the client, enabling automatic expiration and intelligent member filtering based on user-defined flags.

Updates guild parsing to incorporate presence and voice state data for more accurate member caching decisions.
2025-06-11 02:11:33 -06:00
0151526d07
chore: Apply code formatting across the codebase
This commit applies consistent code formatting to multiple files. No functional changes are included.
2025-06-11 02:10:33 -06:00
17b7ea35a9
feat: Implement guild.fetch_members 2025-06-11 02:06:19 -06:00
28702fa8a1
feat(voice): Implement voice receiving and audio sinks 2025-06-11 02:06:18 -06:00
97505948ee
feat: Add advanced message, channel, and thread management 2025-06-11 02:06:16 -06:00
152c0f12be
feat(ui): Implement persistent views 2025-06-11 02:06:15 -06:00
eb38ecf671
feat(commands): Add has_role and has_any_role check decorators 2025-06-11 02:06:13 -06:00
2bd45c87ca
feat: Enhance command framework with groups, checks, and events 2025-06-11 02:06:11 -06:00
14 changed files with 7679 additions and 6664 deletions

View File

@ -114,3 +114,20 @@ class FFmpegAudioSource(AudioSource):
if isinstance(self.source, io.IOBase):
with contextlib.suppress(Exception):
self.source.close()
class AudioSink:
"""Abstract base class for audio sinks."""
def write(self, user, data):
"""Write a chunk of PCM audio.
Subclasses must implement this. The data is raw PCM at 48kHz
stereo.
"""
raise NotImplementedError
def close(self) -> None:
"""Cleanup the sink when the voice client disconnects."""
return None

View File

@ -4,7 +4,8 @@ import time
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
if TYPE_CHECKING:
from .models import Channel, Guild
from .models import Channel, Guild, Member
from .caching import MemberCacheFlags
T = TypeVar("T")
@ -53,3 +54,32 @@ class GuildCache(Cache["Guild"]):
class ChannelCache(Cache["Channel"]):
"""Cache specifically for :class:`Channel` objects."""
class MemberCache(Cache["Member"]):
"""
A cache for :class:`Member` objects that respects :class:`MemberCacheFlags`.
"""
def __init__(self, flags: MemberCacheFlags, ttl: Optional[float] = None) -> None:
super().__init__(ttl)
self.flags = flags
def _should_cache(self, member: Member) -> bool:
"""Determines if a member should be cached based on the flags."""
if self.flags.all:
return True
if self.flags.none:
return False
if self.flags.online and member.status != "offline":
return True
if self.flags.voice and member.voice_state is not None:
return True
if self.flags.joined and getattr(member, "_just_joined", False):
return True
return False
def set(self, key: str, value: Member) -> None:
if self._should_cache(value):
super().set(key, value)

120
disagreement/caching.py Normal file
View File

@ -0,0 +1,120 @@
from __future__ import annotations
import operator
from typing import Any, Callable, ClassVar, Dict, Iterator, Tuple
class _MemberCacheFlagValue:
flag: int
def __init__(self, func: Callable[[Any], bool]):
self.flag = getattr(func, 'flag', 0)
self.__doc__ = func.__doc__
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any:
if instance is None:
return self
return instance.value & self.flag != 0
def __set__(self, instance: Any, value: bool) -> None:
if value:
instance.value |= self.flag
else:
instance.value &= ~self.flag
def __repr__(self) -> str:
return f'<{self.__class__.__name__} flag={self.flag}>'
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
setattr(func, 'flag', flag)
return _MemberCacheFlagValue(func)
return decorator
class MemberCacheFlags:
__slots__ = ('value',)
VALID_FLAGS: ClassVar[Dict[str, int]] = {
'joined': 1 << 0,
'voice': 1 << 1,
'online': 1 << 2,
}
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_FLAGS
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid member cache flag.')
setattr(self, key, value)
@classmethod
def _from_value(cls, value: int) -> MemberCacheFlags:
self = cls.__new__(cls)
self.value = value
return self
def __eq__(self, other: object) -> bool:
return isinstance(other, MemberCacheFlags) and self.value == other.value
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
return hash(self.value)
def __repr__(self) -> str:
return f'<MemberCacheFlags value={self.value}>'
def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name in self.VALID_FLAGS:
yield name, getattr(self, name)
def __int__(self) -> int:
return self.value
def __index__(self) -> int:
return self.value
@classmethod
def all(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with all flags enabled."""
return cls._from_value(cls.ALL_FLAGS)
@classmethod
def none(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with all flags disabled."""
return cls._from_value(0)
@classmethod
def only_joined(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['joined'])
@classmethod
def only_voice(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['voice'])
@classmethod
def only_online(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['online'])
@flag_value(1 << 0)
def joined(self) -> bool:
"""Whether to cache members that have just joined the guild."""
return False
@flag_value(1 << 1)
def voice(self) -> bool:
"""Whether to cache members that are in a voice channel."""
return False
@flag_value(1 << 2)
def online(self) -> bool:
"""Whether to cache members that are online."""
return False

File diff suppressed because it is too large Load Diff

View File

@ -76,7 +76,7 @@ class EventDispatcher:
"""Parses MESSAGE_DELETE and updates message cache."""
message_id = data.get("id")
if message_id:
self._client._messages.pop(message_id, None)
self._client._messages.invalidate(message_id)
return data
def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]:
@ -124,7 +124,7 @@ class EventDispatcher:
"""Parses GUILD_MEMBER_ADD into a Member object."""
guild_id = str(data.get("guild_id"))
return self._client.parse_member(data, guild_id)
return self._client.parse_member(data, guild_id, just_joined=True)
def _parse_guild_member_remove(self, data: Dict[str, Any]):
"""Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model."""

View File

@ -1,61 +1,65 @@
# disagreement/ext/commands/__init__.py
"""
disagreement.ext.commands - A command framework extension for the Disagreement library.
"""
from .cog import Cog
from .core import (
Command,
CommandContext,
CommandHandler,
) # CommandHandler might be internal
from .decorators import (
command,
listener,
check,
check_any,
cooldown,
max_concurrency,
requires_permissions,
)
from .errors import (
CommandError,
CommandNotFound,
BadArgument,
MissingRequiredArgument,
ArgumentParsingError,
CheckFailure,
CheckAnyFailure,
CommandOnCooldown,
CommandInvokeError,
MaxConcurrencyReached,
)
__all__ = [
# Cog
"Cog",
# Core
"Command",
"CommandContext",
# "CommandHandler", # Usually not part of public API for direct use by bot devs
# Decorators
"command",
"listener",
"check",
"check_any",
"cooldown",
"max_concurrency",
"requires_permissions",
# Errors
"CommandError",
"CommandNotFound",
"BadArgument",
"MissingRequiredArgument",
"ArgumentParsingError",
"CheckFailure",
"CheckAnyFailure",
"CommandOnCooldown",
"CommandInvokeError",
"MaxConcurrencyReached",
]
# disagreement/ext/commands/__init__.py
"""
disagreement.ext.commands - A command framework extension for the Disagreement library.
"""
from .cog import Cog
from .core import (
Command,
CommandContext,
CommandHandler,
) # CommandHandler might be internal
from .decorators import (
command,
listener,
check,
check_any,
cooldown,
max_concurrency,
requires_permissions,
has_role,
has_any_role,
)
from .errors import (
CommandError,
CommandNotFound,
BadArgument,
MissingRequiredArgument,
ArgumentParsingError,
CheckFailure,
CheckAnyFailure,
CommandOnCooldown,
CommandInvokeError,
MaxConcurrencyReached,
)
__all__ = [
# Cog
"Cog",
# Core
"Command",
"CommandContext",
# "CommandHandler", # Usually not part of public API for direct use by bot devs
# Decorators
"command",
"listener",
"check",
"check_any",
"cooldown",
"max_concurrency",
"requires_permissions",
"has_role",
"has_any_role",
# Errors
"CommandError",
"CommandNotFound",
"BadArgument",
"MissingRequiredArgument",
"ArgumentParsingError",
"CheckFailure",
"CheckAnyFailure",
"CommandOnCooldown",
"CommandInvokeError",
"MaxConcurrencyReached",
]

File diff suppressed because it is too large Load Diff

View File

@ -1,219 +1,298 @@
# disagreement/ext/commands/decorators.py
from __future__ import annotations
import asyncio
import inspect
import time
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
if TYPE_CHECKING:
from .core import Command, CommandContext
from disagreement.permissions import Permissions
from disagreement.models import Member, Guild, Channel
def command(
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
) -> Callable:
"""
A decorator that transforms a function into a Command.
Args:
name (Optional[str]): The name of the command. Defaults to the function name.
aliases (Optional[List[str]]): Alternative names for the command.
**attrs: Additional attributes to pass to the Command constructor
(e.g., brief, description, hidden).
Returns:
Callable: A decorator that registers the command.
"""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Command callback must be a coroutine function.")
from .core import Command
cmd_name = name or func.__name__
if hasattr(func, "__command_attrs__"):
raise TypeError("Function is already a command or has command attributes.")
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
func.__command_object__ = cmd # type: ignore
return func
return decorator
def listener(
name: Optional[str] = None,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""
A decorator that marks a function as an event listener within a Cog.
"""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Listener callback must be a coroutine function.")
actual_event_name = name or func.__name__
setattr(func, "__listener_name__", actual_event_name)
return func
return decorator
def check(
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator to add a check to a command."""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
checks = getattr(func, "__command_checks__", [])
checks.append(predicate)
setattr(func, "__command_checks__", checks)
return func
return decorator
def check_any(
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator that passes if any predicate returns ``True``."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckAnyFailure, CheckFailure
errors = []
for p in predicates:
try:
result = p(ctx)
if inspect.isawaitable(result):
result = await result
if result:
return True
except CheckFailure as e:
errors.append(e)
raise CheckAnyFailure(errors)
return check(predicate)
def max_concurrency(
number: int, per: str = "user"
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Limit how many concurrent invocations of a command are allowed.
Parameters
----------
number:
The maximum number of concurrent invocations.
per:
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
"""
if number < 1:
raise ValueError("Concurrency number must be at least 1.")
if per not in {"user", "guild", "global"}:
raise ValueError("per must be 'user', 'guild', or 'global'.")
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
setattr(func, "__max_concurrency__", (number, per))
return func
return decorator
def cooldown(
rate: int, per: float
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Simple per-user cooldown decorator."""
buckets: dict[str, dict[str, float]] = {}
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CommandOnCooldown
now = time.monotonic()
user_buckets = buckets.setdefault(ctx.command.name, {})
reset = user_buckets.get(ctx.author.id, 0)
if now < reset:
raise CommandOnCooldown(reset - now)
user_buckets[ctx.author.id] = now + per
return True
return check(predicate)
def _compute_permissions(
member: "Member", channel: "Channel", guild: "Guild"
) -> "Permissions":
"""Compute the effective permissions for a member in a channel."""
return channel.permissions_for(member)
def requires_permissions(
*perms: "Permissions",
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has the given permissions in the channel."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.permissions import (
has_permissions,
missing_permissions,
)
from disagreement.models import Member
channel = getattr(ctx, "channel", None)
if channel is None and hasattr(ctx.bot, "get_channel"):
channel = ctx.bot.get_channel(ctx.message.channel_id)
if channel is None and hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
if channel is None:
raise CheckFailure("Channel for permission check not found.")
guild = getattr(channel, "guild", None)
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
if hasattr(ctx.bot, "get_guild"):
guild = ctx.bot.get_guild(channel.guild_id)
if not guild and hasattr(ctx.bot, "fetch_guild"):
guild = await ctx.bot.fetch_guild(channel.guild_id)
if not guild:
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
if is_dm:
if perms:
raise CheckFailure("Permission checks are not supported in DMs.")
return True
raise CheckFailure("Guild for permission check not found.")
member = ctx.author
if not isinstance(member, Member):
member = guild.get_member(ctx.author.id)
if not member and hasattr(ctx.bot, "fetch_member"):
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
if not member:
raise CheckFailure("Could not resolve author to a guild member.")
perms_value = _compute_permissions(member, channel, guild)
if not has_permissions(perms_value, *perms):
missing = missing_permissions(perms_value, *perms)
missing_names = ", ".join(p.name for p in missing if p.name)
raise CheckFailure(f"Missing permissions: {missing_names}")
return True
return check(predicate)
# disagreement/ext/commands/decorators.py
from __future__ import annotations
import asyncio
import inspect
import time
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
if TYPE_CHECKING:
from .core import Command, CommandContext
from disagreement.permissions import Permissions
from disagreement.models import Member, Guild, Channel
def command(
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
) -> Callable:
"""
A decorator that transforms a function into a Command.
Args:
name (Optional[str]): The name of the command. Defaults to the function name.
aliases (Optional[List[str]]): Alternative names for the command.
**attrs: Additional attributes to pass to the Command constructor
(e.g., brief, description, hidden).
Returns:
Callable: A decorator that registers the command.
"""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Command callback must be a coroutine function.")
from .core import Command
cmd_name = name or func.__name__
if hasattr(func, "__command_attrs__"):
raise TypeError("Function is already a command or has command attributes.")
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
func.__command_object__ = cmd # type: ignore
return func
return decorator
def listener(
name: Optional[str] = None,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""
A decorator that marks a function as an event listener within a Cog.
"""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Listener callback must be a coroutine function.")
actual_event_name = name or func.__name__
setattr(func, "__listener_name__", actual_event_name)
return func
return decorator
def check(
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator to add a check to a command."""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
checks = getattr(func, "__command_checks__", [])
checks.append(predicate)
setattr(func, "__command_checks__", checks)
return func
return decorator
def check_any(
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator that passes if any predicate returns ``True``."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckAnyFailure, CheckFailure
errors = []
for p in predicates:
try:
result = p(ctx)
if inspect.isawaitable(result):
result = await result
if result:
return True
except CheckFailure as e:
errors.append(e)
raise CheckAnyFailure(errors)
return check(predicate)
def max_concurrency(
number: int, per: str = "user"
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Limit how many concurrent invocations of a command are allowed.
Parameters
----------
number:
The maximum number of concurrent invocations.
per:
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
"""
if number < 1:
raise ValueError("Concurrency number must be at least 1.")
if per not in {"user", "guild", "global"}:
raise ValueError("per must be 'user', 'guild', or 'global'.")
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
setattr(func, "__max_concurrency__", (number, per))
return func
return decorator
def cooldown(
rate: int, per: float
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Simple per-user cooldown decorator."""
buckets: dict[str, dict[str, float]] = {}
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CommandOnCooldown
now = time.monotonic()
user_buckets = buckets.setdefault(ctx.command.name, {})
reset = user_buckets.get(ctx.author.id, 0)
if now < reset:
raise CommandOnCooldown(reset - now)
user_buckets[ctx.author.id] = now + per
return True
return check(predicate)
def _compute_permissions(
member: "Member", channel: "Channel", guild: "Guild"
) -> "Permissions":
"""Compute the effective permissions for a member in a channel."""
return channel.permissions_for(member)
def requires_permissions(
*perms: "Permissions",
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has the given permissions in the channel."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.permissions import (
has_permissions,
missing_permissions,
)
from disagreement.models import Member
channel = getattr(ctx, "channel", None)
if channel is None and hasattr(ctx.bot, "get_channel"):
channel = ctx.bot.get_channel(ctx.message.channel_id)
if channel is None and hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
if channel is None:
raise CheckFailure("Channel for permission check not found.")
guild = getattr(channel, "guild", None)
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
if hasattr(ctx.bot, "get_guild"):
guild = ctx.bot.get_guild(channel.guild_id)
if not guild and hasattr(ctx.bot, "fetch_guild"):
guild = await ctx.bot.fetch_guild(channel.guild_id)
if not guild:
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
if is_dm:
if perms:
raise CheckFailure("Permission checks are not supported in DMs.")
return True
raise CheckFailure("Guild for permission check not found.")
member = ctx.author
if not isinstance(member, Member):
member = guild.get_member(ctx.author.id)
if not member and hasattr(ctx.bot, "fetch_member"):
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
if not member:
raise CheckFailure("Could not resolve author to a guild member.")
perms_value = _compute_permissions(member, channel, guild)
if not has_permissions(perms_value, *perms):
missing = missing_permissions(perms_value, *perms)
missing_names = ", ".join(p.name for p in missing if p.name)
raise CheckFailure(f"Missing permissions: {missing_names}")
return True
return check(predicate)
def has_role(
name_or_id: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has a role with the given name or ID."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.models import Member
if not ctx.guild:
raise CheckFailure("This command cannot be used in DMs.")
author = ctx.author
if not isinstance(author, Member):
try:
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
except Exception:
raise CheckFailure("Could not resolve author to a guild member.")
if not author:
raise CheckFailure("Could not resolve author to a guild member.")
# Create a list of the member's role objects by looking them up in the guild's roles list
member_roles = [
role for role in ctx.guild.roles if role.id in author.roles
]
if any(
role.id == str(name_or_id) or role.name == name_or_id
for role in member_roles
):
return True
raise CheckFailure(f"You need the '{name_or_id}' role to use this command.")
return check(predicate)
def has_any_role(
*names_or_ids: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has any of the roles with the given names or IDs."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.models import Member
if not ctx.guild:
raise CheckFailure("This command cannot be used in DMs.")
author = ctx.author
if not isinstance(author, Member):
try:
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
except Exception:
raise CheckFailure("Could not resolve author to a guild member.")
if not author:
raise CheckFailure("Could not resolve author to a guild member.")
member_roles = [
role for role in ctx.guild.roles if role.id in author.roles
]
# Convert names_or_ids to a set for efficient lookup
names_or_ids_set = set(map(str, names_or_ids))
if any(
role.id in names_or_ids_set or role.name in names_or_ids_set
for role in member_roles
):
return True
role_list = ", ".join(f"'{r}'" for r in names_or_ids)
raise CheckFailure(
f"You need one of the following roles to use this command: {role_list}"
)
return check(predicate)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,165 +1,167 @@
from __future__ import annotations
import asyncio
import uuid
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
from ..models import ActionRow
from .item import Item
if TYPE_CHECKING:
from ..client import Client
from ..interactions import Interaction
class View:
"""Represents a container for UI components that can be sent with a message.
Args:
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
Defaults to 180.
"""
def __init__(self, *, timeout: Optional[float] = 180.0):
self.timeout = timeout
self.id = str(uuid.uuid4())
self.__children: List[Item] = []
self.__stopped = asyncio.Event()
self._client: Optional[Client] = None
self._message_id: Optional[str] = None
for item in self.__class__.__dict__.values():
if isinstance(item, Item):
self.add_item(item)
@property
def children(self) -> List[Item]:
return self.__children
def add_item(self, item: Item):
"""Adds an item to the view."""
if not isinstance(item, Item):
raise TypeError("Only instances of 'Item' can be added to a View.")
if len(self.__children) >= 25:
raise ValueError("A view can only have a maximum of 25 components.")
item._view = self
self.__children.append(item)
@property
def message_id(self) -> Optional[str]:
return self._message_id
@message_id.setter
def message_id(self, value: str):
self._message_id = value
def to_components(self) -> List[ActionRow]:
"""Converts the view's children into a list of ActionRow components.
This retains the original, simple layout behaviour where each item is
placed in its own :class:`ActionRow` to ensure backward compatibility.
"""
rows: List[ActionRow] = []
for item in self.children:
if item.custom_id is None:
item.custom_id = (
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
)
rows.append(ActionRow(components=[item]))
return rows
def layout_components_advanced(self) -> List[ActionRow]:
"""Group compatible components into rows following Discord rules."""
rows: List[ActionRow] = []
for item in self.children:
if item.custom_id is None:
item.custom_id = (
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
)
target_row = item.row
if target_row is not None:
if not 0 <= target_row <= 4:
raise ValueError("Row index must be between 0 and 4.")
while len(rows) <= target_row:
if len(rows) >= 5:
raise ValueError("A view can have at most 5 action rows.")
rows.append(ActionRow())
rows[target_row].add_component(item)
continue
placed = False
for row in rows:
try:
row.add_component(item)
placed = True
break
except ValueError:
continue
if not placed:
if len(rows) >= 5:
raise ValueError("A view can have at most 5 action rows.")
new_row = ActionRow([item])
rows.append(new_row)
return rows
def to_components_payload(self) -> List[Dict[str, Any]]:
"""Converts the view's children into a list of component dictionaries
that can be sent to the Discord API."""
return [row.to_dict() for row in self.to_components()]
async def _dispatch(self, interaction: Interaction):
"""Called by the client to dispatch an interaction to the correct item."""
if self.timeout is not None:
self.__stopped.set() # Reset the timeout on each interaction
self.__stopped.clear()
if interaction.data:
custom_id = interaction.data.custom_id
for child in self.children:
if child.custom_id == custom_id:
if child.callback:
await child.callback(self, interaction)
break
async def wait(self) -> bool:
"""Waits until the view has stopped interacting."""
return await self.__stopped.wait()
def stop(self):
"""Stops the view from listening to interactions."""
if not self.__stopped.is_set():
self.__stopped.set()
async def on_timeout(self):
"""Called when the view times out."""
pass # User can override this
async def _start(self, client: Client):
"""Starts the view's internal listener."""
self._client = client
if self.timeout is not None:
asyncio.create_task(self._timeout_task())
async def _timeout_task(self):
"""The task that waits for the timeout and then stops the view."""
try:
await asyncio.wait_for(self.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
self.stop()
await self.on_timeout()
if self._client and self._message_id:
# Remove the view from the client's listeners
self._client._views.pop(self._message_id, None)
from __future__ import annotations
import asyncio
import uuid
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
from ..models import ActionRow
from .item import Item
if TYPE_CHECKING:
from ..client import Client
from ..interactions import Interaction
class View:
"""Represents a container for UI components that can be sent with a message.
Args:
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
Defaults to 180.
"""
def __init__(self, *, timeout: Optional[float] = 180.0):
self.timeout = timeout
self.id = str(uuid.uuid4())
self.__children: List[Item] = []
self.__stopped = asyncio.Event()
self._client: Optional[Client] = None
self._message_id: Optional[str] = None
# The below is a bit of a hack to support items defined as class members
# e.g. button = Button(...)
for item in self.__class__.__dict__.values():
if isinstance(item, Item):
self.add_item(item)
@property
def children(self) -> List[Item]:
return self.__children
def add_item(self, item: Item):
"""Adds an item to the view."""
if not isinstance(item, Item):
raise TypeError("Only instances of 'Item' can be added to a View.")
if len(self.__children) >= 25:
raise ValueError("A view can only have a maximum of 25 components.")
if self.timeout is None and item.custom_id is None:
raise ValueError(
"All components in a persistent view must have a 'custom_id'."
)
item._view = self
self.__children.append(item)
@property
def message_id(self) -> Optional[str]:
return self._message_id
@message_id.setter
def message_id(self, value: str):
self._message_id = value
def to_components(self) -> List[ActionRow]:
"""Converts the view's children into a list of ActionRow components.
This retains the original, simple layout behaviour where each item is
placed in its own :class:`ActionRow` to ensure backward compatibility.
"""
rows: List[ActionRow] = []
for item in self.children:
rows.append(ActionRow(components=[item]))
return rows
def layout_components_advanced(self) -> List[ActionRow]:
"""Group compatible components into rows following Discord rules."""
rows: List[ActionRow] = []
for item in self.children:
if item.custom_id is None:
item.custom_id = (
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
)
target_row = item.row
if target_row is not None:
if not 0 <= target_row <= 4:
raise ValueError("Row index must be between 0 and 4.")
while len(rows) <= target_row:
if len(rows) >= 5:
raise ValueError("A view can have at most 5 action rows.")
rows.append(ActionRow())
rows[target_row].add_component(item)
continue
placed = False
for row in rows:
try:
row.add_component(item)
placed = True
break
except ValueError:
continue
if not placed:
if len(rows) >= 5:
raise ValueError("A view can have at most 5 action rows.")
new_row = ActionRow([item])
rows.append(new_row)
return rows
def to_components_payload(self) -> List[Dict[str, Any]]:
"""Converts the view's children into a list of component dictionaries
that can be sent to the Discord API."""
return [row.to_dict() for row in self.to_components()]
async def _dispatch(self, interaction: Interaction):
"""Called by the client to dispatch an interaction to the correct item."""
if self.timeout is not None:
self.__stopped.set() # Reset the timeout on each interaction
self.__stopped.clear()
if interaction.data:
custom_id = interaction.data.custom_id
for child in self.children:
if child.custom_id == custom_id:
if child.callback:
await child.callback(self, interaction)
break
async def wait(self) -> bool:
"""Waits until the view has stopped interacting."""
return await self.__stopped.wait()
def stop(self):
"""Stops the view from listening to interactions."""
if not self.__stopped.is_set():
self.__stopped.set()
async def on_timeout(self):
"""Called when the view times out."""
pass # User can override this
async def _start(self, client: Client):
"""Starts the view's internal listener."""
self._client = client
if self.timeout is not None:
asyncio.create_task(self._timeout_task())
async def _timeout_task(self):
"""The task that waits for the timeout and then stops the view."""
try:
await asyncio.wait_for(self.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
self.stop()
await self.on_timeout()
if self._client and self._message_id:
# Remove the view from the client's listeners
self._client._views.pop(self._message_id, None)

View File

@ -1,162 +1,244 @@
# disagreement/voice_client.py
"""Voice gateway and UDP audio client."""
from __future__ import annotations
import asyncio
import contextlib
import socket
from typing import Optional, Sequence
import aiohttp
from .audio import AudioSource, FFmpegAudioSource
class VoiceClient:
"""Handles the Discord voice WebSocket connection and UDP streaming."""
def __init__(
self,
endpoint: str,
session_id: str,
token: str,
guild_id: int,
user_id: int,
*,
ws=None,
udp: Optional[socket.socket] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
verbose: bool = False,
) -> None:
self.endpoint = endpoint
self.session_id = session_id
self.token = token
self.guild_id = str(guild_id)
self.user_id = str(user_id)
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
self._udp = udp
self._session: Optional[aiohttp.ClientSession] = None
self._heartbeat_task: Optional[asyncio.Task] = None
self._heartbeat_interval: Optional[float] = None
self._loop = loop or asyncio.get_event_loop()
self.verbose = verbose
self.ssrc: Optional[int] = None
self.secret_key: Optional[Sequence[int]] = None
self._server_ip: Optional[str] = None
self._server_port: Optional[int] = None
self._current_source: Optional[AudioSource] = None
self._play_task: Optional[asyncio.Task] = None
async def connect(self) -> None:
if self._ws is None:
self._session = aiohttp.ClientSession()
self._ws = await self._session.ws_connect(self.endpoint)
hello = await self._ws.receive_json()
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
self._heartbeat_task = self._loop.create_task(self._heartbeat())
await self._ws.send_json(
{
"op": 0,
"d": {
"server_id": self.guild_id,
"user_id": self.user_id,
"session_id": self.session_id,
"token": self.token,
},
}
)
ready = await self._ws.receive_json()
data = ready["d"]
self.ssrc = data["ssrc"]
self._server_ip = data["ip"]
self._server_port = data["port"]
if self._udp is None:
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._udp.connect((self._server_ip, self._server_port))
await self._ws.send_json(
{
"op": 1,
"d": {
"protocol": "udp",
"data": {
"address": self._udp.getsockname()[0],
"port": self._udp.getsockname()[1],
"mode": "xsalsa20_poly1305",
},
},
}
)
session_desc = await self._ws.receive_json()
self.secret_key = session_desc["d"].get("secret_key")
async def _heartbeat(self) -> None:
assert self._ws is not None
assert self._heartbeat_interval is not None
try:
while True:
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
await asyncio.sleep(self._heartbeat_interval)
except asyncio.CancelledError:
pass
async def send_audio_frame(self, frame: bytes) -> None:
if not self._udp:
raise RuntimeError("UDP socket not initialised")
self._udp.send(frame)
async def _play_loop(self) -> None:
assert self._current_source is not None
try:
while True:
data = await self._current_source.read()
if not data:
break
await self.send_audio_frame(data)
finally:
await self._current_source.close()
self._current_source = None
self._play_task = None
async def stop(self) -> None:
if self._play_task:
self._play_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._play_task
self._play_task = None
if self._current_source:
await self._current_source.close()
self._current_source = None
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
"""|coro| Play an :class:`AudioSource` on the voice connection."""
await self.stop()
self._current_source = source
self._play_task = self._loop.create_task(self._play_loop())
if wait:
await self._play_task
async def play_file(self, filename: str, *, wait: bool = True) -> None:
"""|coro| Stream an audio file or URL using FFmpeg."""
await self.play(FFmpegAudioSource(filename), wait=wait)
async def close(self) -> None:
await self.stop()
if self._heartbeat_task:
self._heartbeat_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._heartbeat_task
if self._ws:
await self._ws.close()
if self._session:
await self._session.close()
if self._udp:
self._udp.close()
# disagreement/voice_client.py
"""Voice gateway and UDP audio client."""
from __future__ import annotations
import asyncio
import contextlib
import socket
import threading
from typing import TYPE_CHECKING, Optional, Sequence
import aiohttp
# The following import is correct, but may be flagged by Pylance if the virtual
# environment is not configured correctly.
from nacl.secret import SecretBox
from .audio import AudioSink, AudioSource, FFmpegAudioSource
from .models import User
if TYPE_CHECKING:
from .client import Client
class VoiceClient:
"""Handles the Discord voice WebSocket connection and UDP streaming."""
def __init__(
self,
client: Client,
endpoint: str,
session_id: str,
token: str,
guild_id: int,
user_id: int,
*,
ws=None,
udp: Optional[socket.socket] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
verbose: bool = False,
) -> None:
self.client = client
self.endpoint = endpoint
self.session_id = session_id
self.token = token
self.guild_id = str(guild_id)
self.user_id = str(user_id)
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
self._udp = udp
self._session: Optional[aiohttp.ClientSession] = None
self._heartbeat_task: Optional[asyncio.Task] = None
self._receive_task: Optional[asyncio.Task] = None
self._udp_receive_thread: Optional[threading.Thread] = None
self._heartbeat_interval: Optional[float] = None
try:
self._loop = loop or asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self.verbose = verbose
self.ssrc: Optional[int] = None
self.secret_key: Optional[Sequence[int]] = None
self._server_ip: Optional[str] = None
self._server_port: Optional[int] = None
self._current_source: Optional[AudioSource] = None
self._play_task: Optional[asyncio.Task] = None
self._sink: Optional[AudioSink] = None
self._ssrc_map: dict[int, int] = {}
self._ssrc_lock = threading.Lock()
async def connect(self) -> None:
if self._ws is None:
self._session = aiohttp.ClientSession()
self._ws = await self._session.ws_connect(self.endpoint)
hello = await self._ws.receive_json()
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
self._heartbeat_task = self._loop.create_task(self._heartbeat())
await self._ws.send_json(
{
"op": 0,
"d": {
"server_id": self.guild_id,
"user_id": self.user_id,
"session_id": self.session_id,
"token": self.token,
},
}
)
ready = await self._ws.receive_json()
data = ready["d"]
self.ssrc = data["ssrc"]
self._server_ip = data["ip"]
self._server_port = data["port"]
if self._udp is None:
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._udp.connect((self._server_ip, self._server_port))
await self._ws.send_json(
{
"op": 1,
"d": {
"protocol": "udp",
"data": {
"address": self._udp.getsockname()[0],
"port": self._udp.getsockname()[1],
"mode": "xsalsa20_poly1305",
},
},
}
)
session_desc = await self._ws.receive_json()
self.secret_key = session_desc["d"].get("secret_key")
async def _heartbeat(self) -> None:
assert self._ws is not None
assert self._heartbeat_interval is not None
try:
while True:
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
await asyncio.sleep(self._heartbeat_interval)
except asyncio.CancelledError:
pass
async def _receive_loop(self) -> None:
assert self._ws is not None
while True:
try:
msg = await self._ws.receive_json()
op = msg.get("op")
data = msg.get("d")
if op == 5: # Speaking
user_id = int(data["user_id"])
ssrc = data["ssrc"]
with self._ssrc_lock:
self._ssrc_map[ssrc] = user_id
except (asyncio.CancelledError, aiohttp.ClientError):
break
def _udp_receive_loop(self) -> None:
assert self._udp is not None
assert self.secret_key is not None
box = SecretBox(bytes(self.secret_key))
while True:
try:
packet = self._udp.recv(4096)
if len(packet) < 12:
continue
ssrc = int.from_bytes(packet[8:12], "big")
with self._ssrc_lock:
if ssrc not in self._ssrc_map:
continue
user_id = self._ssrc_map[ssrc]
user = self.client._users.get(str(user_id))
if not user:
continue
decrypted = box.decrypt(packet[12:])
if self._sink:
self._sink.write(user, decrypted)
except (socket.error, asyncio.CancelledError):
break
except Exception as e:
if self.verbose:
print(f"Error in UDP receive loop: {e}")
async def send_audio_frame(self, frame: bytes) -> None:
if not self._udp:
raise RuntimeError("UDP socket not initialised")
self._udp.send(frame)
async def _play_loop(self) -> None:
assert self._current_source is not None
try:
while True:
data = await self._current_source.read()
if not data:
break
await self.send_audio_frame(data)
finally:
await self._current_source.close()
self._current_source = None
self._play_task = None
async def stop(self) -> None:
if self._play_task:
self._play_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._play_task
self._play_task = None
if self._current_source:
await self._current_source.close()
self._current_source = None
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
"""|coro| Play an :class:`AudioSource` on the voice connection."""
await self.stop()
self._current_source = source
self._play_task = self._loop.create_task(self._play_loop())
if wait:
await self._play_task
async def play_file(self, filename: str, *, wait: bool = True) -> None:
"""|coro| Stream an audio file or URL using FFmpeg."""
await self.play(FFmpegAudioSource(filename), wait=wait)
def listen(self, sink: AudioSink) -> None:
"""Start listening to voice and routing to a sink."""
if not isinstance(sink, AudioSink):
raise TypeError("sink must be an AudioSink instance")
self._sink = sink
if not self._udp_receive_thread:
self._udp_receive_thread = threading.Thread(
target=self._udp_receive_loop, daemon=True
)
self._udp_receive_thread.start()
async def close(self) -> None:
await self.stop()
if self._heartbeat_task:
self._heartbeat_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._heartbeat_task
if self._receive_task:
self._receive_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._receive_task
if self._ws:
await self._ws.close()
if self._session:
await self._session.close()
if self._udp:
self._udp.close()
if self._udp_receive_thread:
self._udp_receive_thread.join(timeout=1)
if self._sink:
self._sink.close()

View File

@ -1,56 +1,57 @@
[project]
name = "disagreement"
version = "0.2.0rc1"
description = "A Python library for the Discord API."
readme = "README.md"
requires-python = ">=3.10"
license = {text = "BSD 3-Clause"}
authors = [
{name = "Slipstream", email = "me@slipstreamm.dev"}
]
keywords = ["discord", "api", "bot", "async", "aiohttp"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Internet",
]
dependencies = [
"aiohttp>=3.9.0,<4.0.0",
]
[project.optional-dependencies]
test = [
"pytest>=8.0.0",
"pytest-asyncio>=1.0.0",
"hypothesis>=6.132.0",
]
dev = [
"python-dotenv>=1.0.0",
]
[project.urls]
Homepage = "https://github.com/Slipstreamm/disagreement"
Issues = "https://github.com/Slipstreamm/disagreement/issues"
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Optional: for linting/formatting, e.g., Ruff
# [tool.ruff]
# line-length = 88
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
# ignore = []
# [tool.ruff.format]
# quote-style = "double"
[project]
name = "disagreement"
version = "0.2.0rc1"
description = "A Python library for the Discord API."
readme = "README.md"
requires-python = ">=3.10"
license = {text = "BSD 3-Clause"}
authors = [
{name = "Slipstream", email = "me@slipstreamm.dev"}
]
keywords = ["discord", "api", "bot", "async", "aiohttp"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Internet",
]
dependencies = [
"aiohttp>=3.9.0,<4.0.0",
"PyNaCl>=1.5.0,<2.0.0",
]
[project.optional-dependencies]
test = [
"pytest>=8.0.0",
"pytest-asyncio>=1.0.0",
"hypothesis>=6.132.0",
]
dev = [
"python-dotenv>=1.0.0",
]
[project.urls]
Homepage = "https://github.com/Slipstreamm/disagreement"
Issues = "https://github.com/Slipstreamm/disagreement/issues"
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Optional: for linting/formatting, e.g., Ruff
# [tool.ruff]
# line-length = 88
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
# ignore = []
# [tool.ruff.format]
# quote-style = "double"