This commit is contained in:
Slipstream 2025-06-11 14:52:49 -06:00
commit a702c66603
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
37 changed files with 831 additions and 232 deletions

View File

@ -23,6 +23,13 @@ pip install -e .
Requires Python 3.10 or newer. Requires Python 3.10 or newer.
To run the example scripts, you'll need the `python-dotenv` package to load
environment variables. Install the development extras with:
```bash
pip install "disagreement[dev]"
```
## Basic Usage ## Basic Usage
```python ```python
@ -102,6 +109,20 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other ``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
session parameter supported by ``aiohttp``. session parameter supported by ``aiohttp``.
### Default Allowed Mentions
Specify default mention behaviour for all outgoing messages when constructing the client:
```python
client = disagreement.Client(
token=token,
allowed_mentions={"parse": [], "replied_user": False},
)
```
This dictionary is used whenever ``send_message`` is called without an explicit
``allowed_mentions`` argument.
### Defining Subcommands with `AppCommandGroup` ### Defining Subcommands with `AppCommandGroup`
```python ```python
@ -120,6 +141,7 @@ async def show(ctx: AppCommandContext, key: str):
@slash_command(name="set", description="Update a setting.", parent=admin_group) @slash_command(name="set", description="Update a setting.", parent=admin_group)
async def set_setting(ctx: AppCommandContext, key: str, value: str): async def set_setting(ctx: AppCommandContext, key: str, value: str):
... ...
```
## Fetching Guilds ## Fetching Guilds
Use `Client.fetch_guild` to retrieve a guild from the Discord API if it Use `Client.fetch_guild` to retrieve a guild from the Discord API if it

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import io import io
import shlex
from typing import Optional, Union from typing import Optional, Union
@ -35,15 +36,27 @@ class FFmpegAudioSource(AudioSource):
A filename, URL, or file-like object to read from. A filename, URL, or file-like object to read from.
""" """
def __init__(self, source: Union[str, io.BufferedIOBase]): def __init__(
self,
source: Union[str, io.BufferedIOBase],
*,
before_options: Optional[str] = None,
options: Optional[str] = None,
volume: float = 1.0,
):
self.source = source self.source = source
self.before_options = before_options
self.options = options
self.volume = volume
self.process: Optional[asyncio.subprocess.Process] = None self.process: Optional[asyncio.subprocess.Process] = None
self._feeder: Optional[asyncio.Task] = None self._feeder: Optional[asyncio.Task] = None
async def _spawn(self) -> None: async def _spawn(self) -> None:
if isinstance(self.source, str): if isinstance(self.source, str):
args = [ args = ["ffmpeg"]
"ffmpeg", if self.before_options:
args += shlex.split(self.before_options)
args += [
"-i", "-i",
self.source, self.source,
"-f", "-f",
@ -54,14 +67,18 @@ class FFmpegAudioSource(AudioSource):
"2", "2",
"pipe:1", "pipe:1",
] ]
if self.options:
args += shlex.split(self.options)
self.process = await asyncio.create_subprocess_exec( self.process = await asyncio.create_subprocess_exec(
*args, *args,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL,
) )
else: else:
args = [ args = ["ffmpeg"]
"ffmpeg", if self.before_options:
args += shlex.split(self.before_options)
args += [
"-i", "-i",
"pipe:0", "pipe:0",
"-f", "-f",
@ -72,6 +89,8 @@ class FFmpegAudioSource(AudioSource):
"2", "2",
"pipe:1", "pipe:1",
] ]
if self.options:
args += shlex.split(self.options)
self.process = await asyncio.create_subprocess_exec( self.process = await asyncio.create_subprocess_exec(
*args, *args,
stdin=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
@ -115,6 +134,7 @@ class FFmpegAudioSource(AudioSource):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self.source.close() self.source.close()
class AudioSink: class AudioSink:
"""Abstract base class for audio sinks.""" """Abstract base class for audio sinks."""

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import time import time
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
from collections import OrderedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from .models import Channel, Guild, Member from .models import Channel, Guild, Member
@ -11,15 +12,22 @@ T = TypeVar("T")
class Cache(Generic[T]): class Cache(Generic[T]):
"""Simple in-memory cache with optional TTL support.""" """Simple in-memory cache with optional TTL and max size support."""
def __init__(self, ttl: Optional[float] = None) -> None: def __init__(
self, ttl: Optional[float] = None, maxlen: Optional[int] = None
) -> None:
self.ttl = ttl self.ttl = ttl
self._data: Dict[str, tuple[T, Optional[float]]] = {} self.maxlen = maxlen
self._data: "OrderedDict[str, tuple[T, Optional[float]]]" = OrderedDict()
def set(self, key: str, value: T) -> None: def set(self, key: str, value: T) -> None:
expiry = time.monotonic() + self.ttl if self.ttl is not None else None expiry = time.monotonic() + self.ttl if self.ttl is not None else None
if key in self._data:
self._data.move_to_end(key)
self._data[key] = (value, expiry) self._data[key] = (value, expiry)
if self.maxlen is not None and len(self._data) > self.maxlen:
self._data.popitem(last=False)
def get(self, key: str) -> Optional[T]: def get(self, key: str) -> Optional[T]:
item = self._data.get(key) item = self._data.get(key)
@ -29,6 +37,7 @@ class Cache(Generic[T]):
if expiry is not None and expiry < time.monotonic(): if expiry is not None and expiry < time.monotonic():
self.invalidate(key) self.invalidate(key)
return None return None
self._data.move_to_end(key)
return value return value
def invalidate(self, key: str) -> None: def invalidate(self, key: str) -> None:

View File

@ -8,10 +8,10 @@ class _MemberCacheFlagValue:
flag: int flag: int
def __init__(self, func: Callable[[Any], bool]): def __init__(self, func: Callable[[Any], bool]):
self.flag = getattr(func, 'flag', 0) self.flag = getattr(func, "flag", 0)
self.__doc__ = func.__doc__ self.__doc__ = func.__doc__
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any: def __get__(self, instance: "MemberCacheFlags", owner: type) -> Any:
if instance is None: if instance is None:
return self return self
return instance.value & self.flag != 0 return instance.value & self.flag != 0
@ -23,23 +23,24 @@ class _MemberCacheFlagValue:
instance.value &= ~self.flag instance.value &= ~self.flag
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} flag={self.flag}>' return f"<{self.__class__.__name__} flag={self.flag}>"
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]: def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue: def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
setattr(func, 'flag', flag) setattr(func, "flag", flag)
return _MemberCacheFlagValue(func) return _MemberCacheFlagValue(func)
return decorator return decorator
class MemberCacheFlags: class MemberCacheFlags:
__slots__ = ('value',) __slots__ = ("value",)
VALID_FLAGS: ClassVar[Dict[str, int]] = { VALID_FLAGS: ClassVar[Dict[str, int]] = {
'joined': 1 << 0, "joined": 1 << 0,
'voice': 1 << 1, "voice": 1 << 1,
'online': 1 << 2, "online": 1 << 2,
} }
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4 DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values()) ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
@ -48,7 +49,7 @@ class MemberCacheFlags:
self.value = self.DEFAULT_FLAGS self.value = self.DEFAULT_FLAGS
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid member cache flag.') raise TypeError(f"{key!r} is not a valid member cache flag.")
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
@ -67,7 +68,7 @@ class MemberCacheFlags:
return hash(self.value) return hash(self.value)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<MemberCacheFlags value={self.value}>' return f"<MemberCacheFlags value={self.value}>"
def __iter__(self) -> Iterator[Tuple[str, bool]]: def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name in self.VALID_FLAGS: for name in self.VALID_FLAGS:
@ -92,17 +93,17 @@ class MemberCacheFlags:
@classmethod @classmethod
def only_joined(cls) -> MemberCacheFlags: def only_joined(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled.""" """A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['joined']) return cls._from_value(cls.VALID_FLAGS["joined"])
@classmethod @classmethod
def only_voice(cls) -> MemberCacheFlags: def only_voice(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled.""" """A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['voice']) return cls._from_value(cls.VALID_FLAGS["voice"])
@classmethod @classmethod
def only_online(cls) -> MemberCacheFlags: def only_online(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled.""" """A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['online']) return cls._from_value(cls.VALID_FLAGS["online"])
@flag_value(1 << 0) @flag_value(1 << 0)
def joined(self) -> bool: def joined(self) -> bool:

View File

@ -36,6 +36,7 @@ from .ext import loader as ext_loader
from .interactions import Interaction, Snowflake from .interactions import Interaction, Snowflake
from .error_handler import setup_global_error_handler from .error_handler import setup_global_error_handler
from .voice_client import VoiceClient from .voice_client import VoiceClient
from .models import Activity
if TYPE_CHECKING: if TYPE_CHECKING:
from .models import ( from .models import (
@ -75,13 +76,21 @@ class Client:
intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`. intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`.
You might need to enable privileged intents in your bot's application page. You might need to enable privileged intents in your bot's application page.
loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations. loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations.
Defaults to `asyncio.get_event_loop()`. Defaults to the running loop
via `asyncio.get_running_loop()`,
or a new loop from
`asyncio.new_event_loop()` if
none is running.
command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]): command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]):
The prefix(es) for commands. Defaults to '!'. The prefix(es) for commands. Defaults to '!'.
verbose (bool): If True, print raw HTTP and Gateway traffic for debugging. verbose (bool): If True, print raw HTTP and Gateway traffic for debugging.
mention_replies (bool): Whether replies mention the author by default.
allowed_mentions (Optional[Dict[str, Any]]): Default allowed mentions for messages.
http_options (Optional[Dict[str, Any]]): Extra options passed to http_options (Optional[Dict[str, Any]]): Extra options passed to
:class:`HTTPClient` for creating the internal :class:`HTTPClient` for creating the internal
:class:`aiohttp.ClientSession`. :class:`aiohttp.ClientSession`.
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
in the cache. When ``None``, the cache size is unlimited.
""" """
def __init__( def __init__(
@ -95,10 +104,12 @@ class Client:
application_id: Optional[Union[str, int]] = None, application_id: Optional[Union[str, int]] = None,
verbose: bool = False, verbose: bool = False,
mention_replies: bool = False, mention_replies: bool = False,
allowed_mentions: Optional[Dict[str, Any]] = None,
shard_count: Optional[int] = None, shard_count: Optional[int] = None,
gateway_max_retries: int = 5, gateway_max_retries: int = 5,
gateway_max_backoff: float = 60.0, gateway_max_backoff: float = 60.0,
member_cache_flags: Optional[MemberCacheFlags] = None, member_cache_flags: Optional[MemberCacheFlags] = None,
message_cache_maxlen: Optional[int] = None,
http_options: Optional[Dict[str, Any]] = None, http_options: Optional[Dict[str, Any]] = None,
): ):
if not token: if not token:
@ -108,6 +119,7 @@ class Client:
self.member_cache_flags: MemberCacheFlags = ( self.member_cache_flags: MemberCacheFlags = (
member_cache_flags if member_cache_flags is not None else MemberCacheFlags() member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
) )
self.message_cache_maxlen: Optional[int] = message_cache_maxlen
self.intents: int = intents if intents is not None else GatewayIntent.default() self.intents: int = intents if intents is not None else GatewayIntent.default()
if loop: if loop:
self.loop: asyncio.AbstractEventLoop = loop self.loop: asyncio.AbstractEventLoop = loop
@ -157,7 +169,7 @@ class Client:
self._guilds: GuildCache = GuildCache() self._guilds: GuildCache = GuildCache()
self._channels: ChannelCache = ChannelCache() self._channels: ChannelCache = ChannelCache()
self._users: Cache["User"] = Cache() self._users: Cache["User"] = Cache()
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen)
self._views: Dict[Snowflake, "View"] = {} self._views: Dict[Snowflake, "View"] = {}
self._persistent_views: Dict[str, "View"] = {} self._persistent_views: Dict[str, "View"] = {}
self._voice_clients: Dict[Snowflake, VoiceClient] = {} self._voice_clients: Dict[Snowflake, VoiceClient] = {}
@ -165,6 +177,7 @@ class Client:
# Default whether replies mention the user # Default whether replies mention the user
self.mention_replies: bool = mention_replies self.mention_replies: bool = mention_replies
self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions
# Basic signal handling for graceful shutdown # Basic signal handling for graceful shutdown
# This might be better handled by the user's application code, but can be a nice default. # This might be better handled by the user's application code, but can be a nice default.
@ -435,8 +448,7 @@ class Client:
async def change_presence( async def change_presence(
self, self,
status: str, status: str,
activity_name: Optional[str] = None, activity: Optional[Activity] = None,
activity_type: int = 0,
since: int = 0, since: int = 0,
afk: bool = False, afk: bool = False,
): ):
@ -445,8 +457,7 @@ class Client:
Args: Args:
status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible"). status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible").
activity_name (Optional[str]): The name of the activity. activity (Optional[Activity]): Activity instance describing what the bot is doing.
activity_type (int): The type of the activity.
since (int): The timestamp (in milliseconds) of when the client went idle. since (int): The timestamp (in milliseconds) of when the client went idle.
afk (bool): Whether the client is AFK. afk (bool): Whether the client is AFK.
""" """
@ -456,8 +467,7 @@ class Client:
if self._gateway: if self._gateway:
await self._gateway.update_presence( await self._gateway.update_presence(
status=status, status=status,
activity_name=activity_name, activity=activity,
activity_type=activity_type,
since=since, since=since,
afk=afk, afk=afk,
) )
@ -1010,7 +1020,7 @@ class Client:
embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`. embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`.
Discord supports up to 10 embeds per message. Discord supports up to 10 embeds per message.
components (Optional[List[ActionRow]]): A list of ActionRow components to include. components (Optional[List[ActionRow]]): A list of ActionRow components to include.
allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. Defaults to :attr:`Client.allowed_mentions`.
message_reference (Optional[Dict[str, Any]]): Message reference for replying. message_reference (Optional[Dict[str, Any]]): Message reference for replying.
attachments (Optional[List[Any]]): Attachments to include with the message. attachments (Optional[List[Any]]): Attachments to include with the message.
files (Optional[List[Any]]): Files to upload with the message. files (Optional[List[Any]]): Files to upload with the message.
@ -1057,6 +1067,9 @@ class Client:
if isinstance(comp, ComponentModel) if isinstance(comp, ComponentModel)
] ]
if allowed_mentions is None:
allowed_mentions = self.allowed_mentions
message_data = await self._http.send_message( message_data = await self._http.send_message(
channel_id=channel_id, channel_id=channel_id,
content=content, content=content,
@ -1428,6 +1441,24 @@ class Client:
await self._http.delete_guild_template(guild_id, template_code) await self._http.delete_guild_template(guild_id, template_code)
async def fetch_widget(self, guild_id: Snowflake) -> Dict[str, Any]:
"""|coro| Fetch a guild's widget settings."""
if self._closed:
raise DisagreementException("Client is closed.")
return await self._http.get_guild_widget(guild_id)
async def edit_widget(
self, guild_id: Snowflake, payload: Dict[str, Any]
) -> Dict[str, Any]:
"""|coro| Edit a guild's widget settings."""
if self._closed:
raise DisagreementException("Client is closed.")
return await self._http.edit_guild_widget(guild_id, payload)
async def fetch_scheduled_events( async def fetch_scheduled_events(
self, guild_id: Snowflake self, guild_id: Snowflake
) -> List["ScheduledEvent"]: ) -> List["ScheduledEvent"]:
@ -1514,35 +1545,35 @@ class Client:
return [self.parse_invite(inv) for inv in data] return [self.parse_invite(inv) for inv in data]
def add_persistent_view(self, view: "View") -> None: def add_persistent_view(self, view: "View") -> None:
""" """
Registers a persistent view with the client. Registers a persistent view with the client.
Persistent views have a timeout of `None` and their components must have a `custom_id`. Persistent views have a timeout of `None` and their components must have a `custom_id`.
This allows the view to be re-instantiated across bot restarts. This allows the view to be re-instantiated across bot restarts.
Args: Args:
view (View): The view instance to register. view (View): The view instance to register.
Raises: Raises:
ValueError: If the view is not persistent (timeout is not None) or if a component's ValueError: If the view is not persistent (timeout is not None) or if a component's
custom_id is already registered. custom_id is already registered.
""" """
if self.is_ready(): if self.is_ready():
print( print(
"Warning: Adding a persistent view after the client is ready. " "Warning: Adding a persistent view after the client is ready. "
"This view will only be available for interactions on this session." "This view will only be available for interactions on this session."
) )
if view.timeout is not None: if view.timeout is not None:
raise ValueError("Persistent views must have a timeout of None.") raise ValueError("Persistent views must have a timeout of None.")
for item in view.children: for item in view.children:
if item.custom_id: # Ensure custom_id is not None if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views: if item.custom_id in self._persistent_views:
raise ValueError( raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered." f"A component with custom_id '{item.custom_id}' is already registered."
) )
self._persistent_views[item.custom_id] = view self._persistent_views[item.custom_id] = view
# --- Application Command Methods --- # --- Application Command Methods ---
async def process_interaction(self, interaction: Interaction) -> None: async def process_interaction(self, interaction: Interaction) -> None:

View File

@ -375,6 +375,15 @@ class OverwriteType(IntEnum):
MEMBER = 1 MEMBER = 1
class AutoArchiveDuration(IntEnum):
"""Thread auto-archive duration in minutes."""
HOUR = 60
DAY = 1440
THREE_DAYS = 4320
WEEK = 10080
# --- Component Enums --- # --- Component Enums ---

View File

@ -14,7 +14,11 @@ def setup_global_error_handler(
The handler logs unhandled exceptions so they don't crash the bot. The handler logs unhandled exceptions so they don't crash the bot.
""" """
if loop is None: if loop is None:
loop = asyncio.get_event_loop() try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if not logging.getLogger().hasHandlers(): if not logging.getLogger().hasHandlers():
setup_logging(logging.ERROR) setup_logging(logging.ERROR)

View File

@ -1,8 +1,10 @@
# disagreement/ext/commands/converters.py # disagreement/ext/commands/converters.py
# pyright: reportIncompatibleMethodOverride=false
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import re import re
import inspect
from .errors import BadArgument from .errors import BadArgument
from disagreement.models import Member, Guild, Role from disagreement.models import Member, Guild, Role
@ -36,6 +38,20 @@ class Converter(ABC, Generic[T]):
raise NotImplementedError("Converter subclass must implement convert method.") raise NotImplementedError("Converter subclass must implement convert method.")
class Greedy(list):
"""Type hint helper to greedily consume arguments."""
converter: Any = None
def __class_getitem__(cls, param: Any) -> type: # pyright: ignore[override]
if isinstance(param, tuple):
if len(param) != 1:
raise TypeError("Greedy[...] expects a single parameter")
param = param[0]
name = f"Greedy[{getattr(param, '__name__', str(param))}]"
return type(name, (Greedy,), {"converter": param})
# --- Built-in Type Converters --- # --- Built-in Type Converters ---
@ -169,7 +185,3 @@ async def run_converters(ctx: "CommandContext", annotation: Any, argument: str)
raise BadArgument(f"No converter found for type annotation '{annotation}'.") raise BadArgument(f"No converter found for type annotation '{annotation}'.")
return argument # Default to string if no annotation or annotation is str return argument # Default to string if no annotation or annotation is str
# Need to import inspect for the run_converters function
import inspect

View File

@ -29,7 +29,7 @@ from .errors import (
CheckFailure, CheckFailure,
CommandInvokeError, CommandInvokeError,
) )
from .converters import run_converters, DEFAULT_CONVERTERS, Converter from .converters import Greedy, run_converters, DEFAULT_CONVERTERS, Converter
from disagreement.typing import Typing from disagreement.typing import Typing
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,29 +46,39 @@ class GroupMixin:
self.commands: Dict[str, "Command"] = {} self.commands: Dict[str, "Command"] = {}
self.name: str = "" self.name: str = ""
def command(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Command"]: def command(
self, **attrs: Any
) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
def decorator(func: Callable[..., Awaitable[None]]) -> "Command": def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
cmd = Command(func, **attrs) cmd = Command(func, **attrs)
cmd.cog = getattr(self, "cog", None) cmd.cog = getattr(self, "cog", None)
self.add_command(cmd) self.add_command(cmd)
return cmd return cmd
return decorator return decorator
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Group"]: def group(
self, **attrs: Any
) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
def decorator(func: Callable[..., Awaitable[None]]) -> "Group": def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
cmd = Group(func, **attrs) cmd = Group(func, **attrs)
cmd.cog = getattr(self, "cog", None) cmd.cog = getattr(self, "cog", None)
self.add_command(cmd) self.add_command(cmd)
return cmd return cmd
return decorator return decorator
def add_command(self, command: "Command") -> None: def add_command(self, command: "Command") -> None:
if command.name in self.commands: if command.name in self.commands:
raise ValueError(f"Command '{command.name}' is already registered in group '{self.name}'.") raise ValueError(
f"Command '{command.name}' is already registered in group '{self.name}'."
)
self.commands[command.name.lower()] = command self.commands[command.name.lower()] = command
for alias in command.aliases: for alias in command.aliases:
if alias in self.commands: if alias in self.commands:
logger.warning(f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias.") logger.warning(
f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias."
)
self.commands[alias.lower()] = command self.commands[alias.lower()] = command
def get_command(self, name: str) -> Optional["Command"]: def get_command(self, name: str) -> Optional["Command"]:
@ -181,6 +191,7 @@ class Command(GroupMixin):
class Group(Command): class Group(Command):
"""A command that can have subcommands.""" """A command that can have subcommands."""
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any): def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
super().__init__(callback, **attrs) super().__init__(callback, **attrs)
@ -494,7 +505,34 @@ class CommandHandler:
None # Holds the raw string for current param None # Holds the raw string for current param
) )
if view.eof: # No more input string annotation = param.annotation
if inspect.isclass(annotation) and issubclass(annotation, Greedy):
greedy_values = []
converter_type = annotation.converter
while not view.eof:
view.skip_whitespace()
if view.eof:
break
start = view.index
if view.buffer[view.index] == '"':
arg_str_value = view.get_quoted_string()
if arg_str_value == "" and view.buffer[view.index] == '"':
raise BadArgument(
f"Unterminated quoted string for argument '{param.name}'."
)
else:
arg_str_value = view.get_word()
try:
converted = await run_converters(
ctx, converter_type, arg_str_value
)
except BadArgument:
view.index = start
break
greedy_values.append(converted)
final_value_for_param = greedy_values
arg_str_value = None
elif view.eof: # No more input string
if param.default is not inspect.Parameter.empty: if param.default is not inspect.Parameter.empty:
final_value_for_param = param.default final_value_for_param = param.default
elif param.kind != inspect.Parameter.VAR_KEYWORD: elif param.kind != inspect.Parameter.VAR_KEYWORD:
@ -656,7 +694,9 @@ class CommandHandler:
elif command.invoke_without_command: elif command.invoke_without_command:
view.index -= len(potential_subcommand) + view.previous view.index -= len(potential_subcommand) + view.previous
else: else:
raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.") raise CommandNotFound(
f"Subcommand '{potential_subcommand}' not found."
)
ctx = CommandContext( ctx = CommandContext(
message=message, message=message,
@ -681,7 +721,9 @@ class CommandHandler:
if hasattr(self.client, "on_command_error"): if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, e) await self.client.on_command_error(ctx, e)
except Exception as e: except Exception as e:
logger.error("Unexpected error invoking command '%s': %s", original_command.name, e) logger.error(
"Unexpected error invoking command '%s': %s", original_command.name, e
)
exc = CommandInvokeError(e) exc = CommandInvokeError(e)
if hasattr(self.client, "on_command_error"): if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, exc) await self.client.on_command_error(ctx, exc)

View File

@ -218,6 +218,7 @@ def requires_permissions(
return check(predicate) return check(predicate)
def has_role( def has_role(
name_or_id: str | int, name_or_id: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
@ -241,9 +242,7 @@ def has_role(
raise CheckFailure("Could not resolve author to a guild member.") raise CheckFailure("Could not resolve author to a guild member.")
# Create a list of the member's role objects by looking them up in the guild's roles list # Create a list of the member's role objects by looking them up in the guild's roles list
member_roles = [ member_roles = [role for role in ctx.guild.roles if role.id in author.roles]
role for role in ctx.guild.roles if role.id in author.roles
]
if any( if any(
role.id == str(name_or_id) or role.name == name_or_id role.id == str(name_or_id) or role.name == name_or_id
@ -278,9 +277,7 @@ def has_any_role(
if not author: if not author:
raise CheckFailure("Could not resolve author to a guild member.") raise CheckFailure("Could not resolve author to a guild member.")
member_roles = [ member_roles = [role for role in ctx.guild.roles if role.id in author.roles]
role for role in ctx.guild.roles if role.id in author.roles
]
# Convert names_or_ids to a set for efficient lookup # Convert names_or_ids to a set for efficient lookup
names_or_ids_set = set(map(str, names_or_ids)) names_or_ids_set = set(map(str, names_or_ids))

View File

@ -14,6 +14,8 @@ import time
import random import random
from typing import Optional, TYPE_CHECKING, Any, Dict from typing import Optional, TYPE_CHECKING, Any, Dict
from .models import Activity
from .enums import GatewayOpcode, GatewayIntent from .enums import GatewayOpcode, GatewayIntent
from .errors import GatewayException, DisagreementException, AuthenticationError from .errors import GatewayException, DisagreementException, AuthenticationError
from .interactions import Interaction from .interactions import Interaction
@ -63,7 +65,11 @@ class GatewayClient:
self._max_backoff: float = max_backoff self._max_backoff: float = max_backoff
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() try:
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._heartbeat_interval: Optional[float] = None self._heartbeat_interval: Optional[float] = None
self._last_sequence: Optional[int] = None self._last_sequence: Optional[int] = None
self._session_id: Optional[str] = None self._session_id: Optional[str] = None
@ -213,26 +219,17 @@ class GatewayClient:
async def update_presence( async def update_presence(
self, self,
status: str, status: str,
activity_name: Optional[str] = None, activity: Optional[Activity] = None,
activity_type: int = 0, *,
since: int = 0, since: int = 0,
afk: bool = False, afk: bool = False,
): ) -> None:
"""Sends the presence update payload to the Gateway.""" """Sends the presence update payload to the Gateway."""
payload = { payload = {
"op": GatewayOpcode.PRESENCE_UPDATE, "op": GatewayOpcode.PRESENCE_UPDATE,
"d": { "d": {
"since": since, "since": since,
"activities": ( "activities": [activity.to_dict()] if activity else [],
[
{
"name": activity_name,
"type": activity_type,
}
]
if activity_name
else []
),
"status": status, "status": status,
"afk": afk, "afk": afk,
}, },
@ -353,7 +350,10 @@ class GatewayClient:
future._members.extend(raw_event_d_payload.get("members", [])) # type: ignore future._members.extend(raw_event_d_payload.get("members", [])) # type: ignore
# If this is the last chunk, resolve the future # If this is the last chunk, resolve the future
if raw_event_d_payload.get("chunk_index") == raw_event_d_payload.get("chunk_count", 1) - 1: if (
raw_event_d_payload.get("chunk_index")
== raw_event_d_payload.get("chunk_count", 1) - 1
):
future.set_result(future._members) # type: ignore future.set_result(future._members) # type: ignore
del self._member_chunk_requests[nonce] del self._member_chunk_requests[nonce]

View File

@ -601,18 +601,18 @@ class HTTPClient:
) )
async def delete_user_reaction( async def delete_user_reaction(
self, self,
channel_id: "Snowflake", channel_id: "Snowflake",
message_id: "Snowflake", message_id: "Snowflake",
emoji: str, emoji: str,
user_id: "Snowflake", user_id: "Snowflake",
) -> None: ) -> None:
"""Removes another user's reaction from a message.""" """Removes another user's reaction from a message."""
encoded = quote(emoji) encoded = quote(emoji)
await self.request( await self.request(
"DELETE", "DELETE",
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}", f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}",
) )
async def get_reactions( async def get_reactions(
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
@ -910,6 +910,20 @@ class HTTPClient:
"""Fetches a guild object for a given guild ID.""" """Fetches a guild object for a given guild ID."""
return await self.request("GET", f"/guilds/{guild_id}") return await self.request("GET", f"/guilds/{guild_id}")
async def get_guild_widget(self, guild_id: "Snowflake") -> Dict[str, Any]:
"""Fetches the guild widget settings."""
return await self.request("GET", f"/guilds/{guild_id}/widget")
async def edit_guild_widget(
self, guild_id: "Snowflake", payload: Dict[str, Any]
) -> Dict[str, Any]:
"""Edits the guild widget settings."""
return await self.request(
"PATCH", f"/guilds/{guild_id}/widget", payload=payload
)
async def get_guild_templates(self, guild_id: "Snowflake") -> List[Dict[str, Any]]: async def get_guild_templates(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
"""Fetches all templates for the given guild.""" """Fetches all templates for the given guild."""
return await self.request("GET", f"/guilds/{guild_id}/templates") return await self.request("GET", f"/guilds/{guild_id}/templates")

View File

@ -6,6 +6,7 @@ Data models for Discord objects.
import asyncio import asyncio
import json import json
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
@ -24,6 +25,7 @@ from .enums import ( # These enums will need to be defined in disagreement/enum
PremiumTier, PremiumTier,
GuildFeature, GuildFeature,
ChannelType, ChannelType,
AutoArchiveDuration,
ComponentType, ComponentType,
ButtonStyle, # Added for Button ButtonStyle, # Added for Button
GuildScheduledEventPrivacyLevel, GuildScheduledEventPrivacyLevel,
@ -39,6 +41,7 @@ if TYPE_CHECKING:
from .enums import OverwriteType # For PermissionOverwrite model from .enums import OverwriteType # For PermissionOverwrite model
from .ui.view import View from .ui.view import View
from .interactions import Snowflake from .interactions import Snowflake
from .typing import Typing
# Forward reference Message if it were used in type hints before its definition # Forward reference Message if it were used in type hints before its definition
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
@ -114,31 +117,39 @@ class Message:
# self.mention_roles: List[str] = data.get("mention_roles", []) # self.mention_roles: List[str] = data.get("mention_roles", [])
# self.mention_everyone: bool = data.get("mention_everyone", False) # self.mention_everyone: bool = data.get("mention_everyone", False)
@property
def clean_content(self) -> str:
"""Returns message content without user, role, or channel mentions."""
pattern = re.compile(r"<@!?\d+>|<#\d+>|<@&\d+>")
cleaned = pattern.sub("", self.content)
return " ".join(cleaned.split())
async def pin(self) -> None: async def pin(self) -> None:
"""|coro| """|coro|
Pins this message to its channel. Pins this message to its channel.
Raises Raises
------ ------
HTTPException HTTPException
Pinning the message failed. Pinning the message failed.
""" """
await self._client._http.pin_message(self.channel_id, self.id) await self._client._http.pin_message(self.channel_id, self.id)
self.pinned = True self.pinned = True
async def unpin(self) -> None: async def unpin(self) -> None:
"""|coro| """|coro|
Unpins this message from its channel. Unpins this message from its channel.
Raises Raises
------ ------
HTTPException HTTPException
Unpinning the message failed. Unpinning the message failed.
""" """
await self._client._http.unpin_message(self.channel_id, self.id) await self._client._http.unpin_message(self.channel_id, self.id)
self.pinned = False self.pinned = False
async def reply( async def reply(
self, self,
@ -241,16 +252,16 @@ class Message:
await self._client.add_reaction(self.channel_id, self.id, emoji) await self._client.add_reaction(self.channel_id, self.id, emoji)
async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None: async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None:
"""|coro| """|coro|
Removes a reaction from this message. Removes a reaction from this message.
If no ``member`` is provided, removes the bot's own reaction. If no ``member`` is provided, removes the bot's own reaction.
""" """
if member: if member:
await self._client._http.delete_user_reaction( await self._client._http.delete_user_reaction(
self.channel_id, self.id, emoji, member.id self.channel_id, self.id, emoji, member.id
) )
else: else:
await self._client.remove_reaction(self.channel_id, self.id, emoji) await self._client.remove_reaction(self.channel_id, self.id, emoji)
async def clear_reactions(self) -> None: async def clear_reactions(self) -> None:
"""|coro| Remove all reactions from this message.""" """|coro| Remove all reactions from this message."""
@ -280,7 +291,7 @@ class Message:
self, self,
name: str, name: str,
*, *,
auto_archive_duration: Optional[int] = None, auto_archive_duration: Optional[AutoArchiveDuration] = None,
rate_limit_per_user: Optional[int] = None, rate_limit_per_user: Optional[int] = None,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> "Thread": ) -> "Thread":
@ -292,9 +303,9 @@ class Message:
---------- ----------
name: str name: str
The name of the thread. The name of the thread.
auto_archive_duration: Optional[int] auto_archive_duration: Optional[AutoArchiveDuration]
The duration in minutes to automatically archive the thread after recent activity. How long before the thread is automatically archived after recent activity.
Can be one of 60, 1440, 4320, 10080. See :class:`AutoArchiveDuration` for allowed values.
rate_limit_per_user: Optional[int] rate_limit_per_user: Optional[int]
The number of seconds a user has to wait before sending another message. The number of seconds a user has to wait before sending another message.
reason: Optional[str] reason: Optional[str]
@ -307,7 +318,7 @@ class Message:
""" """
payload: Dict[str, Any] = {"name": name} payload: Dict[str, Any] = {"name": name}
if auto_archive_duration is not None: if auto_archive_duration is not None:
payload["auto_archive_duration"] = auto_archive_duration payload["auto_archive_duration"] = int(auto_archive_duration)
if rate_limit_per_user is not None: if rate_limit_per_user is not None:
payload["rate_limit_per_user"] = rate_limit_per_user payload["rate_limit_per_user"] = rate_limit_per_user
@ -530,8 +541,42 @@ class Embed:
payload["fields"] = [f.to_dict() for f in self.fields] payload["fields"] = [f.to_dict() for f in self.fields]
return payload return payload
# Convenience methods for building embeds can be added here # Convenience methods mirroring ``discord.py``'s ``Embed`` API
# e.g., set_author, add_field, set_footer, set_image, etc.
def set_author(
self, *, name: str, url: Optional[str] = None, icon_url: Optional[str] = None
) -> "Embed":
"""Set the embed author and return ``self`` for chaining."""
data: Dict[str, Any] = {"name": name}
if url:
data["url"] = url
if icon_url:
data["icon_url"] = icon_url
self.author = EmbedAuthor(data)
return self
def add_field(self, *, name: str, value: str, inline: bool = False) -> "Embed":
"""Add a field to the embed."""
field = EmbedField({"name": name, "value": value, "inline": inline})
self.fields.append(field)
return self
def set_footer(self, *, text: str, icon_url: Optional[str] = None) -> "Embed":
"""Set the embed footer."""
data: Dict[str, Any] = {"text": text}
if icon_url:
data["icon_url"] = icon_url
self.footer = EmbedFooter(data)
return self
def set_image(self, url: str) -> "Embed":
"""Set the embed image."""
self.image = EmbedImage({"url": url})
return self
class Attachment: class Attachment:
@ -1088,7 +1133,9 @@ class Guild:
# Internal caches, populated by events or specific fetches # Internal caches, populated by events or specific fetches
self._channels: ChannelCache = ChannelCache() self._channels: ChannelCache = ChannelCache()
self._members: MemberCache = MemberCache(getattr(client_instance, "member_cache_flags", MemberCacheFlags())) self._members: MemberCache = MemberCache(
getattr(client_instance, "member_cache_flags", MemberCacheFlags())
)
self._threads: Dict[str, "Thread"] = {} self._threads: Dict[str, "Thread"] = {}
def get_channel(self, channel_id: str) -> Optional["Channel"]: def get_channel(self, channel_id: str) -> Optional["Channel"]:
@ -1128,6 +1175,16 @@ class Guild:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Guild id='{self.id}' name='{self.name}'>" return f"<Guild id='{self.id}' name='{self.name}'>"
async def fetch_widget(self) -> Dict[str, Any]:
"""|coro| Fetch this guild's widget settings."""
return await self._client.fetch_widget(self.id)
async def edit_widget(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""|coro| Edit this guild's widget settings."""
return await self._client.edit_widget(self.id, payload)
async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]: async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]:
"""|coro| """|coro|
@ -1278,7 +1335,45 @@ class Channel:
return base return base
class TextChannel(Channel): class Messageable:
"""Mixin for channels that can send messages and show typing."""
_client: "Client"
id: str
async def send(
self,
content: Optional[str] = None,
*,
embed: Optional["Embed"] = None,
embeds: Optional[List["Embed"]] = None,
components: Optional[List["ActionRow"]] = None,
) -> "Message":
if not hasattr(self._client, "send_message"):
raise NotImplementedError(
"Client.send_message is required for Messageable.send"
)
return await self._client.send_message(
channel_id=self.id,
content=content,
embed=embed,
embeds=embeds,
components=components,
)
async def trigger_typing(self) -> None:
await self._client._http.trigger_typing(self.id)
def typing(self) -> "Typing":
if not hasattr(self._client, "typing"):
raise NotImplementedError(
"Client.typing is required for Messageable.typing"
)
return self._client.typing(self.id)
class TextChannel(Channel, Messageable):
"""Represents a guild text channel or announcement channel.""" """Represents a guild text channel or announcement channel."""
def __init__(self, data: Dict[str, Any], client_instance: "Client"): def __init__(self, data: Dict[str, Any], client_instance: "Client"):
@ -1304,27 +1399,6 @@ class TextChannel(Channel):
return message_pager(self, limit=limit, before=before, after=after) return message_pager(self, limit=limit, before=before, after=after)
async def send(
self,
content: Optional[str] = None,
*,
embed: Optional[Embed] = None,
embeds: Optional[List[Embed]] = None,
components: Optional[List["ActionRow"]] = None, # Added components
) -> "Message": # Forward reference Message
if not hasattr(self._client, "send_message"):
raise NotImplementedError(
"Client.send_message is required for TextChannel.send"
)
return await self._client.send_message(
channel_id=self.id,
content=content,
embed=embed,
embeds=embeds,
components=components,
)
async def purge( async def purge(
self, limit: int, *, before: "Snowflake | None" = None self, limit: int, *, before: "Snowflake | None" = None
) -> List["Snowflake"]: ) -> List["Snowflake"]:
@ -1347,21 +1421,21 @@ class TextChannel(Channel):
return ids return ids
def get_partial_message(self, id: int) -> "PartialMessage": def get_partial_message(self, id: int) -> "PartialMessage":
"""Returns a :class:`PartialMessage` for the given ID. """Returns a :class:`PartialMessage` for the given ID.
This allows performing actions on a message without fetching it first. This allows performing actions on a message without fetching it first.
Parameters Parameters
---------- ----------
id: int id: int
The ID of the message to get a partial instance of. The ID of the message to get a partial instance of.
Returns Returns
------- -------
PartialMessage PartialMessage
The partial message instance. The partial message instance.
""" """
return PartialMessage(id=str(id), channel=self) return PartialMessage(id=str(id), channel=self)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>" return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
@ -1390,7 +1464,7 @@ class TextChannel(Channel):
name: str, name: str,
*, *,
type: ChannelType = ChannelType.PUBLIC_THREAD, type: ChannelType = ChannelType.PUBLIC_THREAD,
auto_archive_duration: Optional[int] = None, auto_archive_duration: Optional[AutoArchiveDuration] = None,
invitable: Optional[bool] = None, invitable: Optional[bool] = None,
rate_limit_per_user: Optional[int] = None, rate_limit_per_user: Optional[int] = None,
reason: Optional[str] = None, reason: Optional[str] = None,
@ -1406,8 +1480,8 @@ class TextChannel(Channel):
type: ChannelType type: ChannelType
The type of thread to create. Defaults to PUBLIC_THREAD. The type of thread to create. Defaults to PUBLIC_THREAD.
Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD. Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD.
auto_archive_duration: Optional[int] auto_archive_duration: Optional[AutoArchiveDuration]
The duration in minutes to automatically archive the thread after recent activity. How long before the thread is automatically archived after recent activity.
invitable: Optional[bool] invitable: Optional[bool]
Whether non-moderators can invite other non-moderators to a private thread. Whether non-moderators can invite other non-moderators to a private thread.
Only applicable to private threads. Only applicable to private threads.
@ -1426,7 +1500,7 @@ class TextChannel(Channel):
"type": type.value, "type": type.value,
} }
if auto_archive_duration is not None: if auto_archive_duration is not None:
payload["auto_archive_duration"] = auto_archive_duration payload["auto_archive_duration"] = int(auto_archive_duration)
if invitable is not None and type == ChannelType.PRIVATE_THREAD: if invitable is not None and type == ChannelType.PRIVATE_THREAD:
payload["invitable"] = invitable payload["invitable"] = invitable
if rate_limit_per_user is not None: if rate_limit_per_user is not None:
@ -1606,7 +1680,9 @@ class Thread(TextChannel): # Threads are a specialized TextChannel
""" """
await self._client._http.leave_thread(self.id) await self._client._http.leave_thread(self.id)
async def archive(self, locked: bool = False, *, reason: Optional[str] = None) -> "Thread": async def archive(
self, locked: bool = False, *, reason: Optional[str] = None
) -> "Thread":
"""|coro| """|coro|
Archives this thread. Archives this thread.
@ -1631,7 +1707,7 @@ class Thread(TextChannel): # Threads are a specialized TextChannel
return cast("Thread", self._client.parse_channel(data)) return cast("Thread", self._client.parse_channel(data))
class DMChannel(Channel): class DMChannel(Channel, Messageable):
"""Represents a Direct Message channel.""" """Represents a Direct Message channel."""
def __init__(self, data: Dict[str, Any], client_instance: "Client"): def __init__(self, data: Dict[str, Any], client_instance: "Client"):
@ -1645,27 +1721,6 @@ class DMChannel(Channel):
def recipient(self) -> Optional[User]: def recipient(self) -> Optional[User]:
return self.recipients[0] if self.recipients else None return self.recipients[0] if self.recipients else None
async def send(
self,
content: Optional[str] = None,
*,
embed: Optional[Embed] = None,
embeds: Optional[List[Embed]] = None,
components: Optional[List["ActionRow"]] = None, # Added components
) -> "Message":
if not hasattr(self._client, "send_message"):
raise NotImplementedError(
"Client.send_message is required for DMChannel.send"
)
return await self._client.send_message(
channel_id=self.id,
content=content,
embed=embed,
embeds=embeds,
components=components,
)
async def history( async def history(
self, self,
*, *,
@ -2356,6 +2411,37 @@ class ThreadMember:
return f"<ThreadMember user_id='{self.user_id}' thread_id='{self.id}'>" return f"<ThreadMember user_id='{self.user_id}' thread_id='{self.id}'>"
class Activity:
"""Represents a user's presence activity."""
def __init__(self, name: str, type: int) -> None:
self.name = name
self.type = type
def to_dict(self) -> Dict[str, Any]:
return {"name": self.name, "type": self.type}
class Game(Activity):
"""Represents a playing activity."""
def __init__(self, name: str) -> None:
super().__init__(name, 0)
class Streaming(Activity):
"""Represents a streaming activity."""
def __init__(self, name: str, url: str) -> None:
super().__init__(name, 1)
self.url = url
def to_dict(self) -> Dict[str, Any]:
payload = super().to_dict()
payload["url"] = self.url
return payload
class PresenceUpdate: class PresenceUpdate:
"""Represents a PRESENCE_UPDATE event.""" """Represents a PRESENCE_UPDATE event."""
@ -2366,7 +2452,17 @@ class PresenceUpdate:
self.user = User(data["user"]) self.user = User(data["user"])
self.guild_id: Optional[str] = data.get("guild_id") self.guild_id: Optional[str] = data.get("guild_id")
self.status: Optional[str] = data.get("status") self.status: Optional[str] = data.get("status")
self.activities: List[Dict[str, Any]] = data.get("activities", []) self.activities: List[Activity] = []
for activity in data.get("activities", []):
act_type = activity.get("type", 0)
name = activity.get("name", "")
if act_type == 0:
obj = Game(name)
elif act_type == 1:
obj = Streaming(name, activity.get("url", ""))
else:
obj = Activity(name, act_type)
self.activities.append(obj)
self.client_status: Dict[str, Any] = data.get("client_status", {}) self.client_status: Dict[str, Any] = data.get("client_status", {})
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@ -72,7 +72,7 @@ class View:
rows: List[ActionRow] = [] rows: List[ActionRow] = []
for item in self.children: for item in self.children:
rows.append(ActionRow(components=[item])) rows.append(ActionRow(components=[item]))
return rows return rows

View File

@ -7,9 +7,26 @@ import asyncio
import contextlib import contextlib
import socket import socket
import threading import threading
from array import array
def _apply_volume(data: bytes, volume: float) -> bytes:
samples = array("h")
samples.frombytes(data)
for i, sample in enumerate(samples):
scaled = int(sample * volume)
if scaled > 32767:
scaled = 32767
elif scaled < -32768:
scaled = -32768
samples[i] = scaled
return samples.tobytes()
from typing import TYPE_CHECKING, Optional, Sequence from typing import TYPE_CHECKING, Optional, Sequence
import aiohttp import aiohttp
# The following import is correct, but may be flagged by Pylance if the virtual # The following import is correct, but may be flagged by Pylance if the virtual
# environment is not configured correctly. # environment is not configured correctly.
from nacl.secret import SecretBox from nacl.secret import SecretBox
@ -180,6 +197,9 @@ class VoiceClient:
data = await self._current_source.read() data = await self._current_source.read()
if not data: if not data:
break break
volume = getattr(self._current_source, "volume", 1.0)
if volume != 1.0:
data = _apply_volume(data, volume)
await self.send_audio_frame(data) await self.send_audio_frame(data)
finally: finally:
await self._current_source.close() await self._current_source.close()

22
docs/embeds.md Normal file
View File

@ -0,0 +1,22 @@
# Embeds
`Embed` objects can be constructed piece by piece much like in `discord.py`.
These helper methods return the embed instance so you can chain calls.
```python
from disagreement.models import Embed
embed = (
Embed()
.set_author(name="Disagreement", url="https://example.com", icon_url="https://cdn.example.com/bot.png")
.add_field(name="Info", value="Some details")
.set_footer(text="Made with Disagreement")
.set_image(url="https://cdn.example.com/image.png")
)
```
Call `to_dict()` to convert the embed back to a payload dictionary before sending:
```python
payload = embed.to_dict()
```

23
docs/mentions.md Normal file
View File

@ -0,0 +1,23 @@
# Controlling Mentions
The client exposes settings to control how mentions behave in outgoing messages.
## Default Allowed Mentions
Use the ``allowed_mentions`` parameter of :class:`disagreement.Client` to set a
default for all messages:
```python
client = disagreement.Client(
token="YOUR_TOKEN",
allowed_mentions={"parse": [], "replied_user": False},
)
```
When ``Client.send_message`` is called without an explicit ``allowed_mentions``
argument this value will be used.
## Next Steps
- [Commands](commands.md)
- [HTTP Client Options](http_client.md)

View File

@ -1,6 +1,7 @@
# Updating Presence # Updating Presence
The `Client.change_presence` method allows you to update the bot's status and displayed activity. The `Client.change_presence` method allows you to update the bot's status and displayed activity.
Pass an :class:`~disagreement.models.Activity` (such as :class:`~disagreement.models.Game` or :class:`~disagreement.models.Streaming`) to describe what your bot is doing.
## Status Strings ## Status Strings
@ -22,8 +23,18 @@ An activity dictionary must include a `name` and a `type` field. The type value
| `4` | Custom | | `4` | Custom |
| `5` | Competing | | `5` | Competing |
Example: Example using the provided activity classes:
```python ```python
await client.change_presence(status="idle", activity={"name": "with Discord", "type": 0}) from disagreement.models import Game
await client.change_presence(status="idle", activity=Game("with Discord"))
```
You can also specify a streaming URL:
```python
from disagreement.models import Streaming
await client.change_presence(status="online", activity=Streaming("My Stream", "https://twitch.tv/someone"))
``` ```

18
docs/threads.md Normal file
View File

@ -0,0 +1,18 @@
# Threads
`Message.create_thread` and `TextChannel.create_thread` let you start new threads.
Use :class:`AutoArchiveDuration` to control when a thread is automatically archived.
```python
from disagreement.enums import AutoArchiveDuration
await message.create_thread(
"discussion",
auto_archive_duration=AutoArchiveDuration.DAY,
)
```
## Next Steps
- [Message History](message_history.md)
- [Caching](caching.md)

View File

@ -39,9 +39,14 @@ except ImportError:
) )
sys.exit(1) sys.exit(1)
from dotenv import load_dotenv try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
load_dotenv() if load_dotenv:
load_dotenv()
# Optional: Configure logging for more insight, especially for gateway events # Optional: Configure logging for more insight, especially for gateway events
# logging.basicConfig(level=logging.DEBUG) # For very verbose output # logging.basicConfig(level=logging.DEBUG) # For very verbose output

View File

@ -37,9 +37,15 @@ from disagreement.interactions import (
InteractionResponsePayload, InteractionResponsePayload,
InteractionCallbackData, InteractionCallbackData,
) )
from dotenv import load_dotenv
load_dotenv() try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
if load_dotenv:
load_dotenv()
# Get the bot token and application ID from the environment variables # Get the bot token and application ID from the environment variables
token = os.getenv("DISCORD_BOT_TOKEN") token = os.getenv("DISCORD_BOT_TOKEN")

View File

@ -15,9 +15,14 @@ from disagreement.ext.app_commands import (
) )
from disagreement.models import User, Message from disagreement.models import User, Message
from dotenv import load_dotenv try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
load_dotenv() if load_dotenv:
load_dotenv()
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "") BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "") APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "")

View File

@ -4,7 +4,11 @@ import asyncio
import os import os
import sys import sys
from dotenv import load_dotenv try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
# Allow running from the examples folder without installing # Allow running from the examples folder without installing
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
@ -12,7 +16,8 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
from disagreement import Client from disagreement import Client
load_dotenv() if load_dotenv:
load_dotenv()
TOKEN = os.environ.get("DISCORD_BOT_TOKEN") TOKEN = os.environ.get("DISCORD_BOT_TOKEN")

View File

@ -36,9 +36,14 @@ from disagreement.enums import (
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from dotenv import load_dotenv try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
load_dotenv() if load_dotenv:
load_dotenv()
# --- Define a Test Cog --- # --- Define a Test Cog ---

View File

@ -10,9 +10,15 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
from disagreement.client import Client from disagreement.client import Client
from disagreement.models import TextChannel from disagreement.models import TextChannel
from dotenv import load_dotenv
load_dotenv() try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
if load_dotenv:
load_dotenv()
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "") BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
CHANNEL_ID = os.environ.get("DISCORD_CHANNEL_ID", "") CHANNEL_ID = os.environ.get("DISCORD_CHANNEL_ID", "")

View File

@ -2,14 +2,20 @@
import os import os
import asyncio import asyncio
from dotenv import load_dotenv
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
from disagreement import Client, ui from disagreement import Client, ui
from disagreement.enums import GatewayIntent, TextInputStyle from disagreement.enums import GatewayIntent, TextInputStyle
from disagreement.ext.app_commands.decorators import slash_command from disagreement.ext.app_commands.decorators import slash_command
from disagreement.ext.app_commands.context import AppCommandContext from disagreement.ext.app_commands.context import AppCommandContext
load_dotenv() if load_dotenv:
load_dotenv()
token = os.getenv("DISCORD_BOT_TOKEN", "") token = os.getenv("DISCORD_BOT_TOKEN", "")
application_id = os.getenv("DISCORD_APPLICATION_ID", "") application_id = os.getenv("DISCORD_APPLICATION_ID", "")

View File

@ -3,7 +3,11 @@
import os import os
import sys import sys
from dotenv import load_dotenv try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -11,7 +15,8 @@ from disagreement import Client, GatewayIntent, ui # type: ignore
from disagreement.ext.app_commands.decorators import slash_command from disagreement.ext.app_commands.decorators import slash_command
from disagreement.ext.app_commands.context import AppCommandContext from disagreement.ext.app_commands.context import AppCommandContext
load_dotenv() if load_dotenv:
load_dotenv()
TOKEN = os.getenv("DISCORD_BOT_TOKEN", "") TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
APP_ID = os.getenv("DISCORD_APPLICATION_ID", "") APP_ID = os.getenv("DISCORD_APPLICATION_ID", "")

View File

@ -9,9 +9,15 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import disagreement import disagreement
from dotenv import load_dotenv
load_dotenv() try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
if load_dotenv:
load_dotenv()
TOKEN = os.environ.get("DISCORD_BOT_TOKEN") TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
if not TOKEN: if not TOKEN:

View File

@ -10,11 +10,16 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
from typing import cast from typing import cast
from dotenv import load_dotenv try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
import disagreement import disagreement
load_dotenv() if load_dotenv:
load_dotenv()
_TOKEN = os.getenv("DISCORD_BOT_TOKEN") _TOKEN = os.getenv("DISCORD_BOT_TOKEN")
_GUILD_ID = os.getenv("DISCORD_GUILD_ID") _GUILD_ID = os.getenv("DISCORD_GUILD_ID")

View File

@ -15,3 +15,14 @@ def test_cache_ttl_expiry():
assert cache.get("b") == 1 assert cache.get("b") == 1
time.sleep(0.02) time.sleep(0.02)
assert cache.get("b") is None assert cache.get("b") is None
def test_cache_lru_eviction():
cache = Cache(maxlen=2)
cache.set("a", 1)
cache.set("b", 2)
assert cache.get("a") == 1
cache.set("c", 3)
assert cache.get("b") is None
assert cache.get("a") == 1
assert cache.get("c") == 3

View File

@ -0,0 +1,23 @@
import pytest
from disagreement.client import Client
def _add_message(client: Client, message_id: str) -> None:
data = {
"id": message_id,
"channel_id": "c",
"author": {"id": "u", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
client.parse_message(data)
def test_client_message_cache_size():
client = Client(token="t", message_cache_maxlen=1)
_add_message(client, "1")
assert client._messages.get("1").id == "1"
_add_message(client, "2")
assert client._messages.get("1") is None
assert client._messages.get("2").id == "2"

View File

@ -0,0 +1,18 @@
from disagreement.models import Embed
def test_embed_helper_methods():
embed = (
Embed()
.set_author(name="name", url="url", icon_url="icon")
.add_field(name="n", value="v")
.set_footer(text="footer", icon_url="icon")
.set_image(url="https://example.com/image.png")
)
assert embed.author.name == "name"
assert embed.author.url == "url"
assert embed.author.icon_url == "icon"
assert len(embed.fields) == 1 and embed.fields[0].name == "n"
assert embed.footer.text == "footer"
assert embed.image.url == "https://example.com/image.png"

View File

@ -24,7 +24,7 @@ class DummyDispatcher:
class DummyClient: class DummyClient:
def __init__(self): def __init__(self):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_running_loop()
self.application_id = None # Mock application_id for Client.connect self.application_id = None # Mock application_id for Client.connect
@ -39,7 +39,7 @@ async def test_client_connect_backoff(monkeypatch):
client = Client( client = Client(
token="test_token", token="test_token",
intents=0, intents=0,
loop=asyncio.get_event_loop(), loop=asyncio.get_running_loop(),
command_prefix="!", command_prefix="!",
verbose=False, verbose=False,
mention_replies=False, mention_replies=False,

View File

@ -0,0 +1,23 @@
import types
from disagreement.models import Message
def make_message(content: str) -> Message:
data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": content,
"timestamp": "t",
}
return Message(data, client_instance=types.SimpleNamespace())
def test_clean_content_removes_mentions():
msg = make_message("Hello <@123> <#456> <@&789> world")
assert msg.clean_content == "Hello world"
def test_clean_content_no_mentions():
msg = make_message("Just text")
assert msg.clean_content == "Just text"

View File

@ -2,6 +2,7 @@ import pytest
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from disagreement.client import Client from disagreement.client import Client
from disagreement.models import Game
from disagreement.errors import DisagreementException from disagreement.errors import DisagreementException
@ -18,11 +19,11 @@ class DummyGateway(MagicMock):
async def test_change_presence_passes_arguments(): async def test_change_presence_passes_arguments():
client = Client(token="t") client = Client(token="t")
client._gateway = DummyGateway() client._gateway = DummyGateway()
game = Game("hi")
await client.change_presence(status="idle", activity_name="hi", activity_type=0) await client.change_presence(status="idle", activity=game)
client._gateway.update_presence.assert_awaited_once_with( client._gateway.update_presence.assert_awaited_once_with(
status="idle", activity_name="hi", activity_type=0, since=0, afk=False status="idle", activity=game, since=0, afk=False
) )

View File

@ -1,8 +1,11 @@
import asyncio import asyncio
import io
from array import array
import pytest import pytest
from disagreement.audio import AudioSource, FFmpegAudioSource
from disagreement.voice_client import VoiceClient from disagreement.voice_client import VoiceClient
from disagreement.audio import AudioSource
from disagreement.client import Client from disagreement.client import Client
@ -137,3 +140,68 @@ async def test_play_and_switch_sources():
await vc.play(DummySource([b"c"])) await vc.play(DummySource([b"c"]))
assert udp.sent == [b"a", b"b", b"c"] assert udp.sent == [b"a", b"b", b"c"]
@pytest.mark.asyncio
async def test_ffmpeg_source_custom_options(monkeypatch):
captured = {}
class DummyProcess:
def __init__(self):
self.stdout = io.BytesIO(b"")
async def wait(self):
return 0
async def fake_exec(*args, **kwargs):
captured["args"] = args
return DummyProcess()
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_exec)
src = FFmpegAudioSource(
"file.mp3", before_options="-reconnect 1", options="-vn", volume=0.5
)
await src._spawn()
cmd = captured["args"]
assert "-reconnect" in cmd
assert "-vn" in cmd
assert src.volume == 0.5
@pytest.mark.asyncio
async def test_voice_client_volume_scaling(monkeypatch):
ws = DummyWebSocket(
[
{"d": {"heartbeat_interval": 50}},
{"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}},
{"d": {"secret_key": []}},
]
)
udp = DummyUDP()
vc = VoiceClient(
client=DummyVoiceClient(),
endpoint="ws://localhost",
session_id="sess",
token="tok",
guild_id=1,
user_id=2,
ws=ws,
udp=udp,
)
await vc.connect()
vc._heartbeat_task.cancel()
chunk = b"\x10\x00\x10\x00"
src = DummySource([chunk])
src.volume = 0.5
await vc.play(src)
samples = array("h")
samples.frombytes(chunk)
samples[0] = int(samples[0] * 0.5)
samples[1] = int(samples[1] * 0.5)
expected = samples.tobytes()
assert udp.sent == [expected]

50
tests/test_widget.py Normal file
View File

@ -0,0 +1,50 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.http import HTTPClient
from disagreement.client import Client
@pytest.mark.asyncio
async def test_get_guild_widget_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={})
await http.get_guild_widget("1")
http.request.assert_called_once_with("GET", "/guilds/1/widget")
@pytest.mark.asyncio
async def test_edit_guild_widget_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={})
payload = {"enabled": True}
await http.edit_guild_widget("1", payload)
http.request.assert_called_once_with("PATCH", "/guilds/1/widget", payload=payload)
@pytest.mark.asyncio
async def test_client_fetch_widget_returns_data():
http = SimpleNamespace(get_guild_widget=AsyncMock(return_value={"enabled": True}))
client = Client.__new__(Client)
client._http = http
client._closed = False
data = await client.fetch_widget("1")
http.get_guild_widget.assert_awaited_once_with("1")
assert data == {"enabled": True}
@pytest.mark.asyncio
async def test_client_edit_widget_returns_data():
http = SimpleNamespace(edit_guild_widget=AsyncMock(return_value={"enabled": False}))
client = Client.__new__(Client)
client._http = http
client._closed = False
payload = {"enabled": False}
data = await client.edit_widget("1", payload)
http.edit_guild_widget.assert_awaited_once_with("1", payload)
assert data == {"enabled": False}