Compare commits
8 Commits
be85444aa0
...
ed83a9da85
Author | SHA1 | Date | |
---|---|---|---|
ed83a9da85 | |||
0151526d07 | |||
17b7ea35a9 | |||
28702fa8a1 | |||
97505948ee | |||
152c0f12be | |||
eb38ecf671 | |||
2bd45c87ca |
@ -114,3 +114,20 @@ class FFmpegAudioSource(AudioSource):
|
||||
if isinstance(self.source, io.IOBase):
|
||||
with contextlib.suppress(Exception):
|
||||
self.source.close()
|
||||
|
||||
class AudioSink:
|
||||
"""Abstract base class for audio sinks."""
|
||||
|
||||
def write(self, user, data):
|
||||
"""Write a chunk of PCM audio.
|
||||
|
||||
Subclasses must implement this. The data is raw PCM at 48kHz
|
||||
stereo.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cleanup the sink when the voice client disconnects."""
|
||||
|
||||
return None
|
||||
|
@ -4,7 +4,8 @@ import time
|
||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import Channel, Guild
|
||||
from .models import Channel, Guild, Member
|
||||
from .caching import MemberCacheFlags
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -53,3 +54,32 @@ class GuildCache(Cache["Guild"]):
|
||||
|
||||
class ChannelCache(Cache["Channel"]):
|
||||
"""Cache specifically for :class:`Channel` objects."""
|
||||
|
||||
|
||||
class MemberCache(Cache["Member"]):
|
||||
"""
|
||||
A cache for :class:`Member` objects that respects :class:`MemberCacheFlags`.
|
||||
"""
|
||||
|
||||
def __init__(self, flags: MemberCacheFlags, ttl: Optional[float] = None) -> None:
|
||||
super().__init__(ttl)
|
||||
self.flags = flags
|
||||
|
||||
def _should_cache(self, member: Member) -> bool:
|
||||
"""Determines if a member should be cached based on the flags."""
|
||||
if self.flags.all:
|
||||
return True
|
||||
if self.flags.none:
|
||||
return False
|
||||
|
||||
if self.flags.online and member.status != "offline":
|
||||
return True
|
||||
if self.flags.voice and member.voice_state is not None:
|
||||
return True
|
||||
if self.flags.joined and getattr(member, "_just_joined", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
def set(self, key: str, value: Member) -> None:
|
||||
if self._should_cache(value):
|
||||
super().set(key, value)
|
||||
|
120
disagreement/caching.py
Normal file
120
disagreement/caching.py
Normal file
@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterator, Tuple
|
||||
|
||||
|
||||
class _MemberCacheFlagValue:
|
||||
flag: int
|
||||
|
||||
def __init__(self, func: Callable[[Any], bool]):
|
||||
self.flag = getattr(func, 'flag', 0)
|
||||
self.__doc__ = func.__doc__
|
||||
|
||||
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any:
|
||||
if instance is None:
|
||||
return self
|
||||
return instance.value & self.flag != 0
|
||||
|
||||
def __set__(self, instance: Any, value: bool) -> None:
|
||||
if value:
|
||||
instance.value |= self.flag
|
||||
else:
|
||||
instance.value &= ~self.flag
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} flag={self.flag}>'
|
||||
|
||||
|
||||
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
|
||||
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
|
||||
setattr(func, 'flag', flag)
|
||||
return _MemberCacheFlagValue(func)
|
||||
return decorator
|
||||
|
||||
|
||||
class MemberCacheFlags:
|
||||
__slots__ = ('value',)
|
||||
|
||||
VALID_FLAGS: ClassVar[Dict[str, int]] = {
|
||||
'joined': 1 << 0,
|
||||
'voice': 1 << 1,
|
||||
'online': 1 << 2,
|
||||
}
|
||||
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
|
||||
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
|
||||
|
||||
def __init__(self, **kwargs: bool):
|
||||
self.value = self.DEFAULT_FLAGS
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError(f'{key!r} is not a valid member cache flag.')
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
def _from_value(cls, value: int) -> MemberCacheFlags:
|
||||
self = cls.__new__(cls)
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, MemberCacheFlags) and self.value == other.value
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<MemberCacheFlags value={self.value}>'
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[str, bool]]:
|
||||
for name in self.VALID_FLAGS:
|
||||
yield name, getattr(self, name)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.value
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with all flags enabled."""
|
||||
return cls._from_value(cls.ALL_FLAGS)
|
||||
|
||||
@classmethod
|
||||
def none(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with all flags disabled."""
|
||||
return cls._from_value(0)
|
||||
|
||||
@classmethod
|
||||
def only_joined(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
|
||||
return cls._from_value(cls.VALID_FLAGS['joined'])
|
||||
|
||||
@classmethod
|
||||
def only_voice(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
|
||||
return cls._from_value(cls.VALID_FLAGS['voice'])
|
||||
|
||||
@classmethod
|
||||
def only_online(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
|
||||
return cls._from_value(cls.VALID_FLAGS['online'])
|
||||
|
||||
@flag_value(1 << 0)
|
||||
def joined(self) -> bool:
|
||||
"""Whether to cache members that have just joined the guild."""
|
||||
return False
|
||||
|
||||
@flag_value(1 << 1)
|
||||
def voice(self) -> bool:
|
||||
"""Whether to cache members that are in a voice channel."""
|
||||
return False
|
||||
|
||||
@flag_value(1 << 2)
|
||||
def online(self) -> bool:
|
||||
"""Whether to cache members that are online."""
|
||||
return False
|
File diff suppressed because it is too large
Load Diff
@ -76,7 +76,7 @@ class EventDispatcher:
|
||||
"""Parses MESSAGE_DELETE and updates message cache."""
|
||||
message_id = data.get("id")
|
||||
if message_id:
|
||||
self._client._messages.pop(message_id, None)
|
||||
self._client._messages.invalidate(message_id)
|
||||
return data
|
||||
|
||||
def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@ -124,7 +124,7 @@ class EventDispatcher:
|
||||
"""Parses GUILD_MEMBER_ADD into a Member object."""
|
||||
|
||||
guild_id = str(data.get("guild_id"))
|
||||
return self._client.parse_member(data, guild_id)
|
||||
return self._client.parse_member(data, guild_id, just_joined=True)
|
||||
|
||||
def _parse_guild_member_remove(self, data: Dict[str, Any]):
|
||||
"""Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model."""
|
||||
|
@ -1,61 +1,65 @@
|
||||
# disagreement/ext/commands/__init__.py
|
||||
|
||||
"""
|
||||
disagreement.ext.commands - A command framework extension for the Disagreement library.
|
||||
"""
|
||||
|
||||
from .cog import Cog
|
||||
from .core import (
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandHandler,
|
||||
) # CommandHandler might be internal
|
||||
from .decorators import (
|
||||
command,
|
||||
listener,
|
||||
check,
|
||||
check_any,
|
||||
cooldown,
|
||||
max_concurrency,
|
||||
requires_permissions,
|
||||
)
|
||||
from .errors import (
|
||||
CommandError,
|
||||
CommandNotFound,
|
||||
BadArgument,
|
||||
MissingRequiredArgument,
|
||||
ArgumentParsingError,
|
||||
CheckFailure,
|
||||
CheckAnyFailure,
|
||||
CommandOnCooldown,
|
||||
CommandInvokeError,
|
||||
MaxConcurrencyReached,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Cog
|
||||
"Cog",
|
||||
# Core
|
||||
"Command",
|
||||
"CommandContext",
|
||||
# "CommandHandler", # Usually not part of public API for direct use by bot devs
|
||||
# Decorators
|
||||
"command",
|
||||
"listener",
|
||||
"check",
|
||||
"check_any",
|
||||
"cooldown",
|
||||
"max_concurrency",
|
||||
"requires_permissions",
|
||||
# Errors
|
||||
"CommandError",
|
||||
"CommandNotFound",
|
||||
"BadArgument",
|
||||
"MissingRequiredArgument",
|
||||
"ArgumentParsingError",
|
||||
"CheckFailure",
|
||||
"CheckAnyFailure",
|
||||
"CommandOnCooldown",
|
||||
"CommandInvokeError",
|
||||
"MaxConcurrencyReached",
|
||||
]
|
||||
# disagreement/ext/commands/__init__.py
|
||||
|
||||
"""
|
||||
disagreement.ext.commands - A command framework extension for the Disagreement library.
|
||||
"""
|
||||
|
||||
from .cog import Cog
|
||||
from .core import (
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandHandler,
|
||||
) # CommandHandler might be internal
|
||||
from .decorators import (
|
||||
command,
|
||||
listener,
|
||||
check,
|
||||
check_any,
|
||||
cooldown,
|
||||
max_concurrency,
|
||||
requires_permissions,
|
||||
has_role,
|
||||
has_any_role,
|
||||
)
|
||||
from .errors import (
|
||||
CommandError,
|
||||
CommandNotFound,
|
||||
BadArgument,
|
||||
MissingRequiredArgument,
|
||||
ArgumentParsingError,
|
||||
CheckFailure,
|
||||
CheckAnyFailure,
|
||||
CommandOnCooldown,
|
||||
CommandInvokeError,
|
||||
MaxConcurrencyReached,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Cog
|
||||
"Cog",
|
||||
# Core
|
||||
"Command",
|
||||
"CommandContext",
|
||||
# "CommandHandler", # Usually not part of public API for direct use by bot devs
|
||||
# Decorators
|
||||
"command",
|
||||
"listener",
|
||||
"check",
|
||||
"check_any",
|
||||
"cooldown",
|
||||
"max_concurrency",
|
||||
"requires_permissions",
|
||||
"has_role",
|
||||
"has_any_role",
|
||||
# Errors
|
||||
"CommandError",
|
||||
"CommandNotFound",
|
||||
"BadArgument",
|
||||
"MissingRequiredArgument",
|
||||
"ArgumentParsingError",
|
||||
"CheckFailure",
|
||||
"CheckAnyFailure",
|
||||
"CommandOnCooldown",
|
||||
"CommandInvokeError",
|
||||
"MaxConcurrencyReached",
|
||||
]
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,219 +1,298 @@
|
||||
# disagreement/ext/commands/decorators.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import Command, CommandContext
|
||||
from disagreement.permissions import Permissions
|
||||
from disagreement.models import Member, Guild, Channel
|
||||
|
||||
|
||||
def command(
|
||||
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
|
||||
) -> Callable:
|
||||
"""
|
||||
A decorator that transforms a function into a Command.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): The name of the command. Defaults to the function name.
|
||||
aliases (Optional[List[str]]): Alternative names for the command.
|
||||
**attrs: Additional attributes to pass to the Command constructor
|
||||
(e.g., brief, description, hidden).
|
||||
|
||||
Returns:
|
||||
Callable: A decorator that registers the command.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Command callback must be a coroutine function.")
|
||||
|
||||
from .core import Command
|
||||
|
||||
cmd_name = name or func.__name__
|
||||
|
||||
if hasattr(func, "__command_attrs__"):
|
||||
raise TypeError("Function is already a command or has command attributes.")
|
||||
|
||||
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
||||
func.__command_object__ = cmd # type: ignore
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def listener(
|
||||
name: Optional[str] = None,
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""
|
||||
A decorator that marks a function as an event listener within a Cog.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Listener callback must be a coroutine function.")
|
||||
|
||||
actual_event_name = name or func.__name__
|
||||
setattr(func, "__listener_name__", actual_event_name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check(
|
||||
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator to add a check to a command."""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
checks = getattr(func, "__command_checks__", [])
|
||||
checks.append(predicate)
|
||||
setattr(func, "__command_checks__", checks)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_any(
|
||||
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator that passes if any predicate returns ``True``."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckAnyFailure, CheckFailure
|
||||
|
||||
errors = []
|
||||
for p in predicates:
|
||||
try:
|
||||
result = p(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if result:
|
||||
return True
|
||||
except CheckFailure as e:
|
||||
errors.append(e)
|
||||
raise CheckAnyFailure(errors)
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def max_concurrency(
|
||||
number: int, per: str = "user"
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Limit how many concurrent invocations of a command are allowed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
number:
|
||||
The maximum number of concurrent invocations.
|
||||
per:
|
||||
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
|
||||
"""
|
||||
|
||||
if number < 1:
|
||||
raise ValueError("Concurrency number must be at least 1.")
|
||||
if per not in {"user", "guild", "global"}:
|
||||
raise ValueError("per must be 'user', 'guild', or 'global'.")
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
setattr(func, "__max_concurrency__", (number, per))
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def cooldown(
|
||||
rate: int, per: float
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Simple per-user cooldown decorator."""
|
||||
|
||||
buckets: dict[str, dict[str, float]] = {}
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CommandOnCooldown
|
||||
|
||||
now = time.monotonic()
|
||||
user_buckets = buckets.setdefault(ctx.command.name, {})
|
||||
reset = user_buckets.get(ctx.author.id, 0)
|
||||
if now < reset:
|
||||
raise CommandOnCooldown(reset - now)
|
||||
user_buckets[ctx.author.id] = now + per
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def _compute_permissions(
|
||||
member: "Member", channel: "Channel", guild: "Guild"
|
||||
) -> "Permissions":
|
||||
"""Compute the effective permissions for a member in a channel."""
|
||||
return channel.permissions_for(member)
|
||||
|
||||
|
||||
def requires_permissions(
|
||||
*perms: "Permissions",
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Check that the invoking member has the given permissions in the channel."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
from disagreement.permissions import (
|
||||
has_permissions,
|
||||
missing_permissions,
|
||||
)
|
||||
from disagreement.models import Member
|
||||
|
||||
channel = getattr(ctx, "channel", None)
|
||||
if channel is None and hasattr(ctx.bot, "get_channel"):
|
||||
channel = ctx.bot.get_channel(ctx.message.channel_id)
|
||||
if channel is None and hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
|
||||
|
||||
if channel is None:
|
||||
raise CheckFailure("Channel for permission check not found.")
|
||||
|
||||
guild = getattr(channel, "guild", None)
|
||||
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
|
||||
if hasattr(ctx.bot, "get_guild"):
|
||||
guild = ctx.bot.get_guild(channel.guild_id)
|
||||
if not guild and hasattr(ctx.bot, "fetch_guild"):
|
||||
guild = await ctx.bot.fetch_guild(channel.guild_id)
|
||||
|
||||
if not guild:
|
||||
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
|
||||
if is_dm:
|
||||
if perms:
|
||||
raise CheckFailure("Permission checks are not supported in DMs.")
|
||||
return True
|
||||
raise CheckFailure("Guild for permission check not found.")
|
||||
|
||||
member = ctx.author
|
||||
if not isinstance(member, Member):
|
||||
member = guild.get_member(ctx.author.id)
|
||||
if not member and hasattr(ctx.bot, "fetch_member"):
|
||||
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
|
||||
|
||||
if not member:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
perms_value = _compute_permissions(member, channel, guild)
|
||||
|
||||
if not has_permissions(perms_value, *perms):
|
||||
missing = missing_permissions(perms_value, *perms)
|
||||
missing_names = ", ".join(p.name for p in missing if p.name)
|
||||
raise CheckFailure(f"Missing permissions: {missing_names}")
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
# disagreement/ext/commands/decorators.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import Command, CommandContext
|
||||
from disagreement.permissions import Permissions
|
||||
from disagreement.models import Member, Guild, Channel
|
||||
|
||||
|
||||
def command(
|
||||
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
|
||||
) -> Callable:
|
||||
"""
|
||||
A decorator that transforms a function into a Command.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): The name of the command. Defaults to the function name.
|
||||
aliases (Optional[List[str]]): Alternative names for the command.
|
||||
**attrs: Additional attributes to pass to the Command constructor
|
||||
(e.g., brief, description, hidden).
|
||||
|
||||
Returns:
|
||||
Callable: A decorator that registers the command.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Command callback must be a coroutine function.")
|
||||
|
||||
from .core import Command
|
||||
|
||||
cmd_name = name or func.__name__
|
||||
|
||||
if hasattr(func, "__command_attrs__"):
|
||||
raise TypeError("Function is already a command or has command attributes.")
|
||||
|
||||
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
||||
func.__command_object__ = cmd # type: ignore
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def listener(
|
||||
name: Optional[str] = None,
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""
|
||||
A decorator that marks a function as an event listener within a Cog.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Listener callback must be a coroutine function.")
|
||||
|
||||
actual_event_name = name or func.__name__
|
||||
setattr(func, "__listener_name__", actual_event_name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check(
|
||||
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator to add a check to a command."""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
checks = getattr(func, "__command_checks__", [])
|
||||
checks.append(predicate)
|
||||
setattr(func, "__command_checks__", checks)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_any(
|
||||
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator that passes if any predicate returns ``True``."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckAnyFailure, CheckFailure
|
||||
|
||||
errors = []
|
||||
for p in predicates:
|
||||
try:
|
||||
result = p(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if result:
|
||||
return True
|
||||
except CheckFailure as e:
|
||||
errors.append(e)
|
||||
raise CheckAnyFailure(errors)
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def max_concurrency(
|
||||
number: int, per: str = "user"
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Limit how many concurrent invocations of a command are allowed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
number:
|
||||
The maximum number of concurrent invocations.
|
||||
per:
|
||||
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
|
||||
"""
|
||||
|
||||
if number < 1:
|
||||
raise ValueError("Concurrency number must be at least 1.")
|
||||
if per not in {"user", "guild", "global"}:
|
||||
raise ValueError("per must be 'user', 'guild', or 'global'.")
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
setattr(func, "__max_concurrency__", (number, per))
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def cooldown(
|
||||
rate: int, per: float
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Simple per-user cooldown decorator."""
|
||||
|
||||
buckets: dict[str, dict[str, float]] = {}
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CommandOnCooldown
|
||||
|
||||
now = time.monotonic()
|
||||
user_buckets = buckets.setdefault(ctx.command.name, {})
|
||||
reset = user_buckets.get(ctx.author.id, 0)
|
||||
if now < reset:
|
||||
raise CommandOnCooldown(reset - now)
|
||||
user_buckets[ctx.author.id] = now + per
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def _compute_permissions(
|
||||
member: "Member", channel: "Channel", guild: "Guild"
|
||||
) -> "Permissions":
|
||||
"""Compute the effective permissions for a member in a channel."""
|
||||
return channel.permissions_for(member)
|
||||
|
||||
|
||||
def requires_permissions(
|
||||
*perms: "Permissions",
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Check that the invoking member has the given permissions in the channel."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
from disagreement.permissions import (
|
||||
has_permissions,
|
||||
missing_permissions,
|
||||
)
|
||||
from disagreement.models import Member
|
||||
|
||||
channel = getattr(ctx, "channel", None)
|
||||
if channel is None and hasattr(ctx.bot, "get_channel"):
|
||||
channel = ctx.bot.get_channel(ctx.message.channel_id)
|
||||
if channel is None and hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
|
||||
|
||||
if channel is None:
|
||||
raise CheckFailure("Channel for permission check not found.")
|
||||
|
||||
guild = getattr(channel, "guild", None)
|
||||
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
|
||||
if hasattr(ctx.bot, "get_guild"):
|
||||
guild = ctx.bot.get_guild(channel.guild_id)
|
||||
if not guild and hasattr(ctx.bot, "fetch_guild"):
|
||||
guild = await ctx.bot.fetch_guild(channel.guild_id)
|
||||
|
||||
if not guild:
|
||||
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
|
||||
if is_dm:
|
||||
if perms:
|
||||
raise CheckFailure("Permission checks are not supported in DMs.")
|
||||
return True
|
||||
raise CheckFailure("Guild for permission check not found.")
|
||||
|
||||
member = ctx.author
|
||||
if not isinstance(member, Member):
|
||||
member = guild.get_member(ctx.author.id)
|
||||
if not member and hasattr(ctx.bot, "fetch_member"):
|
||||
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
|
||||
|
||||
if not member:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
perms_value = _compute_permissions(member, channel, guild)
|
||||
|
||||
if not has_permissions(perms_value, *perms):
|
||||
missing = missing_permissions(perms_value, *perms)
|
||||
missing_names = ", ".join(p.name for p in missing if p.name)
|
||||
raise CheckFailure(f"Missing permissions: {missing_names}")
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
def has_role(
|
||||
name_or_id: str | int,
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Check that the invoking member has a role with the given name or ID."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
from disagreement.models import Member
|
||||
|
||||
if not ctx.guild:
|
||||
raise CheckFailure("This command cannot be used in DMs.")
|
||||
|
||||
author = ctx.author
|
||||
if not isinstance(author, Member):
|
||||
try:
|
||||
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
|
||||
except Exception:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
if not author:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
# Create a list of the member's role objects by looking them up in the guild's roles list
|
||||
member_roles = [
|
||||
role for role in ctx.guild.roles if role.id in author.roles
|
||||
]
|
||||
|
||||
if any(
|
||||
role.id == str(name_or_id) or role.name == name_or_id
|
||||
for role in member_roles
|
||||
):
|
||||
return True
|
||||
|
||||
raise CheckFailure(f"You need the '{name_or_id}' role to use this command.")
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def has_any_role(
|
||||
*names_or_ids: str | int,
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Check that the invoking member has any of the roles with the given names or IDs."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
from disagreement.models import Member
|
||||
|
||||
if not ctx.guild:
|
||||
raise CheckFailure("This command cannot be used in DMs.")
|
||||
|
||||
author = ctx.author
|
||||
if not isinstance(author, Member):
|
||||
try:
|
||||
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
|
||||
except Exception:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
if not author:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
member_roles = [
|
||||
role for role in ctx.guild.roles if role.id in author.roles
|
||||
]
|
||||
# Convert names_or_ids to a set for efficient lookup
|
||||
names_or_ids_set = set(map(str, names_or_ids))
|
||||
|
||||
if any(
|
||||
role.id in names_or_ids_set or role.name in names_or_ids_set
|
||||
for role in member_roles
|
||||
):
|
||||
return True
|
||||
|
||||
role_list = ", ".join(f"'{r}'" for r in names_or_ids)
|
||||
raise CheckFailure(
|
||||
f"You need one of the following roles to use this command: {role_list}"
|
||||
)
|
||||
|
||||
return check(predicate)
|
||||
|
File diff suppressed because it is too large
Load Diff
2161
disagreement/http.py
2161
disagreement/http.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,165 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from ..models import ActionRow
|
||||
from .item import Item
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import Client
|
||||
from ..interactions import Interaction
|
||||
|
||||
|
||||
class View:
|
||||
"""Represents a container for UI components that can be sent with a message.
|
||||
|
||||
Args:
|
||||
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
|
||||
Defaults to 180.
|
||||
"""
|
||||
|
||||
def __init__(self, *, timeout: Optional[float] = 180.0):
|
||||
self.timeout = timeout
|
||||
self.id = str(uuid.uuid4())
|
||||
self.__children: List[Item] = []
|
||||
self.__stopped = asyncio.Event()
|
||||
self._client: Optional[Client] = None
|
||||
self._message_id: Optional[str] = None
|
||||
|
||||
for item in self.__class__.__dict__.values():
|
||||
if isinstance(item, Item):
|
||||
self.add_item(item)
|
||||
|
||||
@property
|
||||
def children(self) -> List[Item]:
|
||||
return self.__children
|
||||
|
||||
def add_item(self, item: Item):
|
||||
"""Adds an item to the view."""
|
||||
if not isinstance(item, Item):
|
||||
raise TypeError("Only instances of 'Item' can be added to a View.")
|
||||
|
||||
if len(self.__children) >= 25:
|
||||
raise ValueError("A view can only have a maximum of 25 components.")
|
||||
|
||||
item._view = self
|
||||
self.__children.append(item)
|
||||
|
||||
@property
|
||||
def message_id(self) -> Optional[str]:
|
||||
return self._message_id
|
||||
|
||||
@message_id.setter
|
||||
def message_id(self, value: str):
|
||||
self._message_id = value
|
||||
|
||||
def to_components(self) -> List[ActionRow]:
|
||||
"""Converts the view's children into a list of ActionRow components.
|
||||
|
||||
This retains the original, simple layout behaviour where each item is
|
||||
placed in its own :class:`ActionRow` to ensure backward compatibility.
|
||||
"""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
if item.custom_id is None:
|
||||
item.custom_id = (
|
||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||
)
|
||||
|
||||
rows.append(ActionRow(components=[item]))
|
||||
|
||||
return rows
|
||||
|
||||
def layout_components_advanced(self) -> List[ActionRow]:
|
||||
"""Group compatible components into rows following Discord rules."""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
if item.custom_id is None:
|
||||
item.custom_id = (
|
||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||
)
|
||||
|
||||
target_row = item.row
|
||||
if target_row is not None:
|
||||
if not 0 <= target_row <= 4:
|
||||
raise ValueError("Row index must be between 0 and 4.")
|
||||
|
||||
while len(rows) <= target_row:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
rows.append(ActionRow())
|
||||
|
||||
rows[target_row].add_component(item)
|
||||
continue
|
||||
|
||||
placed = False
|
||||
for row in rows:
|
||||
try:
|
||||
row.add_component(item)
|
||||
placed = True
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not placed:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
new_row = ActionRow([item])
|
||||
rows.append(new_row)
|
||||
|
||||
return rows
|
||||
|
||||
def to_components_payload(self) -> List[Dict[str, Any]]:
|
||||
"""Converts the view's children into a list of component dictionaries
|
||||
that can be sent to the Discord API."""
|
||||
return [row.to_dict() for row in self.to_components()]
|
||||
|
||||
async def _dispatch(self, interaction: Interaction):
|
||||
"""Called by the client to dispatch an interaction to the correct item."""
|
||||
if self.timeout is not None:
|
||||
self.__stopped.set() # Reset the timeout on each interaction
|
||||
self.__stopped.clear()
|
||||
|
||||
if interaction.data:
|
||||
custom_id = interaction.data.custom_id
|
||||
for child in self.children:
|
||||
if child.custom_id == custom_id:
|
||||
if child.callback:
|
||||
await child.callback(self, interaction)
|
||||
break
|
||||
|
||||
async def wait(self) -> bool:
|
||||
"""Waits until the view has stopped interacting."""
|
||||
return await self.__stopped.wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stops the view from listening to interactions."""
|
||||
if not self.__stopped.is_set():
|
||||
self.__stopped.set()
|
||||
|
||||
async def on_timeout(self):
|
||||
"""Called when the view times out."""
|
||||
pass # User can override this
|
||||
|
||||
async def _start(self, client: Client):
|
||||
"""Starts the view's internal listener."""
|
||||
self._client = client
|
||||
if self.timeout is not None:
|
||||
asyncio.create_task(self._timeout_task())
|
||||
|
||||
async def _timeout_task(self):
|
||||
"""The task that waits for the timeout and then stops the view."""
|
||||
try:
|
||||
await asyncio.wait_for(self.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self.stop()
|
||||
await self.on_timeout()
|
||||
if self._client and self._message_id:
|
||||
# Remove the view from the client's listeners
|
||||
self._client._views.pop(self._message_id, None)
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from ..models import ActionRow
|
||||
from .item import Item
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import Client
|
||||
from ..interactions import Interaction
|
||||
|
||||
|
||||
class View:
|
||||
"""Represents a container for UI components that can be sent with a message.
|
||||
|
||||
Args:
|
||||
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
|
||||
Defaults to 180.
|
||||
"""
|
||||
|
||||
def __init__(self, *, timeout: Optional[float] = 180.0):
|
||||
self.timeout = timeout
|
||||
self.id = str(uuid.uuid4())
|
||||
self.__children: List[Item] = []
|
||||
self.__stopped = asyncio.Event()
|
||||
self._client: Optional[Client] = None
|
||||
self._message_id: Optional[str] = None
|
||||
|
||||
# The below is a bit of a hack to support items defined as class members
|
||||
# e.g. button = Button(...)
|
||||
for item in self.__class__.__dict__.values():
|
||||
if isinstance(item, Item):
|
||||
self.add_item(item)
|
||||
|
||||
@property
|
||||
def children(self) -> List[Item]:
|
||||
return self.__children
|
||||
|
||||
def add_item(self, item: Item):
|
||||
"""Adds an item to the view."""
|
||||
if not isinstance(item, Item):
|
||||
raise TypeError("Only instances of 'Item' can be added to a View.")
|
||||
|
||||
if len(self.__children) >= 25:
|
||||
raise ValueError("A view can only have a maximum of 25 components.")
|
||||
|
||||
if self.timeout is None and item.custom_id is None:
|
||||
raise ValueError(
|
||||
"All components in a persistent view must have a 'custom_id'."
|
||||
)
|
||||
|
||||
item._view = self
|
||||
self.__children.append(item)
|
||||
|
||||
@property
|
||||
def message_id(self) -> Optional[str]:
|
||||
return self._message_id
|
||||
|
||||
@message_id.setter
|
||||
def message_id(self, value: str):
|
||||
self._message_id = value
|
||||
|
||||
def to_components(self) -> List[ActionRow]:
|
||||
"""Converts the view's children into a list of ActionRow components.
|
||||
|
||||
This retains the original, simple layout behaviour where each item is
|
||||
placed in its own :class:`ActionRow` to ensure backward compatibility.
|
||||
"""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
rows.append(ActionRow(components=[item]))
|
||||
|
||||
return rows
|
||||
|
||||
def layout_components_advanced(self) -> List[ActionRow]:
|
||||
"""Group compatible components into rows following Discord rules."""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
if item.custom_id is None:
|
||||
item.custom_id = (
|
||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||
)
|
||||
|
||||
target_row = item.row
|
||||
if target_row is not None:
|
||||
if not 0 <= target_row <= 4:
|
||||
raise ValueError("Row index must be between 0 and 4.")
|
||||
|
||||
while len(rows) <= target_row:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
rows.append(ActionRow())
|
||||
|
||||
rows[target_row].add_component(item)
|
||||
continue
|
||||
|
||||
placed = False
|
||||
for row in rows:
|
||||
try:
|
||||
row.add_component(item)
|
||||
placed = True
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not placed:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
new_row = ActionRow([item])
|
||||
rows.append(new_row)
|
||||
|
||||
return rows
|
||||
|
||||
def to_components_payload(self) -> List[Dict[str, Any]]:
|
||||
"""Converts the view's children into a list of component dictionaries
|
||||
that can be sent to the Discord API."""
|
||||
return [row.to_dict() for row in self.to_components()]
|
||||
|
||||
async def _dispatch(self, interaction: Interaction):
|
||||
"""Called by the client to dispatch an interaction to the correct item."""
|
||||
if self.timeout is not None:
|
||||
self.__stopped.set() # Reset the timeout on each interaction
|
||||
self.__stopped.clear()
|
||||
|
||||
if interaction.data:
|
||||
custom_id = interaction.data.custom_id
|
||||
for child in self.children:
|
||||
if child.custom_id == custom_id:
|
||||
if child.callback:
|
||||
await child.callback(self, interaction)
|
||||
break
|
||||
|
||||
async def wait(self) -> bool:
|
||||
"""Waits until the view has stopped interacting."""
|
||||
return await self.__stopped.wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stops the view from listening to interactions."""
|
||||
if not self.__stopped.is_set():
|
||||
self.__stopped.set()
|
||||
|
||||
async def on_timeout(self):
|
||||
"""Called when the view times out."""
|
||||
pass # User can override this
|
||||
|
||||
async def _start(self, client: Client):
|
||||
"""Starts the view's internal listener."""
|
||||
self._client = client
|
||||
if self.timeout is not None:
|
||||
asyncio.create_task(self._timeout_task())
|
||||
|
||||
async def _timeout_task(self):
|
||||
"""The task that waits for the timeout and then stops the view."""
|
||||
try:
|
||||
await asyncio.wait_for(self.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self.stop()
|
||||
await self.on_timeout()
|
||||
if self._client and self._message_id:
|
||||
# Remove the view from the client's listeners
|
||||
self._client._views.pop(self._message_id, None)
|
||||
|
@ -1,162 +1,244 @@
|
||||
# disagreement/voice_client.py
|
||||
"""Voice gateway and UDP audio client."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .audio import AudioSource, FFmpegAudioSource
|
||||
|
||||
|
||||
class VoiceClient:
|
||||
"""Handles the Discord voice WebSocket connection and UDP streaming."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
session_id: str,
|
||||
token: str,
|
||||
guild_id: int,
|
||||
user_id: int,
|
||||
*,
|
||||
ws=None,
|
||||
udp: Optional[socket.socket] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.session_id = session_id
|
||||
self.token = token
|
||||
self.guild_id = str(guild_id)
|
||||
self.user_id = str(user_id)
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
|
||||
self._udp = udp
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_interval: Optional[float] = None
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self.verbose = verbose
|
||||
self.ssrc: Optional[int] = None
|
||||
self.secret_key: Optional[Sequence[int]] = None
|
||||
self._server_ip: Optional[str] = None
|
||||
self._server_port: Optional[int] = None
|
||||
self._current_source: Optional[AudioSource] = None
|
||||
self._play_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._ws is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._ws = await self._session.ws_connect(self.endpoint)
|
||||
|
||||
hello = await self._ws.receive_json()
|
||||
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
|
||||
self._heartbeat_task = self._loop.create_task(self._heartbeat())
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 0,
|
||||
"d": {
|
||||
"server_id": self.guild_id,
|
||||
"user_id": self.user_id,
|
||||
"session_id": self.session_id,
|
||||
"token": self.token,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ready = await self._ws.receive_json()
|
||||
data = ready["d"]
|
||||
self.ssrc = data["ssrc"]
|
||||
self._server_ip = data["ip"]
|
||||
self._server_port = data["port"]
|
||||
|
||||
if self._udp is None:
|
||||
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._udp.connect((self._server_ip, self._server_port))
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 1,
|
||||
"d": {
|
||||
"protocol": "udp",
|
||||
"data": {
|
||||
"address": self._udp.getsockname()[0],
|
||||
"port": self._udp.getsockname()[1],
|
||||
"mode": "xsalsa20_poly1305",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
session_desc = await self._ws.receive_json()
|
||||
self.secret_key = session_desc["d"].get("secret_key")
|
||||
|
||||
async def _heartbeat(self) -> None:
|
||||
assert self._ws is not None
|
||||
assert self._heartbeat_interval is not None
|
||||
try:
|
||||
while True:
|
||||
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
|
||||
await asyncio.sleep(self._heartbeat_interval)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def send_audio_frame(self, frame: bytes) -> None:
|
||||
if not self._udp:
|
||||
raise RuntimeError("UDP socket not initialised")
|
||||
self._udp.send(frame)
|
||||
|
||||
async def _play_loop(self) -> None:
|
||||
assert self._current_source is not None
|
||||
try:
|
||||
while True:
|
||||
data = await self._current_source.read()
|
||||
if not data:
|
||||
break
|
||||
await self.send_audio_frame(data)
|
||||
finally:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
self._play_task = None
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._play_task:
|
||||
self._play_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._play_task
|
||||
self._play_task = None
|
||||
if self._current_source:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
|
||||
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
|
||||
"""|coro| Play an :class:`AudioSource` on the voice connection."""
|
||||
|
||||
await self.stop()
|
||||
self._current_source = source
|
||||
self._play_task = self._loop.create_task(self._play_loop())
|
||||
if wait:
|
||||
await self._play_task
|
||||
|
||||
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
||||
"""|coro| Stream an audio file or URL using FFmpeg."""
|
||||
|
||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.stop()
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._heartbeat_task
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
if self._udp:
|
||||
self._udp.close()
|
||||
# disagreement/voice_client.py
|
||||
"""Voice gateway and UDP audio client."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Optional, Sequence
|
||||
|
||||
import aiohttp
|
||||
# The following import is correct, but may be flagged by Pylance if the virtual
|
||||
# environment is not configured correctly.
|
||||
from nacl.secret import SecretBox
|
||||
|
||||
from .audio import AudioSink, AudioSource, FFmpegAudioSource
|
||||
from .models import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
|
||||
|
||||
class VoiceClient:
|
||||
"""Handles the Discord voice WebSocket connection and UDP streaming."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
endpoint: str,
|
||||
session_id: str,
|
||||
token: str,
|
||||
guild_id: int,
|
||||
user_id: int,
|
||||
*,
|
||||
ws=None,
|
||||
udp: Optional[socket.socket] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.endpoint = endpoint
|
||||
self.session_id = session_id
|
||||
self.token = token
|
||||
self.guild_id = str(guild_id)
|
||||
self.user_id = str(user_id)
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
|
||||
self._udp = udp
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._udp_receive_thread: Optional[threading.Thread] = None
|
||||
self._heartbeat_interval: Optional[float] = None
|
||||
try:
|
||||
self._loop = loop or asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self.verbose = verbose
|
||||
self.ssrc: Optional[int] = None
|
||||
self.secret_key: Optional[Sequence[int]] = None
|
||||
self._server_ip: Optional[str] = None
|
||||
self._server_port: Optional[int] = None
|
||||
self._current_source: Optional[AudioSource] = None
|
||||
self._play_task: Optional[asyncio.Task] = None
|
||||
self._sink: Optional[AudioSink] = None
|
||||
self._ssrc_map: dict[int, int] = {}
|
||||
self._ssrc_lock = threading.Lock()
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._ws is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._ws = await self._session.ws_connect(self.endpoint)
|
||||
|
||||
hello = await self._ws.receive_json()
|
||||
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
|
||||
self._heartbeat_task = self._loop.create_task(self._heartbeat())
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 0,
|
||||
"d": {
|
||||
"server_id": self.guild_id,
|
||||
"user_id": self.user_id,
|
||||
"session_id": self.session_id,
|
||||
"token": self.token,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ready = await self._ws.receive_json()
|
||||
data = ready["d"]
|
||||
self.ssrc = data["ssrc"]
|
||||
self._server_ip = data["ip"]
|
||||
self._server_port = data["port"]
|
||||
|
||||
if self._udp is None:
|
||||
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._udp.connect((self._server_ip, self._server_port))
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 1,
|
||||
"d": {
|
||||
"protocol": "udp",
|
||||
"data": {
|
||||
"address": self._udp.getsockname()[0],
|
||||
"port": self._udp.getsockname()[1],
|
||||
"mode": "xsalsa20_poly1305",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
session_desc = await self._ws.receive_json()
|
||||
self.secret_key = session_desc["d"].get("secret_key")
|
||||
|
||||
async def _heartbeat(self) -> None:
|
||||
assert self._ws is not None
|
||||
assert self._heartbeat_interval is not None
|
||||
try:
|
||||
while True:
|
||||
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
|
||||
await asyncio.sleep(self._heartbeat_interval)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
assert self._ws is not None
|
||||
while True:
|
||||
try:
|
||||
msg = await self._ws.receive_json()
|
||||
op = msg.get("op")
|
||||
data = msg.get("d")
|
||||
if op == 5: # Speaking
|
||||
user_id = int(data["user_id"])
|
||||
ssrc = data["ssrc"]
|
||||
with self._ssrc_lock:
|
||||
self._ssrc_map[ssrc] = user_id
|
||||
except (asyncio.CancelledError, aiohttp.ClientError):
|
||||
break
|
||||
|
||||
def _udp_receive_loop(self) -> None:
|
||||
assert self._udp is not None
|
||||
assert self.secret_key is not None
|
||||
box = SecretBox(bytes(self.secret_key))
|
||||
while True:
|
||||
try:
|
||||
packet = self._udp.recv(4096)
|
||||
if len(packet) < 12:
|
||||
continue
|
||||
|
||||
ssrc = int.from_bytes(packet[8:12], "big")
|
||||
with self._ssrc_lock:
|
||||
if ssrc not in self._ssrc_map:
|
||||
continue
|
||||
user_id = self._ssrc_map[ssrc]
|
||||
user = self.client._users.get(str(user_id))
|
||||
if not user:
|
||||
continue
|
||||
|
||||
decrypted = box.decrypt(packet[12:])
|
||||
if self._sink:
|
||||
self._sink.write(user, decrypted)
|
||||
except (socket.error, asyncio.CancelledError):
|
||||
break
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error in UDP receive loop: {e}")
|
||||
|
||||
async def send_audio_frame(self, frame: bytes) -> None:
|
||||
if not self._udp:
|
||||
raise RuntimeError("UDP socket not initialised")
|
||||
self._udp.send(frame)
|
||||
|
||||
async def _play_loop(self) -> None:
|
||||
assert self._current_source is not None
|
||||
try:
|
||||
while True:
|
||||
data = await self._current_source.read()
|
||||
if not data:
|
||||
break
|
||||
await self.send_audio_frame(data)
|
||||
finally:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
self._play_task = None
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._play_task:
|
||||
self._play_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._play_task
|
||||
self._play_task = None
|
||||
if self._current_source:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
|
||||
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
|
||||
"""|coro| Play an :class:`AudioSource` on the voice connection."""
|
||||
|
||||
await self.stop()
|
||||
self._current_source = source
|
||||
self._play_task = self._loop.create_task(self._play_loop())
|
||||
if wait:
|
||||
await self._play_task
|
||||
|
||||
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
||||
"""|coro| Stream an audio file or URL using FFmpeg."""
|
||||
|
||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||
|
||||
def listen(self, sink: AudioSink) -> None:
|
||||
"""Start listening to voice and routing to a sink."""
|
||||
if not isinstance(sink, AudioSink):
|
||||
raise TypeError("sink must be an AudioSink instance")
|
||||
|
||||
self._sink = sink
|
||||
if not self._udp_receive_thread:
|
||||
self._udp_receive_thread = threading.Thread(
|
||||
target=self._udp_receive_loop, daemon=True
|
||||
)
|
||||
self._udp_receive_thread.start()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.stop()
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._heartbeat_task
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._receive_task
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
if self._udp:
|
||||
self._udp.close()
|
||||
if self._udp_receive_thread:
|
||||
self._udp_receive_thread.join(timeout=1)
|
||||
if self._sink:
|
||||
self._sink.close()
|
||||
|
113
pyproject.toml
113
pyproject.toml
@ -1,56 +1,57 @@
|
||||
[project]
|
||||
name = "disagreement"
|
||||
version = "0.2.0rc1"
|
||||
description = "A Python library for the Discord API."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "BSD 3-Clause"}
|
||||
authors = [
|
||||
{name = "Slipstream", email = "me@slipstreamm.dev"}
|
||||
]
|
||||
keywords = ["discord", "api", "bot", "async", "aiohttp"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Internet",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"hypothesis>=6.132.0",
|
||||
]
|
||||
dev = [
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Slipstreamm/disagreement"
|
||||
Issues = "https://github.com/Slipstreamm/disagreement/issues"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
# Optional: for linting/formatting, e.g., Ruff
|
||||
# [tool.ruff]
|
||||
# line-length = 88
|
||||
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
|
||||
# ignore = []
|
||||
|
||||
# [tool.ruff.format]
|
||||
# quote-style = "double"
|
||||
[project]
|
||||
name = "disagreement"
|
||||
version = "0.2.0rc1"
|
||||
description = "A Python library for the Discord API."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "BSD 3-Clause"}
|
||||
authors = [
|
||||
{name = "Slipstream", email = "me@slipstreamm.dev"}
|
||||
]
|
||||
keywords = ["discord", "api", "bot", "async", "aiohttp"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Internet",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"PyNaCl>=1.5.0,<2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"hypothesis>=6.132.0",
|
||||
]
|
||||
dev = [
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Slipstreamm/disagreement"
|
||||
Issues = "https://github.com/Slipstreamm/disagreement/issues"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
# Optional: for linting/formatting, e.g., Ruff
|
||||
# [tool.ruff]
|
||||
# line-length = 88
|
||||
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
|
||||
# ignore = []
|
||||
|
||||
# [tool.ruff.format]
|
||||
# quote-style = "double"
|
||||
|
Loading…
x
Reference in New Issue
Block a user