Compare commits
8 Commits
be85444aa0
...
ed83a9da85
Author | SHA1 | Date | |
---|---|---|---|
ed83a9da85 | |||
0151526d07 | |||
17b7ea35a9 | |||
28702fa8a1 | |||
97505948ee | |||
152c0f12be | |||
eb38ecf671 | |||
2bd45c87ca |
@ -114,3 +114,20 @@ class FFmpegAudioSource(AudioSource):
|
|||||||
if isinstance(self.source, io.IOBase):
|
if isinstance(self.source, io.IOBase):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self.source.close()
|
self.source.close()
|
||||||
|
|
||||||
|
class AudioSink:
|
||||||
|
"""Abstract base class for audio sinks."""
|
||||||
|
|
||||||
|
def write(self, user, data):
|
||||||
|
"""Write a chunk of PCM audio.
|
||||||
|
|
||||||
|
Subclasses must implement this. The data is raw PCM at 48kHz
|
||||||
|
stereo.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Cleanup the sink when the voice client disconnects."""
|
||||||
|
|
||||||
|
return None
|
||||||
|
@ -4,7 +4,8 @@ import time
|
|||||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .models import Channel, Guild
|
from .models import Channel, Guild, Member
|
||||||
|
from .caching import MemberCacheFlags
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -53,3 +54,32 @@ class GuildCache(Cache["Guild"]):
|
|||||||
|
|
||||||
class ChannelCache(Cache["Channel"]):
|
class ChannelCache(Cache["Channel"]):
|
||||||
"""Cache specifically for :class:`Channel` objects."""
|
"""Cache specifically for :class:`Channel` objects."""
|
||||||
|
|
||||||
|
|
||||||
|
class MemberCache(Cache["Member"]):
|
||||||
|
"""
|
||||||
|
A cache for :class:`Member` objects that respects :class:`MemberCacheFlags`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, flags: MemberCacheFlags, ttl: Optional[float] = None) -> None:
|
||||||
|
super().__init__(ttl)
|
||||||
|
self.flags = flags
|
||||||
|
|
||||||
|
def _should_cache(self, member: Member) -> bool:
|
||||||
|
"""Determines if a member should be cached based on the flags."""
|
||||||
|
if self.flags.all:
|
||||||
|
return True
|
||||||
|
if self.flags.none:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.flags.online and member.status != "offline":
|
||||||
|
return True
|
||||||
|
if self.flags.voice and member.voice_state is not None:
|
||||||
|
return True
|
||||||
|
if self.flags.joined and getattr(member, "_just_joined", False):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def set(self, key: str, value: Member) -> None:
|
||||||
|
if self._should_cache(value):
|
||||||
|
super().set(key, value)
|
||||||
|
120
disagreement/caching.py
Normal file
120
disagreement/caching.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import operator
|
||||||
|
from typing import Any, Callable, ClassVar, Dict, Iterator, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class _MemberCacheFlagValue:
|
||||||
|
flag: int
|
||||||
|
|
||||||
|
def __init__(self, func: Callable[[Any], bool]):
|
||||||
|
self.flag = getattr(func, 'flag', 0)
|
||||||
|
self.__doc__ = func.__doc__
|
||||||
|
|
||||||
|
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any:
|
||||||
|
if instance is None:
|
||||||
|
return self
|
||||||
|
return instance.value & self.flag != 0
|
||||||
|
|
||||||
|
def __set__(self, instance: Any, value: bool) -> None:
|
||||||
|
if value:
|
||||||
|
instance.value |= self.flag
|
||||||
|
else:
|
||||||
|
instance.value &= ~self.flag
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'<{self.__class__.__name__} flag={self.flag}>'
|
||||||
|
|
||||||
|
|
||||||
|
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
|
||||||
|
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
|
||||||
|
setattr(func, 'flag', flag)
|
||||||
|
return _MemberCacheFlagValue(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class MemberCacheFlags:
|
||||||
|
__slots__ = ('value',)
|
||||||
|
|
||||||
|
VALID_FLAGS: ClassVar[Dict[str, int]] = {
|
||||||
|
'joined': 1 << 0,
|
||||||
|
'voice': 1 << 1,
|
||||||
|
'online': 1 << 2,
|
||||||
|
}
|
||||||
|
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
|
||||||
|
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: bool):
|
||||||
|
self.value = self.DEFAULT_FLAGS
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key not in self.VALID_FLAGS:
|
||||||
|
raise TypeError(f'{key!r} is not a valid member cache flag.')
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_value(cls, value: int) -> MemberCacheFlags:
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
self.value = value
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
return isinstance(other, MemberCacheFlags) and self.value == other.value
|
||||||
|
|
||||||
|
def __ne__(self, other: object) -> bool:
|
||||||
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.value)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'<MemberCacheFlags value={self.value}>'
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Tuple[str, bool]]:
|
||||||
|
for name in self.VALID_FLAGS:
|
||||||
|
yield name, getattr(self, name)
|
||||||
|
|
||||||
|
def __int__(self) -> int:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def __index__(self) -> int:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all(cls) -> MemberCacheFlags:
|
||||||
|
"""A factory method that creates a :class:`MemberCacheFlags` with all flags enabled."""
|
||||||
|
return cls._from_value(cls.ALL_FLAGS)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def none(cls) -> MemberCacheFlags:
|
||||||
|
"""A factory method that creates a :class:`MemberCacheFlags` with all flags disabled."""
|
||||||
|
return cls._from_value(0)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def only_joined(cls) -> MemberCacheFlags:
|
||||||
|
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
|
||||||
|
return cls._from_value(cls.VALID_FLAGS['joined'])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def only_voice(cls) -> MemberCacheFlags:
|
||||||
|
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
|
||||||
|
return cls._from_value(cls.VALID_FLAGS['voice'])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def only_online(cls) -> MemberCacheFlags:
|
||||||
|
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
|
||||||
|
return cls._from_value(cls.VALID_FLAGS['online'])
|
||||||
|
|
||||||
|
@flag_value(1 << 0)
|
||||||
|
def joined(self) -> bool:
|
||||||
|
"""Whether to cache members that have just joined the guild."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@flag_value(1 << 1)
|
||||||
|
def voice(self) -> bool:
|
||||||
|
"""Whether to cache members that are in a voice channel."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@flag_value(1 << 2)
|
||||||
|
def online(self) -> bool:
|
||||||
|
"""Whether to cache members that are online."""
|
||||||
|
return False
|
File diff suppressed because it is too large
Load Diff
@ -76,7 +76,7 @@ class EventDispatcher:
|
|||||||
"""Parses MESSAGE_DELETE and updates message cache."""
|
"""Parses MESSAGE_DELETE and updates message cache."""
|
||||||
message_id = data.get("id")
|
message_id = data.get("id")
|
||||||
if message_id:
|
if message_id:
|
||||||
self._client._messages.pop(message_id, None)
|
self._client._messages.invalidate(message_id)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
@ -124,7 +124,7 @@ class EventDispatcher:
|
|||||||
"""Parses GUILD_MEMBER_ADD into a Member object."""
|
"""Parses GUILD_MEMBER_ADD into a Member object."""
|
||||||
|
|
||||||
guild_id = str(data.get("guild_id"))
|
guild_id = str(data.get("guild_id"))
|
||||||
return self._client.parse_member(data, guild_id)
|
return self._client.parse_member(data, guild_id, just_joined=True)
|
||||||
|
|
||||||
def _parse_guild_member_remove(self, data: Dict[str, Any]):
|
def _parse_guild_member_remove(self, data: Dict[str, Any]):
|
||||||
"""Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model."""
|
"""Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model."""
|
||||||
|
@ -1,61 +1,65 @@
|
|||||||
# disagreement/ext/commands/__init__.py
|
# disagreement/ext/commands/__init__.py
|
||||||
|
|
||||||
"""
|
"""
|
||||||
disagreement.ext.commands - A command framework extension for the Disagreement library.
|
disagreement.ext.commands - A command framework extension for the Disagreement library.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .cog import Cog
|
from .cog import Cog
|
||||||
from .core import (
|
from .core import (
|
||||||
Command,
|
Command,
|
||||||
CommandContext,
|
CommandContext,
|
||||||
CommandHandler,
|
CommandHandler,
|
||||||
) # CommandHandler might be internal
|
) # CommandHandler might be internal
|
||||||
from .decorators import (
|
from .decorators import (
|
||||||
command,
|
command,
|
||||||
listener,
|
listener,
|
||||||
check,
|
check,
|
||||||
check_any,
|
check_any,
|
||||||
cooldown,
|
cooldown,
|
||||||
max_concurrency,
|
max_concurrency,
|
||||||
requires_permissions,
|
requires_permissions,
|
||||||
)
|
has_role,
|
||||||
from .errors import (
|
has_any_role,
|
||||||
CommandError,
|
)
|
||||||
CommandNotFound,
|
from .errors import (
|
||||||
BadArgument,
|
CommandError,
|
||||||
MissingRequiredArgument,
|
CommandNotFound,
|
||||||
ArgumentParsingError,
|
BadArgument,
|
||||||
CheckFailure,
|
MissingRequiredArgument,
|
||||||
CheckAnyFailure,
|
ArgumentParsingError,
|
||||||
CommandOnCooldown,
|
CheckFailure,
|
||||||
CommandInvokeError,
|
CheckAnyFailure,
|
||||||
MaxConcurrencyReached,
|
CommandOnCooldown,
|
||||||
)
|
CommandInvokeError,
|
||||||
|
MaxConcurrencyReached,
|
||||||
__all__ = [
|
)
|
||||||
# Cog
|
|
||||||
"Cog",
|
__all__ = [
|
||||||
# Core
|
# Cog
|
||||||
"Command",
|
"Cog",
|
||||||
"CommandContext",
|
# Core
|
||||||
# "CommandHandler", # Usually not part of public API for direct use by bot devs
|
"Command",
|
||||||
# Decorators
|
"CommandContext",
|
||||||
"command",
|
# "CommandHandler", # Usually not part of public API for direct use by bot devs
|
||||||
"listener",
|
# Decorators
|
||||||
"check",
|
"command",
|
||||||
"check_any",
|
"listener",
|
||||||
"cooldown",
|
"check",
|
||||||
"max_concurrency",
|
"check_any",
|
||||||
"requires_permissions",
|
"cooldown",
|
||||||
# Errors
|
"max_concurrency",
|
||||||
"CommandError",
|
"requires_permissions",
|
||||||
"CommandNotFound",
|
"has_role",
|
||||||
"BadArgument",
|
"has_any_role",
|
||||||
"MissingRequiredArgument",
|
# Errors
|
||||||
"ArgumentParsingError",
|
"CommandError",
|
||||||
"CheckFailure",
|
"CommandNotFound",
|
||||||
"CheckAnyFailure",
|
"BadArgument",
|
||||||
"CommandOnCooldown",
|
"MissingRequiredArgument",
|
||||||
"CommandInvokeError",
|
"ArgumentParsingError",
|
||||||
"MaxConcurrencyReached",
|
"CheckFailure",
|
||||||
]
|
"CheckAnyFailure",
|
||||||
|
"CommandOnCooldown",
|
||||||
|
"CommandInvokeError",
|
||||||
|
"MaxConcurrencyReached",
|
||||||
|
]
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,219 +1,298 @@
|
|||||||
# disagreement/ext/commands/decorators.py
|
# disagreement/ext/commands/decorators.py
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .core import Command, CommandContext
|
from .core import Command, CommandContext
|
||||||
from disagreement.permissions import Permissions
|
from disagreement.permissions import Permissions
|
||||||
from disagreement.models import Member, Guild, Channel
|
from disagreement.models import Member, Guild, Channel
|
||||||
|
|
||||||
|
|
||||||
def command(
|
def command(
|
||||||
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
|
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""
|
"""
|
||||||
A decorator that transforms a function into a Command.
|
A decorator that transforms a function into a Command.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (Optional[str]): The name of the command. Defaults to the function name.
|
name (Optional[str]): The name of the command. Defaults to the function name.
|
||||||
aliases (Optional[List[str]]): Alternative names for the command.
|
aliases (Optional[List[str]]): Alternative names for the command.
|
||||||
**attrs: Additional attributes to pass to the Command constructor
|
**attrs: Additional attributes to pass to the Command constructor
|
||||||
(e.g., brief, description, hidden).
|
(e.g., brief, description, hidden).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Callable: A decorator that registers the command.
|
Callable: A decorator that registers the command.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[..., Awaitable[None]],
|
func: Callable[..., Awaitable[None]],
|
||||||
) -> Callable[..., Awaitable[None]]:
|
) -> Callable[..., Awaitable[None]]:
|
||||||
if not asyncio.iscoroutinefunction(func):
|
if not asyncio.iscoroutinefunction(func):
|
||||||
raise TypeError("Command callback must be a coroutine function.")
|
raise TypeError("Command callback must be a coroutine function.")
|
||||||
|
|
||||||
from .core import Command
|
from .core import Command
|
||||||
|
|
||||||
cmd_name = name or func.__name__
|
cmd_name = name or func.__name__
|
||||||
|
|
||||||
if hasattr(func, "__command_attrs__"):
|
if hasattr(func, "__command_attrs__"):
|
||||||
raise TypeError("Function is already a command or has command attributes.")
|
raise TypeError("Function is already a command or has command attributes.")
|
||||||
|
|
||||||
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
||||||
func.__command_object__ = cmd # type: ignore
|
func.__command_object__ = cmd # type: ignore
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def listener(
|
def listener(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
"""
|
"""
|
||||||
A decorator that marks a function as an event listener within a Cog.
|
A decorator that marks a function as an event listener within a Cog.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[..., Awaitable[None]],
|
func: Callable[..., Awaitable[None]],
|
||||||
) -> Callable[..., Awaitable[None]]:
|
) -> Callable[..., Awaitable[None]]:
|
||||||
if not asyncio.iscoroutinefunction(func):
|
if not asyncio.iscoroutinefunction(func):
|
||||||
raise TypeError("Listener callback must be a coroutine function.")
|
raise TypeError("Listener callback must be a coroutine function.")
|
||||||
|
|
||||||
actual_event_name = name or func.__name__
|
actual_event_name = name or func.__name__
|
||||||
setattr(func, "__listener_name__", actual_event_name)
|
setattr(func, "__listener_name__", actual_event_name)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def check(
|
def check(
|
||||||
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
|
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
|
||||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
"""Decorator to add a check to a command."""
|
"""Decorator to add a check to a command."""
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[..., Awaitable[None]],
|
func: Callable[..., Awaitable[None]],
|
||||||
) -> Callable[..., Awaitable[None]]:
|
) -> Callable[..., Awaitable[None]]:
|
||||||
checks = getattr(func, "__command_checks__", [])
|
checks = getattr(func, "__command_checks__", [])
|
||||||
checks.append(predicate)
|
checks.append(predicate)
|
||||||
setattr(func, "__command_checks__", checks)
|
setattr(func, "__command_checks__", checks)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def check_any(
|
def check_any(
|
||||||
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
|
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
"""Decorator that passes if any predicate returns ``True``."""
|
"""Decorator that passes if any predicate returns ``True``."""
|
||||||
|
|
||||||
async def predicate(ctx: "CommandContext") -> bool:
|
async def predicate(ctx: "CommandContext") -> bool:
|
||||||
from .errors import CheckAnyFailure, CheckFailure
|
from .errors import CheckAnyFailure, CheckFailure
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
for p in predicates:
|
for p in predicates:
|
||||||
try:
|
try:
|
||||||
result = p(ctx)
|
result = p(ctx)
|
||||||
if inspect.isawaitable(result):
|
if inspect.isawaitable(result):
|
||||||
result = await result
|
result = await result
|
||||||
if result:
|
if result:
|
||||||
return True
|
return True
|
||||||
except CheckFailure as e:
|
except CheckFailure as e:
|
||||||
errors.append(e)
|
errors.append(e)
|
||||||
raise CheckAnyFailure(errors)
|
raise CheckAnyFailure(errors)
|
||||||
|
|
||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def max_concurrency(
|
def max_concurrency(
|
||||||
number: int, per: str = "user"
|
number: int, per: str = "user"
|
||||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
"""Limit how many concurrent invocations of a command are allowed.
|
"""Limit how many concurrent invocations of a command are allowed.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
number:
|
number:
|
||||||
The maximum number of concurrent invocations.
|
The maximum number of concurrent invocations.
|
||||||
per:
|
per:
|
||||||
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
|
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if number < 1:
|
if number < 1:
|
||||||
raise ValueError("Concurrency number must be at least 1.")
|
raise ValueError("Concurrency number must be at least 1.")
|
||||||
if per not in {"user", "guild", "global"}:
|
if per not in {"user", "guild", "global"}:
|
||||||
raise ValueError("per must be 'user', 'guild', or 'global'.")
|
raise ValueError("per must be 'user', 'guild', or 'global'.")
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[..., Awaitable[None]],
|
func: Callable[..., Awaitable[None]],
|
||||||
) -> Callable[..., Awaitable[None]]:
|
) -> Callable[..., Awaitable[None]]:
|
||||||
setattr(func, "__max_concurrency__", (number, per))
|
setattr(func, "__max_concurrency__", (number, per))
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def cooldown(
|
def cooldown(
|
||||||
rate: int, per: float
|
rate: int, per: float
|
||||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
"""Simple per-user cooldown decorator."""
|
"""Simple per-user cooldown decorator."""
|
||||||
|
|
||||||
buckets: dict[str, dict[str, float]] = {}
|
buckets: dict[str, dict[str, float]] = {}
|
||||||
|
|
||||||
async def predicate(ctx: "CommandContext") -> bool:
|
async def predicate(ctx: "CommandContext") -> bool:
|
||||||
from .errors import CommandOnCooldown
|
from .errors import CommandOnCooldown
|
||||||
|
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
user_buckets = buckets.setdefault(ctx.command.name, {})
|
user_buckets = buckets.setdefault(ctx.command.name, {})
|
||||||
reset = user_buckets.get(ctx.author.id, 0)
|
reset = user_buckets.get(ctx.author.id, 0)
|
||||||
if now < reset:
|
if now < reset:
|
||||||
raise CommandOnCooldown(reset - now)
|
raise CommandOnCooldown(reset - now)
|
||||||
user_buckets[ctx.author.id] = now + per
|
user_buckets[ctx.author.id] = now + per
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def _compute_permissions(
|
def _compute_permissions(
|
||||||
member: "Member", channel: "Channel", guild: "Guild"
|
member: "Member", channel: "Channel", guild: "Guild"
|
||||||
) -> "Permissions":
|
) -> "Permissions":
|
||||||
"""Compute the effective permissions for a member in a channel."""
|
"""Compute the effective permissions for a member in a channel."""
|
||||||
return channel.permissions_for(member)
|
return channel.permissions_for(member)
|
||||||
|
|
||||||
|
|
||||||
def requires_permissions(
|
def requires_permissions(
|
||||||
*perms: "Permissions",
|
*perms: "Permissions",
|
||||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
"""Check that the invoking member has the given permissions in the channel."""
|
"""Check that the invoking member has the given permissions in the channel."""
|
||||||
|
|
||||||
async def predicate(ctx: "CommandContext") -> bool:
|
async def predicate(ctx: "CommandContext") -> bool:
|
||||||
from .errors import CheckFailure
|
from .errors import CheckFailure
|
||||||
from disagreement.permissions import (
|
from disagreement.permissions import (
|
||||||
has_permissions,
|
has_permissions,
|
||||||
missing_permissions,
|
missing_permissions,
|
||||||
)
|
)
|
||||||
from disagreement.models import Member
|
from disagreement.models import Member
|
||||||
|
|
||||||
channel = getattr(ctx, "channel", None)
|
channel = getattr(ctx, "channel", None)
|
||||||
if channel is None and hasattr(ctx.bot, "get_channel"):
|
if channel is None and hasattr(ctx.bot, "get_channel"):
|
||||||
channel = ctx.bot.get_channel(ctx.message.channel_id)
|
channel = ctx.bot.get_channel(ctx.message.channel_id)
|
||||||
if channel is None and hasattr(ctx.bot, "fetch_channel"):
|
if channel is None and hasattr(ctx.bot, "fetch_channel"):
|
||||||
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
|
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
|
||||||
|
|
||||||
if channel is None:
|
if channel is None:
|
||||||
raise CheckFailure("Channel for permission check not found.")
|
raise CheckFailure("Channel for permission check not found.")
|
||||||
|
|
||||||
guild = getattr(channel, "guild", None)
|
guild = getattr(channel, "guild", None)
|
||||||
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
|
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
|
||||||
if hasattr(ctx.bot, "get_guild"):
|
if hasattr(ctx.bot, "get_guild"):
|
||||||
guild = ctx.bot.get_guild(channel.guild_id)
|
guild = ctx.bot.get_guild(channel.guild_id)
|
||||||
if not guild and hasattr(ctx.bot, "fetch_guild"):
|
if not guild and hasattr(ctx.bot, "fetch_guild"):
|
||||||
guild = await ctx.bot.fetch_guild(channel.guild_id)
|
guild = await ctx.bot.fetch_guild(channel.guild_id)
|
||||||
|
|
||||||
if not guild:
|
if not guild:
|
||||||
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
|
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
|
||||||
if is_dm:
|
if is_dm:
|
||||||
if perms:
|
if perms:
|
||||||
raise CheckFailure("Permission checks are not supported in DMs.")
|
raise CheckFailure("Permission checks are not supported in DMs.")
|
||||||
return True
|
return True
|
||||||
raise CheckFailure("Guild for permission check not found.")
|
raise CheckFailure("Guild for permission check not found.")
|
||||||
|
|
||||||
member = ctx.author
|
member = ctx.author
|
||||||
if not isinstance(member, Member):
|
if not isinstance(member, Member):
|
||||||
member = guild.get_member(ctx.author.id)
|
member = guild.get_member(ctx.author.id)
|
||||||
if not member and hasattr(ctx.bot, "fetch_member"):
|
if not member and hasattr(ctx.bot, "fetch_member"):
|
||||||
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
|
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
|
||||||
|
|
||||||
if not member:
|
if not member:
|
||||||
raise CheckFailure("Could not resolve author to a guild member.")
|
raise CheckFailure("Could not resolve author to a guild member.")
|
||||||
|
|
||||||
perms_value = _compute_permissions(member, channel, guild)
|
perms_value = _compute_permissions(member, channel, guild)
|
||||||
|
|
||||||
if not has_permissions(perms_value, *perms):
|
if not has_permissions(perms_value, *perms):
|
||||||
missing = missing_permissions(perms_value, *perms)
|
missing = missing_permissions(perms_value, *perms)
|
||||||
missing_names = ", ".join(p.name for p in missing if p.name)
|
missing_names = ", ".join(p.name for p in missing if p.name)
|
||||||
raise CheckFailure(f"Missing permissions: {missing_names}")
|
raise CheckFailure(f"Missing permissions: {missing_names}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
def has_role(
|
||||||
|
name_or_id: str | int,
|
||||||
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
|
"""Check that the invoking member has a role with the given name or ID."""
|
||||||
|
|
||||||
|
async def predicate(ctx: "CommandContext") -> bool:
|
||||||
|
from .errors import CheckFailure
|
||||||
|
from disagreement.models import Member
|
||||||
|
|
||||||
|
if not ctx.guild:
|
||||||
|
raise CheckFailure("This command cannot be used in DMs.")
|
||||||
|
|
||||||
|
author = ctx.author
|
||||||
|
if not isinstance(author, Member):
|
||||||
|
try:
|
||||||
|
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
|
||||||
|
except Exception:
|
||||||
|
raise CheckFailure("Could not resolve author to a guild member.")
|
||||||
|
|
||||||
|
if not author:
|
||||||
|
raise CheckFailure("Could not resolve author to a guild member.")
|
||||||
|
|
||||||
|
# Create a list of the member's role objects by looking them up in the guild's roles list
|
||||||
|
member_roles = [
|
||||||
|
role for role in ctx.guild.roles if role.id in author.roles
|
||||||
|
]
|
||||||
|
|
||||||
|
if any(
|
||||||
|
role.id == str(name_or_id) or role.name == name_or_id
|
||||||
|
for role in member_roles
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
raise CheckFailure(f"You need the '{name_or_id}' role to use this command.")
|
||||||
|
|
||||||
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
|
def has_any_role(
|
||||||
|
*names_or_ids: str | int,
|
||||||
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
|
"""Check that the invoking member has any of the roles with the given names or IDs."""
|
||||||
|
|
||||||
|
async def predicate(ctx: "CommandContext") -> bool:
|
||||||
|
from .errors import CheckFailure
|
||||||
|
from disagreement.models import Member
|
||||||
|
|
||||||
|
if not ctx.guild:
|
||||||
|
raise CheckFailure("This command cannot be used in DMs.")
|
||||||
|
|
||||||
|
author = ctx.author
|
||||||
|
if not isinstance(author, Member):
|
||||||
|
try:
|
||||||
|
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
|
||||||
|
except Exception:
|
||||||
|
raise CheckFailure("Could not resolve author to a guild member.")
|
||||||
|
|
||||||
|
if not author:
|
||||||
|
raise CheckFailure("Could not resolve author to a guild member.")
|
||||||
|
|
||||||
|
member_roles = [
|
||||||
|
role for role in ctx.guild.roles if role.id in author.roles
|
||||||
|
]
|
||||||
|
# Convert names_or_ids to a set for efficient lookup
|
||||||
|
names_or_ids_set = set(map(str, names_or_ids))
|
||||||
|
|
||||||
|
if any(
|
||||||
|
role.id in names_or_ids_set or role.name in names_or_ids_set
|
||||||
|
for role in member_roles
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
role_list = ", ".join(f"'{r}'" for r in names_or_ids)
|
||||||
|
raise CheckFailure(
|
||||||
|
f"You need one of the following roles to use this command: {role_list}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return check(predicate)
|
||||||
|
File diff suppressed because it is too large
Load Diff
2161
disagreement/http.py
2161
disagreement/http.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,165 +1,167 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
|
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from ..models import ActionRow
|
from ..models import ActionRow
|
||||||
from .item import Item
|
from .item import Item
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import Client
|
from ..client import Client
|
||||||
from ..interactions import Interaction
|
from ..interactions import Interaction
|
||||||
|
|
||||||
|
|
||||||
class View:
|
class View:
|
||||||
"""Represents a container for UI components that can be sent with a message.
|
"""Represents a container for UI components that can be sent with a message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
|
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
|
||||||
Defaults to 180.
|
Defaults to 180.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, timeout: Optional[float] = 180.0):
|
def __init__(self, *, timeout: Optional[float] = 180.0):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.id = str(uuid.uuid4())
|
self.id = str(uuid.uuid4())
|
||||||
self.__children: List[Item] = []
|
self.__children: List[Item] = []
|
||||||
self.__stopped = asyncio.Event()
|
self.__stopped = asyncio.Event()
|
||||||
self._client: Optional[Client] = None
|
self._client: Optional[Client] = None
|
||||||
self._message_id: Optional[str] = None
|
self._message_id: Optional[str] = None
|
||||||
|
|
||||||
for item in self.__class__.__dict__.values():
|
# The below is a bit of a hack to support items defined as class members
|
||||||
if isinstance(item, Item):
|
# e.g. button = Button(...)
|
||||||
self.add_item(item)
|
for item in self.__class__.__dict__.values():
|
||||||
|
if isinstance(item, Item):
|
||||||
@property
|
self.add_item(item)
|
||||||
def children(self) -> List[Item]:
|
|
||||||
return self.__children
|
@property
|
||||||
|
def children(self) -> List[Item]:
|
||||||
def add_item(self, item: Item):
|
return self.__children
|
||||||
"""Adds an item to the view."""
|
|
||||||
if not isinstance(item, Item):
|
def add_item(self, item: Item):
|
||||||
raise TypeError("Only instances of 'Item' can be added to a View.")
|
"""Adds an item to the view."""
|
||||||
|
if not isinstance(item, Item):
|
||||||
if len(self.__children) >= 25:
|
raise TypeError("Only instances of 'Item' can be added to a View.")
|
||||||
raise ValueError("A view can only have a maximum of 25 components.")
|
|
||||||
|
if len(self.__children) >= 25:
|
||||||
item._view = self
|
raise ValueError("A view can only have a maximum of 25 components.")
|
||||||
self.__children.append(item)
|
|
||||||
|
if self.timeout is None and item.custom_id is None:
|
||||||
@property
|
raise ValueError(
|
||||||
def message_id(self) -> Optional[str]:
|
"All components in a persistent view must have a 'custom_id'."
|
||||||
return self._message_id
|
)
|
||||||
|
|
||||||
@message_id.setter
|
item._view = self
|
||||||
def message_id(self, value: str):
|
self.__children.append(item)
|
||||||
self._message_id = value
|
|
||||||
|
@property
|
||||||
def to_components(self) -> List[ActionRow]:
|
def message_id(self) -> Optional[str]:
|
||||||
"""Converts the view's children into a list of ActionRow components.
|
return self._message_id
|
||||||
|
|
||||||
This retains the original, simple layout behaviour where each item is
|
@message_id.setter
|
||||||
placed in its own :class:`ActionRow` to ensure backward compatibility.
|
def message_id(self, value: str):
|
||||||
"""
|
self._message_id = value
|
||||||
|
|
||||||
rows: List[ActionRow] = []
|
def to_components(self) -> List[ActionRow]:
|
||||||
|
"""Converts the view's children into a list of ActionRow components.
|
||||||
for item in self.children:
|
|
||||||
if item.custom_id is None:
|
This retains the original, simple layout behaviour where each item is
|
||||||
item.custom_id = (
|
placed in its own :class:`ActionRow` to ensure backward compatibility.
|
||||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
"""
|
||||||
)
|
|
||||||
|
rows: List[ActionRow] = []
|
||||||
rows.append(ActionRow(components=[item]))
|
|
||||||
|
for item in self.children:
|
||||||
return rows
|
rows.append(ActionRow(components=[item]))
|
||||||
|
|
||||||
def layout_components_advanced(self) -> List[ActionRow]:
|
return rows
|
||||||
"""Group compatible components into rows following Discord rules."""
|
|
||||||
|
def layout_components_advanced(self) -> List[ActionRow]:
|
||||||
rows: List[ActionRow] = []
|
"""Group compatible components into rows following Discord rules."""
|
||||||
|
|
||||||
for item in self.children:
|
rows: List[ActionRow] = []
|
||||||
if item.custom_id is None:
|
|
||||||
item.custom_id = (
|
for item in self.children:
|
||||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
if item.custom_id is None:
|
||||||
)
|
item.custom_id = (
|
||||||
|
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||||
target_row = item.row
|
)
|
||||||
if target_row is not None:
|
|
||||||
if not 0 <= target_row <= 4:
|
target_row = item.row
|
||||||
raise ValueError("Row index must be between 0 and 4.")
|
if target_row is not None:
|
||||||
|
if not 0 <= target_row <= 4:
|
||||||
while len(rows) <= target_row:
|
raise ValueError("Row index must be between 0 and 4.")
|
||||||
if len(rows) >= 5:
|
|
||||||
raise ValueError("A view can have at most 5 action rows.")
|
while len(rows) <= target_row:
|
||||||
rows.append(ActionRow())
|
if len(rows) >= 5:
|
||||||
|
raise ValueError("A view can have at most 5 action rows.")
|
||||||
rows[target_row].add_component(item)
|
rows.append(ActionRow())
|
||||||
continue
|
|
||||||
|
rows[target_row].add_component(item)
|
||||||
placed = False
|
continue
|
||||||
for row in rows:
|
|
||||||
try:
|
placed = False
|
||||||
row.add_component(item)
|
for row in rows:
|
||||||
placed = True
|
try:
|
||||||
break
|
row.add_component(item)
|
||||||
except ValueError:
|
placed = True
|
||||||
continue
|
break
|
||||||
|
except ValueError:
|
||||||
if not placed:
|
continue
|
||||||
if len(rows) >= 5:
|
|
||||||
raise ValueError("A view can have at most 5 action rows.")
|
if not placed:
|
||||||
new_row = ActionRow([item])
|
if len(rows) >= 5:
|
||||||
rows.append(new_row)
|
raise ValueError("A view can have at most 5 action rows.")
|
||||||
|
new_row = ActionRow([item])
|
||||||
return rows
|
rows.append(new_row)
|
||||||
|
|
||||||
def to_components_payload(self) -> List[Dict[str, Any]]:
|
return rows
|
||||||
"""Converts the view's children into a list of component dictionaries
|
|
||||||
that can be sent to the Discord API."""
|
def to_components_payload(self) -> List[Dict[str, Any]]:
|
||||||
return [row.to_dict() for row in self.to_components()]
|
"""Converts the view's children into a list of component dictionaries
|
||||||
|
that can be sent to the Discord API."""
|
||||||
async def _dispatch(self, interaction: Interaction):
|
return [row.to_dict() for row in self.to_components()]
|
||||||
"""Called by the client to dispatch an interaction to the correct item."""
|
|
||||||
if self.timeout is not None:
|
async def _dispatch(self, interaction: Interaction):
|
||||||
self.__stopped.set() # Reset the timeout on each interaction
|
"""Called by the client to dispatch an interaction to the correct item."""
|
||||||
self.__stopped.clear()
|
if self.timeout is not None:
|
||||||
|
self.__stopped.set() # Reset the timeout on each interaction
|
||||||
if interaction.data:
|
self.__stopped.clear()
|
||||||
custom_id = interaction.data.custom_id
|
|
||||||
for child in self.children:
|
if interaction.data:
|
||||||
if child.custom_id == custom_id:
|
custom_id = interaction.data.custom_id
|
||||||
if child.callback:
|
for child in self.children:
|
||||||
await child.callback(self, interaction)
|
if child.custom_id == custom_id:
|
||||||
break
|
if child.callback:
|
||||||
|
await child.callback(self, interaction)
|
||||||
async def wait(self) -> bool:
|
break
|
||||||
"""Waits until the view has stopped interacting."""
|
|
||||||
return await self.__stopped.wait()
|
async def wait(self) -> bool:
|
||||||
|
"""Waits until the view has stopped interacting."""
|
||||||
def stop(self):
|
return await self.__stopped.wait()
|
||||||
"""Stops the view from listening to interactions."""
|
|
||||||
if not self.__stopped.is_set():
|
def stop(self):
|
||||||
self.__stopped.set()
|
"""Stops the view from listening to interactions."""
|
||||||
|
if not self.__stopped.is_set():
|
||||||
async def on_timeout(self):
|
self.__stopped.set()
|
||||||
"""Called when the view times out."""
|
|
||||||
pass # User can override this
|
async def on_timeout(self):
|
||||||
|
"""Called when the view times out."""
|
||||||
async def _start(self, client: Client):
|
pass # User can override this
|
||||||
"""Starts the view's internal listener."""
|
|
||||||
self._client = client
|
async def _start(self, client: Client):
|
||||||
if self.timeout is not None:
|
"""Starts the view's internal listener."""
|
||||||
asyncio.create_task(self._timeout_task())
|
self._client = client
|
||||||
|
if self.timeout is not None:
|
||||||
async def _timeout_task(self):
|
asyncio.create_task(self._timeout_task())
|
||||||
"""The task that waits for the timeout and then stops the view."""
|
|
||||||
try:
|
async def _timeout_task(self):
|
||||||
await asyncio.wait_for(self.wait(), timeout=self.timeout)
|
"""The task that waits for the timeout and then stops the view."""
|
||||||
except asyncio.TimeoutError:
|
try:
|
||||||
self.stop()
|
await asyncio.wait_for(self.wait(), timeout=self.timeout)
|
||||||
await self.on_timeout()
|
except asyncio.TimeoutError:
|
||||||
if self._client and self._message_id:
|
self.stop()
|
||||||
# Remove the view from the client's listeners
|
await self.on_timeout()
|
||||||
self._client._views.pop(self._message_id, None)
|
if self._client and self._message_id:
|
||||||
|
# Remove the view from the client's listeners
|
||||||
|
self._client._views.pop(self._message_id, None)
|
||||||
|
@ -1,162 +1,244 @@
|
|||||||
# disagreement/voice_client.py
|
# disagreement/voice_client.py
|
||||||
"""Voice gateway and UDP audio client."""
|
"""Voice gateway and UDP audio client."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import socket
|
import socket
|
||||||
from typing import Optional, Sequence
|
import threading
|
||||||
|
from typing import TYPE_CHECKING, Optional, Sequence
|
||||||
import aiohttp
|
|
||||||
|
import aiohttp
|
||||||
from .audio import AudioSource, FFmpegAudioSource
|
# The following import is correct, but may be flagged by Pylance if the virtual
|
||||||
|
# environment is not configured correctly.
|
||||||
|
from nacl.secret import SecretBox
|
||||||
class VoiceClient:
|
|
||||||
"""Handles the Discord voice WebSocket connection and UDP streaming."""
|
from .audio import AudioSink, AudioSource, FFmpegAudioSource
|
||||||
|
from .models import User
|
||||||
def __init__(
|
|
||||||
self,
|
if TYPE_CHECKING:
|
||||||
endpoint: str,
|
from .client import Client
|
||||||
session_id: str,
|
|
||||||
token: str,
|
|
||||||
guild_id: int,
|
class VoiceClient:
|
||||||
user_id: int,
|
"""Handles the Discord voice WebSocket connection and UDP streaming."""
|
||||||
*,
|
|
||||||
ws=None,
|
def __init__(
|
||||||
udp: Optional[socket.socket] = None,
|
self,
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
client: Client,
|
||||||
verbose: bool = False,
|
endpoint: str,
|
||||||
) -> None:
|
session_id: str,
|
||||||
self.endpoint = endpoint
|
token: str,
|
||||||
self.session_id = session_id
|
guild_id: int,
|
||||||
self.token = token
|
user_id: int,
|
||||||
self.guild_id = str(guild_id)
|
*,
|
||||||
self.user_id = str(user_id)
|
ws=None,
|
||||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
|
udp: Optional[socket.socket] = None,
|
||||||
self._udp = udp
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
self._session: Optional[aiohttp.ClientSession] = None
|
verbose: bool = False,
|
||||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
) -> None:
|
||||||
self._heartbeat_interval: Optional[float] = None
|
self.client = client
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
self.endpoint = endpoint
|
||||||
self.verbose = verbose
|
self.session_id = session_id
|
||||||
self.ssrc: Optional[int] = None
|
self.token = token
|
||||||
self.secret_key: Optional[Sequence[int]] = None
|
self.guild_id = str(guild_id)
|
||||||
self._server_ip: Optional[str] = None
|
self.user_id = str(user_id)
|
||||||
self._server_port: Optional[int] = None
|
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
|
||||||
self._current_source: Optional[AudioSource] = None
|
self._udp = udp
|
||||||
self._play_task: Optional[asyncio.Task] = None
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||||
async def connect(self) -> None:
|
self._receive_task: Optional[asyncio.Task] = None
|
||||||
if self._ws is None:
|
self._udp_receive_thread: Optional[threading.Thread] = None
|
||||||
self._session = aiohttp.ClientSession()
|
self._heartbeat_interval: Optional[float] = None
|
||||||
self._ws = await self._session.ws_connect(self.endpoint)
|
try:
|
||||||
|
self._loop = loop or asyncio.get_running_loop()
|
||||||
hello = await self._ws.receive_json()
|
except RuntimeError:
|
||||||
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
|
self._loop = asyncio.new_event_loop()
|
||||||
self._heartbeat_task = self._loop.create_task(self._heartbeat())
|
asyncio.set_event_loop(self._loop)
|
||||||
|
self.verbose = verbose
|
||||||
await self._ws.send_json(
|
self.ssrc: Optional[int] = None
|
||||||
{
|
self.secret_key: Optional[Sequence[int]] = None
|
||||||
"op": 0,
|
self._server_ip: Optional[str] = None
|
||||||
"d": {
|
self._server_port: Optional[int] = None
|
||||||
"server_id": self.guild_id,
|
self._current_source: Optional[AudioSource] = None
|
||||||
"user_id": self.user_id,
|
self._play_task: Optional[asyncio.Task] = None
|
||||||
"session_id": self.session_id,
|
self._sink: Optional[AudioSink] = None
|
||||||
"token": self.token,
|
self._ssrc_map: dict[int, int] = {}
|
||||||
},
|
self._ssrc_lock = threading.Lock()
|
||||||
}
|
|
||||||
)
|
async def connect(self) -> None:
|
||||||
|
if self._ws is None:
|
||||||
ready = await self._ws.receive_json()
|
self._session = aiohttp.ClientSession()
|
||||||
data = ready["d"]
|
self._ws = await self._session.ws_connect(self.endpoint)
|
||||||
self.ssrc = data["ssrc"]
|
|
||||||
self._server_ip = data["ip"]
|
hello = await self._ws.receive_json()
|
||||||
self._server_port = data["port"]
|
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
|
||||||
|
self._heartbeat_task = self._loop.create_task(self._heartbeat())
|
||||||
if self._udp is None:
|
|
||||||
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
await self._ws.send_json(
|
||||||
self._udp.connect((self._server_ip, self._server_port))
|
{
|
||||||
|
"op": 0,
|
||||||
await self._ws.send_json(
|
"d": {
|
||||||
{
|
"server_id": self.guild_id,
|
||||||
"op": 1,
|
"user_id": self.user_id,
|
||||||
"d": {
|
"session_id": self.session_id,
|
||||||
"protocol": "udp",
|
"token": self.token,
|
||||||
"data": {
|
},
|
||||||
"address": self._udp.getsockname()[0],
|
}
|
||||||
"port": self._udp.getsockname()[1],
|
)
|
||||||
"mode": "xsalsa20_poly1305",
|
|
||||||
},
|
ready = await self._ws.receive_json()
|
||||||
},
|
data = ready["d"]
|
||||||
}
|
self.ssrc = data["ssrc"]
|
||||||
)
|
self._server_ip = data["ip"]
|
||||||
|
self._server_port = data["port"]
|
||||||
session_desc = await self._ws.receive_json()
|
|
||||||
self.secret_key = session_desc["d"].get("secret_key")
|
if self._udp is None:
|
||||||
|
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
async def _heartbeat(self) -> None:
|
self._udp.connect((self._server_ip, self._server_port))
|
||||||
assert self._ws is not None
|
|
||||||
assert self._heartbeat_interval is not None
|
await self._ws.send_json(
|
||||||
try:
|
{
|
||||||
while True:
|
"op": 1,
|
||||||
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
|
"d": {
|
||||||
await asyncio.sleep(self._heartbeat_interval)
|
"protocol": "udp",
|
||||||
except asyncio.CancelledError:
|
"data": {
|
||||||
pass
|
"address": self._udp.getsockname()[0],
|
||||||
|
"port": self._udp.getsockname()[1],
|
||||||
async def send_audio_frame(self, frame: bytes) -> None:
|
"mode": "xsalsa20_poly1305",
|
||||||
if not self._udp:
|
},
|
||||||
raise RuntimeError("UDP socket not initialised")
|
},
|
||||||
self._udp.send(frame)
|
}
|
||||||
|
)
|
||||||
async def _play_loop(self) -> None:
|
|
||||||
assert self._current_source is not None
|
session_desc = await self._ws.receive_json()
|
||||||
try:
|
self.secret_key = session_desc["d"].get("secret_key")
|
||||||
while True:
|
|
||||||
data = await self._current_source.read()
|
async def _heartbeat(self) -> None:
|
||||||
if not data:
|
assert self._ws is not None
|
||||||
break
|
assert self._heartbeat_interval is not None
|
||||||
await self.send_audio_frame(data)
|
try:
|
||||||
finally:
|
while True:
|
||||||
await self._current_source.close()
|
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
|
||||||
self._current_source = None
|
await asyncio.sleep(self._heartbeat_interval)
|
||||||
self._play_task = None
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._play_task:
|
async def _receive_loop(self) -> None:
|
||||||
self._play_task.cancel()
|
assert self._ws is not None
|
||||||
with contextlib.suppress(asyncio.CancelledError):
|
while True:
|
||||||
await self._play_task
|
try:
|
||||||
self._play_task = None
|
msg = await self._ws.receive_json()
|
||||||
if self._current_source:
|
op = msg.get("op")
|
||||||
await self._current_source.close()
|
data = msg.get("d")
|
||||||
self._current_source = None
|
if op == 5: # Speaking
|
||||||
|
user_id = int(data["user_id"])
|
||||||
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
|
ssrc = data["ssrc"]
|
||||||
"""|coro| Play an :class:`AudioSource` on the voice connection."""
|
with self._ssrc_lock:
|
||||||
|
self._ssrc_map[ssrc] = user_id
|
||||||
await self.stop()
|
except (asyncio.CancelledError, aiohttp.ClientError):
|
||||||
self._current_source = source
|
break
|
||||||
self._play_task = self._loop.create_task(self._play_loop())
|
|
||||||
if wait:
|
def _udp_receive_loop(self) -> None:
|
||||||
await self._play_task
|
assert self._udp is not None
|
||||||
|
assert self.secret_key is not None
|
||||||
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
box = SecretBox(bytes(self.secret_key))
|
||||||
"""|coro| Stream an audio file or URL using FFmpeg."""
|
while True:
|
||||||
|
try:
|
||||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
packet = self._udp.recv(4096)
|
||||||
|
if len(packet) < 12:
|
||||||
async def close(self) -> None:
|
continue
|
||||||
await self.stop()
|
|
||||||
if self._heartbeat_task:
|
ssrc = int.from_bytes(packet[8:12], "big")
|
||||||
self._heartbeat_task.cancel()
|
with self._ssrc_lock:
|
||||||
with contextlib.suppress(asyncio.CancelledError):
|
if ssrc not in self._ssrc_map:
|
||||||
await self._heartbeat_task
|
continue
|
||||||
if self._ws:
|
user_id = self._ssrc_map[ssrc]
|
||||||
await self._ws.close()
|
user = self.client._users.get(str(user_id))
|
||||||
if self._session:
|
if not user:
|
||||||
await self._session.close()
|
continue
|
||||||
if self._udp:
|
|
||||||
self._udp.close()
|
decrypted = box.decrypt(packet[12:])
|
||||||
|
if self._sink:
|
||||||
|
self._sink.write(user, decrypted)
|
||||||
|
except (socket.error, asyncio.CancelledError):
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Error in UDP receive loop: {e}")
|
||||||
|
|
||||||
|
async def send_audio_frame(self, frame: bytes) -> None:
|
||||||
|
if not self._udp:
|
||||||
|
raise RuntimeError("UDP socket not initialised")
|
||||||
|
self._udp.send(frame)
|
||||||
|
|
||||||
|
async def _play_loop(self) -> None:
|
||||||
|
assert self._current_source is not None
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = await self._current_source.read()
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
await self.send_audio_frame(data)
|
||||||
|
finally:
|
||||||
|
await self._current_source.close()
|
||||||
|
self._current_source = None
|
||||||
|
self._play_task = None
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
if self._play_task:
|
||||||
|
self._play_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._play_task
|
||||||
|
self._play_task = None
|
||||||
|
if self._current_source:
|
||||||
|
await self._current_source.close()
|
||||||
|
self._current_source = None
|
||||||
|
|
||||||
|
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
|
||||||
|
"""|coro| Play an :class:`AudioSource` on the voice connection."""
|
||||||
|
|
||||||
|
await self.stop()
|
||||||
|
self._current_source = source
|
||||||
|
self._play_task = self._loop.create_task(self._play_loop())
|
||||||
|
if wait:
|
||||||
|
await self._play_task
|
||||||
|
|
||||||
|
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
||||||
|
"""|coro| Stream an audio file or URL using FFmpeg."""
|
||||||
|
|
||||||
|
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||||
|
|
||||||
|
def listen(self, sink: AudioSink) -> None:
|
||||||
|
"""Start listening to voice and routing to a sink."""
|
||||||
|
if not isinstance(sink, AudioSink):
|
||||||
|
raise TypeError("sink must be an AudioSink instance")
|
||||||
|
|
||||||
|
self._sink = sink
|
||||||
|
if not self._udp_receive_thread:
|
||||||
|
self._udp_receive_thread = threading.Thread(
|
||||||
|
target=self._udp_receive_loop, daemon=True
|
||||||
|
)
|
||||||
|
self._udp_receive_thread.start()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
await self.stop()
|
||||||
|
if self._heartbeat_task:
|
||||||
|
self._heartbeat_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._heartbeat_task
|
||||||
|
if self._receive_task:
|
||||||
|
self._receive_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._receive_task
|
||||||
|
if self._ws:
|
||||||
|
await self._ws.close()
|
||||||
|
if self._session:
|
||||||
|
await self._session.close()
|
||||||
|
if self._udp:
|
||||||
|
self._udp.close()
|
||||||
|
if self._udp_receive_thread:
|
||||||
|
self._udp_receive_thread.join(timeout=1)
|
||||||
|
if self._sink:
|
||||||
|
self._sink.close()
|
||||||
|
113
pyproject.toml
113
pyproject.toml
@ -1,56 +1,57 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "disagreement"
|
name = "disagreement"
|
||||||
version = "0.2.0rc1"
|
version = "0.2.0rc1"
|
||||||
description = "A Python library for the Discord API."
|
description = "A Python library for the Discord API."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
license = {text = "BSD 3-Clause"}
|
license = {text = "BSD 3-Clause"}
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Slipstream", email = "me@slipstreamm.dev"}
|
{name = "Slipstream", email = "me@slipstreamm.dev"}
|
||||||
]
|
]
|
||||||
keywords = ["discord", "api", "bot", "async", "aiohttp"]
|
keywords = ["discord", "api", "bot", "async", "aiohttp"]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"License :: OSI Approved :: BSD License",
|
"License :: OSI Approved :: BSD License",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
"Programming Language :: Python :: 3.13",
|
"Programming Language :: Python :: 3.13",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
"Topic :: Internet",
|
"Topic :: Internet",
|
||||||
]
|
]
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp>=3.9.0,<4.0.0",
|
"aiohttp>=3.9.0,<4.0.0",
|
||||||
]
|
"PyNaCl>=1.5.0,<2.0.0",
|
||||||
|
]
|
||||||
[project.optional-dependencies]
|
|
||||||
test = [
|
[project.optional-dependencies]
|
||||||
"pytest>=8.0.0",
|
test = [
|
||||||
"pytest-asyncio>=1.0.0",
|
"pytest>=8.0.0",
|
||||||
"hypothesis>=6.132.0",
|
"pytest-asyncio>=1.0.0",
|
||||||
]
|
"hypothesis>=6.132.0",
|
||||||
dev = [
|
]
|
||||||
"python-dotenv>=1.0.0",
|
dev = [
|
||||||
]
|
"python-dotenv>=1.0.0",
|
||||||
|
]
|
||||||
[project.urls]
|
|
||||||
Homepage = "https://github.com/Slipstreamm/disagreement"
|
[project.urls]
|
||||||
Issues = "https://github.com/Slipstreamm/disagreement/issues"
|
Homepage = "https://github.com/Slipstreamm/disagreement"
|
||||||
|
Issues = "https://github.com/Slipstreamm/disagreement/issues"
|
||||||
[build-system]
|
|
||||||
requires = ["setuptools>=61.0"]
|
[build-system]
|
||||||
build-backend = "setuptools.build_meta"
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
# Optional: for linting/formatting, e.g., Ruff
|
|
||||||
# [tool.ruff]
|
# Optional: for linting/formatting, e.g., Ruff
|
||||||
# line-length = 88
|
# [tool.ruff]
|
||||||
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
|
# line-length = 88
|
||||||
# ignore = []
|
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
|
||||||
|
# ignore = []
|
||||||
# [tool.ruff.format]
|
|
||||||
# quote-style = "double"
|
# [tool.ruff.format]
|
||||||
|
# quote-style = "double"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user