This commit is contained in:
Slipstream 2025-06-11 14:52:49 -06:00
commit a702c66603
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
37 changed files with 831 additions and 232 deletions

View File

@ -23,6 +23,13 @@ pip install -e .
Requires Python 3.10 or newer.
To run the example scripts, you'll need the `python-dotenv` package to load
environment variables. Install the development extras with:
```bash
pip install "disagreement[dev]"
```
## Basic Usage
```python
@ -102,6 +109,20 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
session parameter supported by ``aiohttp``.
### Default Allowed Mentions
Specify default mention behaviour for all outgoing messages when constructing the client:
```python
client = disagreement.Client(
token=token,
allowed_mentions={"parse": [], "replied_user": False},
)
```
This dictionary is used whenever ``send_message`` is called without an explicit
``allowed_mentions`` argument.
### Defining Subcommands with `AppCommandGroup`
```python
@ -120,6 +141,7 @@ async def show(ctx: AppCommandContext, key: str):
@slash_command(name="set", description="Update a setting.", parent=admin_group)
async def set_setting(ctx: AppCommandContext, key: str, value: str):
...
```
## Fetching Guilds
Use `Client.fetch_guild` to retrieve a guild from the Discord API if it

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import contextlib
import io
import shlex
from typing import Optional, Union
@ -35,15 +36,27 @@ class FFmpegAudioSource(AudioSource):
A filename, URL, or file-like object to read from.
"""
def __init__(self, source: Union[str, io.BufferedIOBase]):
def __init__(
self,
source: Union[str, io.BufferedIOBase],
*,
before_options: Optional[str] = None,
options: Optional[str] = None,
volume: float = 1.0,
):
self.source = source
self.before_options = before_options
self.options = options
self.volume = volume
self.process: Optional[asyncio.subprocess.Process] = None
self._feeder: Optional[asyncio.Task] = None
async def _spawn(self) -> None:
if isinstance(self.source, str):
args = [
"ffmpeg",
args = ["ffmpeg"]
if self.before_options:
args += shlex.split(self.before_options)
args += [
"-i",
self.source,
"-f",
@ -54,14 +67,18 @@ class FFmpegAudioSource(AudioSource):
"2",
"pipe:1",
]
if self.options:
args += shlex.split(self.options)
self.process = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL,
)
else:
args = [
"ffmpeg",
args = ["ffmpeg"]
if self.before_options:
args += shlex.split(self.before_options)
args += [
"-i",
"pipe:0",
"-f",
@ -72,6 +89,8 @@ class FFmpegAudioSource(AudioSource):
"2",
"pipe:1",
]
if self.options:
args += shlex.split(self.options)
self.process = await asyncio.create_subprocess_exec(
*args,
stdin=asyncio.subprocess.PIPE,
@ -115,6 +134,7 @@ class FFmpegAudioSource(AudioSource):
with contextlib.suppress(Exception):
self.source.close()
class AudioSink:
"""Abstract base class for audio sinks."""

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
from collections import OrderedDict
if TYPE_CHECKING:
from .models import Channel, Guild, Member
@ -11,15 +12,22 @@ T = TypeVar("T")
class Cache(Generic[T]):
"""Simple in-memory cache with optional TTL support."""
"""Simple in-memory cache with optional TTL and max size support."""
def __init__(self, ttl: Optional[float] = None) -> None:
def __init__(
self, ttl: Optional[float] = None, maxlen: Optional[int] = None
) -> None:
self.ttl = ttl
self._data: Dict[str, tuple[T, Optional[float]]] = {}
self.maxlen = maxlen
self._data: "OrderedDict[str, tuple[T, Optional[float]]]" = OrderedDict()
def set(self, key: str, value: T) -> None:
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
if key in self._data:
self._data.move_to_end(key)
self._data[key] = (value, expiry)
if self.maxlen is not None and len(self._data) > self.maxlen:
self._data.popitem(last=False)
def get(self, key: str) -> Optional[T]:
item = self._data.get(key)
@ -29,6 +37,7 @@ class Cache(Generic[T]):
if expiry is not None and expiry < time.monotonic():
self.invalidate(key)
return None
self._data.move_to_end(key)
return value
def invalidate(self, key: str) -> None:

View File

@ -8,10 +8,10 @@ class _MemberCacheFlagValue:
flag: int
def __init__(self, func: Callable[[Any], bool]):
self.flag = getattr(func, 'flag', 0)
self.flag = getattr(func, "flag", 0)
self.__doc__ = func.__doc__
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any:
def __get__(self, instance: "MemberCacheFlags", owner: type) -> Any:
if instance is None:
return self
return instance.value & self.flag != 0
@ -23,23 +23,24 @@ class _MemberCacheFlagValue:
instance.value &= ~self.flag
def __repr__(self) -> str:
return f'<{self.__class__.__name__} flag={self.flag}>'
return f"<{self.__class__.__name__} flag={self.flag}>"
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
setattr(func, 'flag', flag)
setattr(func, "flag", flag)
return _MemberCacheFlagValue(func)
return decorator
class MemberCacheFlags:
__slots__ = ('value',)
__slots__ = ("value",)
VALID_FLAGS: ClassVar[Dict[str, int]] = {
'joined': 1 << 0,
'voice': 1 << 1,
'online': 1 << 2,
"joined": 1 << 0,
"voice": 1 << 1,
"online": 1 << 2,
}
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
@ -48,7 +49,7 @@ class MemberCacheFlags:
self.value = self.DEFAULT_FLAGS
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid member cache flag.')
raise TypeError(f"{key!r} is not a valid member cache flag.")
setattr(self, key, value)
@classmethod
@ -67,7 +68,7 @@ class MemberCacheFlags:
return hash(self.value)
def __repr__(self) -> str:
return f'<MemberCacheFlags value={self.value}>'
return f"<MemberCacheFlags value={self.value}>"
def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name in self.VALID_FLAGS:
@ -92,17 +93,17 @@ class MemberCacheFlags:
@classmethod
def only_joined(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['joined'])
return cls._from_value(cls.VALID_FLAGS["joined"])
@classmethod
def only_voice(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['voice'])
return cls._from_value(cls.VALID_FLAGS["voice"])
@classmethod
def only_online(cls) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
return cls._from_value(cls.VALID_FLAGS['online'])
return cls._from_value(cls.VALID_FLAGS["online"])
@flag_value(1 << 0)
def joined(self) -> bool:

View File

@ -36,6 +36,7 @@ from .ext import loader as ext_loader
from .interactions import Interaction, Snowflake
from .error_handler import setup_global_error_handler
from .voice_client import VoiceClient
from .models import Activity
if TYPE_CHECKING:
from .models import (
@ -75,13 +76,21 @@ class Client:
intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`.
You might need to enable privileged intents in your bot's application page.
loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations.
Defaults to `asyncio.get_event_loop()`.
Defaults to the running loop
via `asyncio.get_running_loop()`,
or a new loop from
`asyncio.new_event_loop()` if
none is running.
command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]):
The prefix(es) for commands. Defaults to '!'.
verbose (bool): If True, print raw HTTP and Gateway traffic for debugging.
mention_replies (bool): Whether replies mention the author by default.
allowed_mentions (Optional[Dict[str, Any]]): Default allowed mentions for messages.
http_options (Optional[Dict[str, Any]]): Extra options passed to
:class:`HTTPClient` for creating the internal
:class:`aiohttp.ClientSession`.
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
in the cache. When ``None``, the cache size is unlimited.
"""
def __init__(
@ -95,10 +104,12 @@ class Client:
application_id: Optional[Union[str, int]] = None,
verbose: bool = False,
mention_replies: bool = False,
allowed_mentions: Optional[Dict[str, Any]] = None,
shard_count: Optional[int] = None,
gateway_max_retries: int = 5,
gateway_max_backoff: float = 60.0,
member_cache_flags: Optional[MemberCacheFlags] = None,
message_cache_maxlen: Optional[int] = None,
http_options: Optional[Dict[str, Any]] = None,
):
if not token:
@ -108,6 +119,7 @@ class Client:
self.member_cache_flags: MemberCacheFlags = (
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
)
self.message_cache_maxlen: Optional[int] = message_cache_maxlen
self.intents: int = intents if intents is not None else GatewayIntent.default()
if loop:
self.loop: asyncio.AbstractEventLoop = loop
@ -157,7 +169,7 @@ class Client:
self._guilds: GuildCache = GuildCache()
self._channels: ChannelCache = ChannelCache()
self._users: Cache["User"] = Cache()
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour
self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen)
self._views: Dict[Snowflake, "View"] = {}
self._persistent_views: Dict[str, "View"] = {}
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
@ -165,6 +177,7 @@ class Client:
# Default whether replies mention the user
self.mention_replies: bool = mention_replies
self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions
# Basic signal handling for graceful shutdown
# This might be better handled by the user's application code, but can be a nice default.
@ -435,8 +448,7 @@ class Client:
async def change_presence(
self,
status: str,
activity_name: Optional[str] = None,
activity_type: int = 0,
activity: Optional[Activity] = None,
since: int = 0,
afk: bool = False,
):
@ -445,8 +457,7 @@ class Client:
Args:
status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible").
activity_name (Optional[str]): The name of the activity.
activity_type (int): The type of the activity.
activity (Optional[Activity]): Activity instance describing what the bot is doing.
since (int): The timestamp (in milliseconds) of when the client went idle.
afk (bool): Whether the client is AFK.
"""
@ -456,8 +467,7 @@ class Client:
if self._gateway:
await self._gateway.update_presence(
status=status,
activity_name=activity_name,
activity_type=activity_type,
activity=activity,
since=since,
afk=afk,
)
@ -693,7 +703,7 @@ class Client:
)
# import traceback
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
async def on_command_completion(self, ctx: "CommandContext") -> None:
"""
Default command completion handler. Called when a command has successfully completed.
@ -1010,7 +1020,7 @@ class Client:
embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`.
Discord supports up to 10 embeds per message.
components (Optional[List[ActionRow]]): A list of ActionRow components to include.
allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message.
allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. Defaults to :attr:`Client.allowed_mentions`.
message_reference (Optional[Dict[str, Any]]): Message reference for replying.
attachments (Optional[List[Any]]): Attachments to include with the message.
files (Optional[List[Any]]): Files to upload with the message.
@ -1057,6 +1067,9 @@ class Client:
if isinstance(comp, ComponentModel)
]
if allowed_mentions is None:
allowed_mentions = self.allowed_mentions
message_data = await self._http.send_message(
channel_id=channel_id,
content=content,
@ -1428,6 +1441,24 @@ class Client:
await self._http.delete_guild_template(guild_id, template_code)
async def fetch_widget(self, guild_id: Snowflake) -> Dict[str, Any]:
"""|coro| Fetch a guild's widget settings."""
if self._closed:
raise DisagreementException("Client is closed.")
return await self._http.get_guild_widget(guild_id)
async def edit_widget(
self, guild_id: Snowflake, payload: Dict[str, Any]
) -> Dict[str, Any]:
"""|coro| Edit a guild's widget settings."""
if self._closed:
raise DisagreementException("Client is closed.")
return await self._http.edit_guild_widget(guild_id, payload)
async def fetch_scheduled_events(
self, guild_id: Snowflake
) -> List["ScheduledEvent"]:
@ -1514,35 +1545,35 @@ class Client:
return [self.parse_invite(inv) for inv in data]
def add_persistent_view(self, view: "View") -> None:
"""
Registers a persistent view with the client.
"""
Registers a persistent view with the client.
Persistent views have a timeout of `None` and their components must have a `custom_id`.
This allows the view to be re-instantiated across bot restarts.
Persistent views have a timeout of `None` and their components must have a `custom_id`.
This allows the view to be re-instantiated across bot restarts.
Args:
view (View): The view instance to register.
Args:
view (View): The view instance to register.
Raises:
ValueError: If the view is not persistent (timeout is not None) or if a component's
custom_id is already registered.
"""
if self.is_ready():
print(
"Warning: Adding a persistent view after the client is ready. "
"This view will only be available for interactions on this session."
)
Raises:
ValueError: If the view is not persistent (timeout is not None) or if a component's
custom_id is already registered.
"""
if self.is_ready():
print(
"Warning: Adding a persistent view after the client is ready. "
"This view will only be available for interactions on this session."
)
if view.timeout is not None:
raise ValueError("Persistent views must have a timeout of None.")
if view.timeout is not None:
raise ValueError("Persistent views must have a timeout of None.")
for item in view.children:
if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views:
raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered."
)
self._persistent_views[item.custom_id] = view
for item in view.children:
if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views:
raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered."
)
self._persistent_views[item.custom_id] = view
# --- Application Command Methods ---
async def process_interaction(self, interaction: Interaction) -> None:

View File

@ -375,6 +375,15 @@ class OverwriteType(IntEnum):
MEMBER = 1
class AutoArchiveDuration(IntEnum):
"""Thread auto-archive duration in minutes."""
HOUR = 60
DAY = 1440
THREE_DAYS = 4320
WEEK = 10080
# --- Component Enums ---

View File

@ -14,7 +14,11 @@ def setup_global_error_handler(
The handler logs unhandled exceptions so they don't crash the bot.
"""
if loop is None:
loop = asyncio.get_event_loop()
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if not logging.getLogger().hasHandlers():
setup_logging(logging.ERROR)

View File

@ -1,8 +1,10 @@
# disagreement/ext/commands/converters.py
# pyright: reportIncompatibleMethodOverride=false
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic
from abc import ABC, abstractmethod
import re
import inspect
from .errors import BadArgument
from disagreement.models import Member, Guild, Role
@ -36,6 +38,20 @@ class Converter(ABC, Generic[T]):
raise NotImplementedError("Converter subclass must implement convert method.")
class Greedy(list):
"""Type hint helper to greedily consume arguments."""
converter: Any = None
def __class_getitem__(cls, param: Any) -> type: # pyright: ignore[override]
if isinstance(param, tuple):
if len(param) != 1:
raise TypeError("Greedy[...] expects a single parameter")
param = param[0]
name = f"Greedy[{getattr(param, '__name__', str(param))}]"
return type(name, (Greedy,), {"converter": param})
# --- Built-in Type Converters ---
@ -169,7 +185,3 @@ async def run_converters(ctx: "CommandContext", annotation: Any, argument: str)
raise BadArgument(f"No converter found for type annotation '{annotation}'.")
return argument # Default to string if no annotation or annotation is str
# Need to import inspect for the run_converters function
import inspect

View File

@ -29,7 +29,7 @@ from .errors import (
CheckFailure,
CommandInvokeError,
)
from .converters import run_converters, DEFAULT_CONVERTERS, Converter
from .converters import Greedy, run_converters, DEFAULT_CONVERTERS, Converter
from disagreement.typing import Typing
logger = logging.getLogger(__name__)
@ -46,29 +46,39 @@ class GroupMixin:
self.commands: Dict[str, "Command"] = {}
self.name: str = ""
def command(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
def command(
self, **attrs: Any
) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
cmd = Command(func, **attrs)
cmd.cog = getattr(self, "cog", None)
self.add_command(cmd)
return cmd
return decorator
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
def group(
self, **attrs: Any
) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
cmd = Group(func, **attrs)
cmd.cog = getattr(self, "cog", None)
self.add_command(cmd)
return cmd
return decorator
def add_command(self, command: "Command") -> None:
if command.name in self.commands:
raise ValueError(f"Command '{command.name}' is already registered in group '{self.name}'.")
raise ValueError(
f"Command '{command.name}' is already registered in group '{self.name}'."
)
self.commands[command.name.lower()] = command
for alias in command.aliases:
if alias in self.commands:
logger.warning(f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias.")
logger.warning(
f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias."
)
self.commands[alias.lower()] = command
def get_command(self, name: str) -> Optional["Command"]:
@ -181,6 +191,7 @@ class Command(GroupMixin):
class Group(Command):
"""A command that can have subcommands."""
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
super().__init__(callback, **attrs)
@ -494,7 +505,34 @@ class CommandHandler:
None # Holds the raw string for current param
)
if view.eof: # No more input string
annotation = param.annotation
if inspect.isclass(annotation) and issubclass(annotation, Greedy):
greedy_values = []
converter_type = annotation.converter
while not view.eof:
view.skip_whitespace()
if view.eof:
break
start = view.index
if view.buffer[view.index] == '"':
arg_str_value = view.get_quoted_string()
if arg_str_value == "" and view.buffer[view.index] == '"':
raise BadArgument(
f"Unterminated quoted string for argument '{param.name}'."
)
else:
arg_str_value = view.get_word()
try:
converted = await run_converters(
ctx, converter_type, arg_str_value
)
except BadArgument:
view.index = start
break
greedy_values.append(converted)
final_value_for_param = greedy_values
arg_str_value = None
elif view.eof: # No more input string
if param.default is not inspect.Parameter.empty:
final_value_for_param = param.default
elif param.kind != inspect.Parameter.VAR_KEYWORD:
@ -656,7 +694,9 @@ class CommandHandler:
elif command.invoke_without_command:
view.index -= len(potential_subcommand) + view.previous
else:
raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.")
raise CommandNotFound(
f"Subcommand '{potential_subcommand}' not found."
)
ctx = CommandContext(
message=message,
@ -681,7 +721,9 @@ class CommandHandler:
if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, e)
except Exception as e:
logger.error("Unexpected error invoking command '%s': %s", original_command.name, e)
logger.error(
"Unexpected error invoking command '%s': %s", original_command.name, e
)
exc = CommandInvokeError(e)
if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, exc)

View File

@ -218,6 +218,7 @@ def requires_permissions(
return check(predicate)
def has_role(
name_or_id: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
@ -241,9 +242,7 @@ def has_role(
raise CheckFailure("Could not resolve author to a guild member.")
# Create a list of the member's role objects by looking them up in the guild's roles list
member_roles = [
role for role in ctx.guild.roles if role.id in author.roles
]
member_roles = [role for role in ctx.guild.roles if role.id in author.roles]
if any(
role.id == str(name_or_id) or role.name == name_or_id
@ -278,9 +277,7 @@ def has_any_role(
if not author:
raise CheckFailure("Could not resolve author to a guild member.")
member_roles = [
role for role in ctx.guild.roles if role.id in author.roles
]
member_roles = [role for role in ctx.guild.roles if role.id in author.roles]
# Convert names_or_ids to a set for efficient lookup
names_or_ids_set = set(map(str, names_or_ids))

View File

@ -14,6 +14,8 @@ import time
import random
from typing import Optional, TYPE_CHECKING, Any, Dict
from .models import Activity
from .enums import GatewayOpcode, GatewayIntent
from .errors import GatewayException, DisagreementException, AuthenticationError
from .interactions import Interaction
@ -63,7 +65,11 @@ class GatewayClient:
self._max_backoff: float = max_backoff
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
try:
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._heartbeat_interval: Optional[float] = None
self._last_sequence: Optional[int] = None
self._session_id: Optional[str] = None
@ -213,26 +219,17 @@ class GatewayClient:
async def update_presence(
self,
status: str,
activity_name: Optional[str] = None,
activity_type: int = 0,
activity: Optional[Activity] = None,
*,
since: int = 0,
afk: bool = False,
):
) -> None:
"""Sends the presence update payload to the Gateway."""
payload = {
"op": GatewayOpcode.PRESENCE_UPDATE,
"d": {
"since": since,
"activities": (
[
{
"name": activity_name,
"type": activity_type,
}
]
if activity_name
else []
),
"activities": [activity.to_dict()] if activity else [],
"status": status,
"afk": afk,
},
@ -353,7 +350,10 @@ class GatewayClient:
future._members.extend(raw_event_d_payload.get("members", [])) # type: ignore
# If this is the last chunk, resolve the future
if raw_event_d_payload.get("chunk_index") == raw_event_d_payload.get("chunk_count", 1) - 1:
if (
raw_event_d_payload.get("chunk_index")
== raw_event_d_payload.get("chunk_count", 1) - 1
):
future.set_result(future._members) # type: ignore
del self._member_chunk_requests[nonce]

View File

@ -601,18 +601,18 @@ class HTTPClient:
)
async def delete_user_reaction(
self,
channel_id: "Snowflake",
message_id: "Snowflake",
emoji: str,
user_id: "Snowflake",
) -> None:
"""Removes another user's reaction from a message."""
encoded = quote(emoji)
await self.request(
"DELETE",
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}",
)
self,
channel_id: "Snowflake",
message_id: "Snowflake",
emoji: str,
user_id: "Snowflake",
) -> None:
"""Removes another user's reaction from a message."""
encoded = quote(emoji)
await self.request(
"DELETE",
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}",
)
async def get_reactions(
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
@ -910,6 +910,20 @@ class HTTPClient:
"""Fetches a guild object for a given guild ID."""
return await self.request("GET", f"/guilds/{guild_id}")
async def get_guild_widget(self, guild_id: "Snowflake") -> Dict[str, Any]:
"""Fetches the guild widget settings."""
return await self.request("GET", f"/guilds/{guild_id}/widget")
async def edit_guild_widget(
self, guild_id: "Snowflake", payload: Dict[str, Any]
) -> Dict[str, Any]:
"""Edits the guild widget settings."""
return await self.request(
"PATCH", f"/guilds/{guild_id}/widget", payload=payload
)
async def get_guild_templates(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
"""Fetches all templates for the given guild."""
return await self.request("GET", f"/guilds/{guild_id}/templates")

View File

@ -6,6 +6,7 @@ Data models for Discord objects.
import asyncio
import json
import re
from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
@ -24,6 +25,7 @@ from .enums import ( # These enums will need to be defined in disagreement/enum
PremiumTier,
GuildFeature,
ChannelType,
AutoArchiveDuration,
ComponentType,
ButtonStyle, # Added for Button
GuildScheduledEventPrivacyLevel,
@ -39,6 +41,7 @@ if TYPE_CHECKING:
from .enums import OverwriteType # For PermissionOverwrite model
from .ui.view import View
from .interactions import Snowflake
from .typing import Typing
# Forward reference Message if it were used in type hints before its definition
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
@ -114,31 +117,39 @@ class Message:
# self.mention_roles: List[str] = data.get("mention_roles", [])
# self.mention_everyone: bool = data.get("mention_everyone", False)
@property
def clean_content(self) -> str:
"""Returns message content without user, role, or channel mentions."""
pattern = re.compile(r"<@!?\d+>|<#\d+>|<@&\d+>")
cleaned = pattern.sub("", self.content)
return " ".join(cleaned.split())
async def pin(self) -> None:
"""|coro|
"""|coro|
Pins this message to its channel.
Pins this message to its channel.
Raises
------
HTTPException
Pinning the message failed.
"""
await self._client._http.pin_message(self.channel_id, self.id)
self.pinned = True
Raises
------
HTTPException
Pinning the message failed.
"""
await self._client._http.pin_message(self.channel_id, self.id)
self.pinned = True
async def unpin(self) -> None:
"""|coro|
"""|coro|
Unpins this message from its channel.
Unpins this message from its channel.
Raises
------
HTTPException
Unpinning the message failed.
"""
await self._client._http.unpin_message(self.channel_id, self.id)
self.pinned = False
Raises
------
HTTPException
Unpinning the message failed.
"""
await self._client._http.unpin_message(self.channel_id, self.id)
self.pinned = False
async def reply(
self,
@ -241,16 +252,16 @@ class Message:
await self._client.add_reaction(self.channel_id, self.id, emoji)
async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None:
"""|coro|
Removes a reaction from this message.
If no ``member`` is provided, removes the bot's own reaction.
"""
if member:
await self._client._http.delete_user_reaction(
self.channel_id, self.id, emoji, member.id
)
else:
await self._client.remove_reaction(self.channel_id, self.id, emoji)
"""|coro|
Removes a reaction from this message.
If no ``member`` is provided, removes the bot's own reaction.
"""
if member:
await self._client._http.delete_user_reaction(
self.channel_id, self.id, emoji, member.id
)
else:
await self._client.remove_reaction(self.channel_id, self.id, emoji)
async def clear_reactions(self) -> None:
"""|coro| Remove all reactions from this message."""
@ -280,7 +291,7 @@ class Message:
self,
name: str,
*,
auto_archive_duration: Optional[int] = None,
auto_archive_duration: Optional[AutoArchiveDuration] = None,
rate_limit_per_user: Optional[int] = None,
reason: Optional[str] = None,
) -> "Thread":
@ -292,9 +303,9 @@ class Message:
----------
name: str
The name of the thread.
auto_archive_duration: Optional[int]
The duration in minutes to automatically archive the thread after recent activity.
Can be one of 60, 1440, 4320, 10080.
auto_archive_duration: Optional[AutoArchiveDuration]
How long before the thread is automatically archived after recent activity.
See :class:`AutoArchiveDuration` for allowed values.
rate_limit_per_user: Optional[int]
The number of seconds a user has to wait before sending another message.
reason: Optional[str]
@ -307,7 +318,7 @@ class Message:
"""
payload: Dict[str, Any] = {"name": name}
if auto_archive_duration is not None:
payload["auto_archive_duration"] = auto_archive_duration
payload["auto_archive_duration"] = int(auto_archive_duration)
if rate_limit_per_user is not None:
payload["rate_limit_per_user"] = rate_limit_per_user
@ -530,8 +541,42 @@ class Embed:
payload["fields"] = [f.to_dict() for f in self.fields]
return payload
# Convenience methods for building embeds can be added here
# e.g., set_author, add_field, set_footer, set_image, etc.
# Convenience methods mirroring ``discord.py``'s ``Embed`` API
def set_author(
self, *, name: str, url: Optional[str] = None, icon_url: Optional[str] = None
) -> "Embed":
"""Set the embed author and return ``self`` for chaining."""
data: Dict[str, Any] = {"name": name}
if url:
data["url"] = url
if icon_url:
data["icon_url"] = icon_url
self.author = EmbedAuthor(data)
return self
def add_field(self, *, name: str, value: str, inline: bool = False) -> "Embed":
"""Add a field to the embed."""
field = EmbedField({"name": name, "value": value, "inline": inline})
self.fields.append(field)
return self
def set_footer(self, *, text: str, icon_url: Optional[str] = None) -> "Embed":
"""Set the embed footer."""
data: Dict[str, Any] = {"text": text}
if icon_url:
data["icon_url"] = icon_url
self.footer = EmbedFooter(data)
return self
def set_image(self, url: str) -> "Embed":
"""Set the embed image."""
self.image = EmbedImage({"url": url})
return self
class Attachment:
@ -1088,7 +1133,9 @@ class Guild:
# Internal caches, populated by events or specific fetches
self._channels: ChannelCache = ChannelCache()
self._members: MemberCache = MemberCache(getattr(client_instance, "member_cache_flags", MemberCacheFlags()))
self._members: MemberCache = MemberCache(
getattr(client_instance, "member_cache_flags", MemberCacheFlags())
)
self._threads: Dict[str, "Thread"] = {}
def get_channel(self, channel_id: str) -> Optional["Channel"]:
@ -1128,6 +1175,16 @@ class Guild:
def __repr__(self) -> str:
return f"<Guild id='{self.id}' name='{self.name}'>"
async def fetch_widget(self) -> Dict[str, Any]:
"""|coro| Fetch this guild's widget settings."""
return await self._client.fetch_widget(self.id)
async def edit_widget(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""|coro| Edit this guild's widget settings."""
return await self._client.edit_widget(self.id, payload)
async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]:
"""|coro|
@ -1278,7 +1335,45 @@ class Channel:
return base
class TextChannel(Channel):
class Messageable:
"""Mixin for channels that can send messages and show typing."""
_client: "Client"
id: str
async def send(
self,
content: Optional[str] = None,
*,
embed: Optional["Embed"] = None,
embeds: Optional[List["Embed"]] = None,
components: Optional[List["ActionRow"]] = None,
) -> "Message":
if not hasattr(self._client, "send_message"):
raise NotImplementedError(
"Client.send_message is required for Messageable.send"
)
return await self._client.send_message(
channel_id=self.id,
content=content,
embed=embed,
embeds=embeds,
components=components,
)
async def trigger_typing(self) -> None:
await self._client._http.trigger_typing(self.id)
def typing(self) -> "Typing":
if not hasattr(self._client, "typing"):
raise NotImplementedError(
"Client.typing is required for Messageable.typing"
)
return self._client.typing(self.id)
class TextChannel(Channel, Messageable):
"""Represents a guild text channel or announcement channel."""
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
@ -1304,27 +1399,6 @@ class TextChannel(Channel):
return message_pager(self, limit=limit, before=before, after=after)
async def send(
self,
content: Optional[str] = None,
*,
embed: Optional[Embed] = None,
embeds: Optional[List[Embed]] = None,
components: Optional[List["ActionRow"]] = None, # Added components
) -> "Message": # Forward reference Message
if not hasattr(self._client, "send_message"):
raise NotImplementedError(
"Client.send_message is required for TextChannel.send"
)
return await self._client.send_message(
channel_id=self.id,
content=content,
embed=embed,
embeds=embeds,
components=components,
)
async def purge(
self, limit: int, *, before: "Snowflake | None" = None
) -> List["Snowflake"]:
@ -1347,41 +1421,41 @@ class TextChannel(Channel):
return ids
def get_partial_message(self, id: int) -> "PartialMessage":
"""Returns a :class:`PartialMessage` for the given ID.
"""Returns a :class:`PartialMessage` for the given ID.
This allows performing actions on a message without fetching it first.
This allows performing actions on a message without fetching it first.
Parameters
----------
id: int
The ID of the message to get a partial instance of.
Parameters
----------
id: int
The ID of the message to get a partial instance of.
Returns
-------
PartialMessage
The partial message instance.
"""
return PartialMessage(id=str(id), channel=self)
Returns
-------
PartialMessage
The partial message instance.
"""
return PartialMessage(id=str(id), channel=self)
def __repr__(self) -> str:
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
async def pins(self) -> List["Message"]:
"""|coro|
Fetches all pinned messages in this channel.
Returns
-------
List[Message]
The pinned messages.
Raises
------
HTTPException
Fetching the pinned messages failed.
"""
messages_data = await self._client._http.get_pinned_messages(self.id)
return [self._client.parse_message(m) for m in messages_data]
@ -1390,7 +1464,7 @@ class TextChannel(Channel):
name: str,
*,
type: ChannelType = ChannelType.PUBLIC_THREAD,
auto_archive_duration: Optional[int] = None,
auto_archive_duration: Optional[AutoArchiveDuration] = None,
invitable: Optional[bool] = None,
rate_limit_per_user: Optional[int] = None,
reason: Optional[str] = None,
@ -1406,8 +1480,8 @@ class TextChannel(Channel):
type: ChannelType
The type of thread to create. Defaults to PUBLIC_THREAD.
Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD.
auto_archive_duration: Optional[int]
The duration in minutes to automatically archive the thread after recent activity.
auto_archive_duration: Optional[AutoArchiveDuration]
How long before the thread is automatically archived after recent activity.
invitable: Optional[bool]
Whether non-moderators can invite other non-moderators to a private thread.
Only applicable to private threads.
@ -1426,7 +1500,7 @@ class TextChannel(Channel):
"type": type.value,
}
if auto_archive_duration is not None:
payload["auto_archive_duration"] = auto_archive_duration
payload["auto_archive_duration"] = int(auto_archive_duration)
if invitable is not None and type == ChannelType.PRIVATE_THREAD:
payload["invitable"] = invitable
if rate_limit_per_user is not None:
@ -1606,7 +1680,9 @@ class Thread(TextChannel): # Threads are a specialized TextChannel
"""
await self._client._http.leave_thread(self.id)
async def archive(self, locked: bool = False, *, reason: Optional[str] = None) -> "Thread":
async def archive(
self, locked: bool = False, *, reason: Optional[str] = None
) -> "Thread":
"""|coro|
Archives this thread.
@ -1631,7 +1707,7 @@ class Thread(TextChannel): # Threads are a specialized TextChannel
return cast("Thread", self._client.parse_channel(data))
class DMChannel(Channel):
class DMChannel(Channel, Messageable):
"""Represents a Direct Message channel."""
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
@ -1645,27 +1721,6 @@ class DMChannel(Channel):
def recipient(self) -> Optional[User]:
return self.recipients[0] if self.recipients else None
async def send(
self,
content: Optional[str] = None,
*,
embed: Optional[Embed] = None,
embeds: Optional[List[Embed]] = None,
components: Optional[List["ActionRow"]] = None, # Added components
) -> "Message":
if not hasattr(self._client, "send_message"):
raise NotImplementedError(
"Client.send_message is required for DMChannel.send"
)
return await self._client.send_message(
channel_id=self.id,
content=content,
embed=embed,
embeds=embeds,
components=components,
)
async def history(
self,
*,
@ -2356,6 +2411,37 @@ class ThreadMember:
return f"<ThreadMember user_id='{self.user_id}' thread_id='{self.id}'>"
class Activity:
"""Represents a user's presence activity."""
def __init__(self, name: str, type: int) -> None:
self.name = name
self.type = type
def to_dict(self) -> Dict[str, Any]:
return {"name": self.name, "type": self.type}
class Game(Activity):
"""Represents a playing activity."""
def __init__(self, name: str) -> None:
super().__init__(name, 0)
class Streaming(Activity):
"""Represents a streaming activity."""
def __init__(self, name: str, url: str) -> None:
super().__init__(name, 1)
self.url = url
def to_dict(self) -> Dict[str, Any]:
payload = super().to_dict()
payload["url"] = self.url
return payload
class PresenceUpdate:
"""Represents a PRESENCE_UPDATE event."""
@ -2366,7 +2452,17 @@ class PresenceUpdate:
self.user = User(data["user"])
self.guild_id: Optional[str] = data.get("guild_id")
self.status: Optional[str] = data.get("status")
self.activities: List[Dict[str, Any]] = data.get("activities", [])
self.activities: List[Activity] = []
for activity in data.get("activities", []):
act_type = activity.get("type", 0)
name = activity.get("name", "")
if act_type == 0:
obj = Game(name)
elif act_type == 1:
obj = Streaming(name, activity.get("url", ""))
else:
obj = Activity(name, act_type)
self.activities.append(obj)
self.client_status: Dict[str, Any] = data.get("client_status", {})
def __repr__(self) -> str:

View File

@ -72,7 +72,7 @@ class View:
rows: List[ActionRow] = []
for item in self.children:
rows.append(ActionRow(components=[item]))
rows.append(ActionRow(components=[item]))
return rows

View File

@ -7,9 +7,26 @@ import asyncio
import contextlib
import socket
import threading
from array import array
def _apply_volume(data: bytes, volume: float) -> bytes:
samples = array("h")
samples.frombytes(data)
for i, sample in enumerate(samples):
scaled = int(sample * volume)
if scaled > 32767:
scaled = 32767
elif scaled < -32768:
scaled = -32768
samples[i] = scaled
return samples.tobytes()
from typing import TYPE_CHECKING, Optional, Sequence
import aiohttp
# The following import is correct, but may be flagged by Pylance if the virtual
# environment is not configured correctly.
from nacl.secret import SecretBox
@ -180,6 +197,9 @@ class VoiceClient:
data = await self._current_source.read()
if not data:
break
volume = getattr(self._current_source, "volume", 1.0)
if volume != 1.0:
data = _apply_volume(data, volume)
await self.send_audio_frame(data)
finally:
await self._current_source.close()

22
docs/embeds.md Normal file
View File

@ -0,0 +1,22 @@
# Embeds
`Embed` objects can be constructed piece by piece much like in `discord.py`.
These helper methods return the embed instance so you can chain calls.
```python
from disagreement.models import Embed
embed = (
Embed()
.set_author(name="Disagreement", url="https://example.com", icon_url="https://cdn.example.com/bot.png")
.add_field(name="Info", value="Some details")
.set_footer(text="Made with Disagreement")
.set_image(url="https://cdn.example.com/image.png")
)
```
Call `to_dict()` to convert the embed back to a payload dictionary before sending:
```python
payload = embed.to_dict()
```

23
docs/mentions.md Normal file
View File

@ -0,0 +1,23 @@
# Controlling Mentions
The client exposes settings to control how mentions behave in outgoing messages.
## Default Allowed Mentions
Use the ``allowed_mentions`` parameter of :class:`disagreement.Client` to set a
default for all messages:
```python
client = disagreement.Client(
token="YOUR_TOKEN",
allowed_mentions={"parse": [], "replied_user": False},
)
```
When ``Client.send_message`` is called without an explicit ``allowed_mentions``
argument this value will be used.
## Next Steps
- [Commands](commands.md)
- [HTTP Client Options](http_client.md)

View File

@ -1,6 +1,7 @@
# Updating Presence
The `Client.change_presence` method allows you to update the bot's status and displayed activity.
Pass an :class:`~disagreement.models.Activity` (such as :class:`~disagreement.models.Game` or :class:`~disagreement.models.Streaming`) to describe what your bot is doing.
## Status Strings
@ -22,8 +23,18 @@ An activity dictionary must include a `name` and a `type` field. The type value
| `4` | Custom |
| `5` | Competing |
Example:
Example using the provided activity classes:
```python
await client.change_presence(status="idle", activity={"name": "with Discord", "type": 0})
from disagreement.models import Game
await client.change_presence(status="idle", activity=Game("with Discord"))
```
You can also specify a streaming URL:
```python
from disagreement.models import Streaming
await client.change_presence(status="online", activity=Streaming("My Stream", "https://twitch.tv/someone"))
```

18
docs/threads.md Normal file
View File

@ -0,0 +1,18 @@
# Threads
`Message.create_thread` and `TextChannel.create_thread` let you start new threads.
Use :class:`AutoArchiveDuration` to control when a thread is automatically archived.
```python
from disagreement.enums import AutoArchiveDuration
await message.create_thread(
"discussion",
auto_archive_duration=AutoArchiveDuration.DAY,
)
```
## Next Steps
- [Message History](message_history.md)
- [Caching](caching.md)

View File

@ -39,9 +39,14 @@ except ImportError:
)
sys.exit(1)
from dotenv import load_dotenv
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
load_dotenv()
if load_dotenv:
load_dotenv()
# Optional: Configure logging for more insight, especially for gateway events
# logging.basicConfig(level=logging.DEBUG) # For very verbose output

View File

@ -37,9 +37,15 @@ from disagreement.interactions import (
InteractionResponsePayload,
InteractionCallbackData,
)
from dotenv import load_dotenv
load_dotenv()
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
if load_dotenv:
load_dotenv()
# Get the bot token and application ID from the environment variables
token = os.getenv("DISCORD_BOT_TOKEN")

View File

@ -15,9 +15,14 @@ from disagreement.ext.app_commands import (
)
from disagreement.models import User, Message
from dotenv import load_dotenv
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
load_dotenv()
if load_dotenv:
load_dotenv()
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "")

View File

@ -4,7 +4,11 @@ import asyncio
import os
import sys
from dotenv import load_dotenv
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
# Allow running from the examples folder without installing
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
@ -12,7 +16,8 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
from disagreement import Client
load_dotenv()
if load_dotenv:
load_dotenv()
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")

View File

@ -36,9 +36,14 @@ from disagreement.enums import (
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
load_dotenv()
if load_dotenv:
load_dotenv()
# --- Define a Test Cog ---

View File

@ -10,9 +10,15 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
from disagreement.client import Client
from disagreement.models import TextChannel
from dotenv import load_dotenv
load_dotenv()
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
if load_dotenv:
load_dotenv()
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
CHANNEL_ID = os.environ.get("DISCORD_CHANNEL_ID", "")

View File

@ -2,14 +2,20 @@
import os
import asyncio
from dotenv import load_dotenv
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
from disagreement import Client, ui
from disagreement.enums import GatewayIntent, TextInputStyle
from disagreement.ext.app_commands.decorators import slash_command
from disagreement.ext.app_commands.context import AppCommandContext
load_dotenv()
if load_dotenv:
load_dotenv()
token = os.getenv("DISCORD_BOT_TOKEN", "")
application_id = os.getenv("DISCORD_APPLICATION_ID", "")

View File

@ -3,7 +3,11 @@
import os
import sys
from dotenv import load_dotenv
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -11,7 +15,8 @@ from disagreement import Client, GatewayIntent, ui # type: ignore
from disagreement.ext.app_commands.decorators import slash_command
from disagreement.ext.app_commands.context import AppCommandContext
load_dotenv()
if load_dotenv:
load_dotenv()
TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
APP_ID = os.getenv("DISCORD_APPLICATION_ID", "")

View File

@ -9,9 +9,15 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import disagreement
from dotenv import load_dotenv
load_dotenv()
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
if load_dotenv:
load_dotenv()
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
if not TOKEN:

View File

@ -10,11 +10,16 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
from typing import cast
from dotenv import load_dotenv
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - example helper
load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded")
import disagreement
load_dotenv()
if load_dotenv:
load_dotenv()
_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
_GUILD_ID = os.getenv("DISCORD_GUILD_ID")

View File

@ -15,3 +15,14 @@ def test_cache_ttl_expiry():
assert cache.get("b") == 1
time.sleep(0.02)
assert cache.get("b") is None
def test_cache_lru_eviction():
cache = Cache(maxlen=2)
cache.set("a", 1)
cache.set("b", 2)
assert cache.get("a") == 1
cache.set("c", 3)
assert cache.get("b") is None
assert cache.get("a") == 1
assert cache.get("c") == 3

View File

@ -0,0 +1,23 @@
import pytest
from disagreement.client import Client
def _add_message(client: Client, message_id: str) -> None:
data = {
"id": message_id,
"channel_id": "c",
"author": {"id": "u", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
client.parse_message(data)
def test_client_message_cache_size():
client = Client(token="t", message_cache_maxlen=1)
_add_message(client, "1")
assert client._messages.get("1").id == "1"
_add_message(client, "2")
assert client._messages.get("1") is None
assert client._messages.get("2").id == "2"

View File

@ -0,0 +1,18 @@
from disagreement.models import Embed
def test_embed_helper_methods():
embed = (
Embed()
.set_author(name="name", url="url", icon_url="icon")
.add_field(name="n", value="v")
.set_footer(text="footer", icon_url="icon")
.set_image(url="https://example.com/image.png")
)
assert embed.author.name == "name"
assert embed.author.url == "url"
assert embed.author.icon_url == "icon"
assert len(embed.fields) == 1 and embed.fields[0].name == "n"
assert embed.footer.text == "footer"
assert embed.image.url == "https://example.com/image.png"

View File

@ -24,7 +24,7 @@ class DummyDispatcher:
class DummyClient:
def __init__(self):
self.loop = asyncio.get_event_loop()
self.loop = asyncio.get_running_loop()
self.application_id = None # Mock application_id for Client.connect
@ -39,7 +39,7 @@ async def test_client_connect_backoff(monkeypatch):
client = Client(
token="test_token",
intents=0,
loop=asyncio.get_event_loop(),
loop=asyncio.get_running_loop(),
command_prefix="!",
verbose=False,
mention_replies=False,

View File

@ -0,0 +1,23 @@
import types
from disagreement.models import Message
def make_message(content: str) -> Message:
data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": content,
"timestamp": "t",
}
return Message(data, client_instance=types.SimpleNamespace())
def test_clean_content_removes_mentions():
msg = make_message("Hello <@123> <#456> <@&789> world")
assert msg.clean_content == "Hello world"
def test_clean_content_no_mentions():
msg = make_message("Just text")
assert msg.clean_content == "Just text"

View File

@ -2,6 +2,7 @@ import pytest
from unittest.mock import AsyncMock
from disagreement.client import Client
from disagreement.models import Game
from disagreement.errors import DisagreementException
@ -18,11 +19,11 @@ class DummyGateway(MagicMock):
async def test_change_presence_passes_arguments():
client = Client(token="t")
client._gateway = DummyGateway()
await client.change_presence(status="idle", activity_name="hi", activity_type=0)
game = Game("hi")
await client.change_presence(status="idle", activity=game)
client._gateway.update_presence.assert_awaited_once_with(
status="idle", activity_name="hi", activity_type=0, since=0, afk=False
status="idle", activity=game, since=0, afk=False
)

View File

@ -1,8 +1,11 @@
import asyncio
import io
from array import array
import pytest
from disagreement.audio import AudioSource, FFmpegAudioSource
from disagreement.voice_client import VoiceClient
from disagreement.audio import AudioSource
from disagreement.client import Client
@ -137,3 +140,68 @@ async def test_play_and_switch_sources():
await vc.play(DummySource([b"c"]))
assert udp.sent == [b"a", b"b", b"c"]
@pytest.mark.asyncio
async def test_ffmpeg_source_custom_options(monkeypatch):
captured = {}
class DummyProcess:
def __init__(self):
self.stdout = io.BytesIO(b"")
async def wait(self):
return 0
async def fake_exec(*args, **kwargs):
captured["args"] = args
return DummyProcess()
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_exec)
src = FFmpegAudioSource(
"file.mp3", before_options="-reconnect 1", options="-vn", volume=0.5
)
await src._spawn()
cmd = captured["args"]
assert "-reconnect" in cmd
assert "-vn" in cmd
assert src.volume == 0.5
@pytest.mark.asyncio
async def test_voice_client_volume_scaling(monkeypatch):
ws = DummyWebSocket(
[
{"d": {"heartbeat_interval": 50}},
{"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}},
{"d": {"secret_key": []}},
]
)
udp = DummyUDP()
vc = VoiceClient(
client=DummyVoiceClient(),
endpoint="ws://localhost",
session_id="sess",
token="tok",
guild_id=1,
user_id=2,
ws=ws,
udp=udp,
)
await vc.connect()
vc._heartbeat_task.cancel()
chunk = b"\x10\x00\x10\x00"
src = DummySource([chunk])
src.volume = 0.5
await vc.play(src)
samples = array("h")
samples.frombytes(chunk)
samples[0] = int(samples[0] * 0.5)
samples[1] = int(samples[1] * 0.5)
expected = samples.tobytes()
assert udp.sent == [expected]

50
tests/test_widget.py Normal file
View File

@ -0,0 +1,50 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.http import HTTPClient
from disagreement.client import Client
@pytest.mark.asyncio
async def test_get_guild_widget_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={})
await http.get_guild_widget("1")
http.request.assert_called_once_with("GET", "/guilds/1/widget")
@pytest.mark.asyncio
async def test_edit_guild_widget_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={})
payload = {"enabled": True}
await http.edit_guild_widget("1", payload)
http.request.assert_called_once_with("PATCH", "/guilds/1/widget", payload=payload)
@pytest.mark.asyncio
async def test_client_fetch_widget_returns_data():
http = SimpleNamespace(get_guild_widget=AsyncMock(return_value={"enabled": True}))
client = Client.__new__(Client)
client._http = http
client._closed = False
data = await client.fetch_widget("1")
http.get_guild_widget.assert_awaited_once_with("1")
assert data == {"enabled": True}
@pytest.mark.asyncio
async def test_client_edit_widget_returns_data():
http = SimpleNamespace(edit_guild_widget=AsyncMock(return_value={"enabled": False}))
client = Client.__new__(Client)
client._http = http
client._closed = False
payload = {"enabled": False}
data = await client.edit_widget("1", payload)
http.edit_guild_widget.assert_awaited_once_with("1", payload)
assert data == {"enabled": False}