diff --git a/README.md b/README.md index df44da5..a99e865 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,13 @@ pip install -e . Requires Python 3.10 or newer. +To run the example scripts, you'll need the `python-dotenv` package to load +environment variables. Install the development extras with: + +```bash +pip install "disagreement[dev]" +``` + ## Basic Usage ```python @@ -102,6 +109,20 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying ``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other session parameter supported by ``aiohttp``. +### Default Allowed Mentions + +Specify default mention behaviour for all outgoing messages when constructing the client: + +```python +client = disagreement.Client( + token=token, + allowed_mentions={"parse": [], "replied_user": False}, +) +``` + +This dictionary is used whenever ``send_message`` is called without an explicit +``allowed_mentions`` argument. + ### Defining Subcommands with `AppCommandGroup` ```python @@ -120,6 +141,7 @@ async def show(ctx: AppCommandContext, key: str): @slash_command(name="set", description="Update a setting.", parent=admin_group) async def set_setting(ctx: AppCommandContext, key: str, value: str): ... +``` ## Fetching Guilds Use `Client.fetch_guild` to retrieve a guild from the Discord API if it diff --git a/disagreement/audio.py b/disagreement/audio.py index 9e58530..cd1eadf 100644 --- a/disagreement/audio.py +++ b/disagreement/audio.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio import contextlib import io +import shlex from typing import Optional, Union @@ -35,15 +36,27 @@ class FFmpegAudioSource(AudioSource): A filename, URL, or file-like object to read from. """ - def __init__(self, source: Union[str, io.BufferedIOBase]): + def __init__( + self, + source: Union[str, io.BufferedIOBase], + *, + before_options: Optional[str] = None, + options: Optional[str] = None, + volume: float = 1.0, + ): self.source = source + self.before_options = before_options + self.options = options + self.volume = volume self.process: Optional[asyncio.subprocess.Process] = None self._feeder: Optional[asyncio.Task] = None async def _spawn(self) -> None: if isinstance(self.source, str): - args = [ - "ffmpeg", + args = ["ffmpeg"] + if self.before_options: + args += shlex.split(self.before_options) + args += [ "-i", self.source, "-f", @@ -54,14 +67,18 @@ class FFmpegAudioSource(AudioSource): "2", "pipe:1", ] + if self.options: + args += shlex.split(self.options) self.process = await asyncio.create_subprocess_exec( *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL, ) else: - args = [ - "ffmpeg", + args = ["ffmpeg"] + if self.before_options: + args += shlex.split(self.before_options) + args += [ "-i", "pipe:0", "-f", @@ -72,6 +89,8 @@ class FFmpegAudioSource(AudioSource): "2", "pipe:1", ] + if self.options: + args += shlex.split(self.options) self.process = await asyncio.create_subprocess_exec( *args, stdin=asyncio.subprocess.PIPE, @@ -115,6 +134,7 @@ class FFmpegAudioSource(AudioSource): with contextlib.suppress(Exception): self.source.close() + class AudioSink: """Abstract base class for audio sinks.""" diff --git a/disagreement/cache.py b/disagreement/cache.py index 178e8ae..92eef02 100644 --- a/disagreement/cache.py +++ b/disagreement/cache.py @@ -2,6 +2,7 @@ from __future__ import annotations import time from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar +from collections import OrderedDict if TYPE_CHECKING: from .models import Channel, Guild, Member @@ -11,15 +12,22 @@ T = TypeVar("T") class Cache(Generic[T]): - """Simple in-memory cache with optional TTL support.""" + """Simple in-memory cache with optional TTL and max size support.""" - def __init__(self, ttl: Optional[float] = None) -> None: + def __init__( + self, ttl: Optional[float] = None, maxlen: Optional[int] = None + ) -> None: self.ttl = ttl - self._data: Dict[str, tuple[T, Optional[float]]] = {} + self.maxlen = maxlen + self._data: "OrderedDict[str, tuple[T, Optional[float]]]" = OrderedDict() def set(self, key: str, value: T) -> None: expiry = time.monotonic() + self.ttl if self.ttl is not None else None + if key in self._data: + self._data.move_to_end(key) self._data[key] = (value, expiry) + if self.maxlen is not None and len(self._data) > self.maxlen: + self._data.popitem(last=False) def get(self, key: str) -> Optional[T]: item = self._data.get(key) @@ -29,6 +37,7 @@ class Cache(Generic[T]): if expiry is not None and expiry < time.monotonic(): self.invalidate(key) return None + self._data.move_to_end(key) return value def invalidate(self, key: str) -> None: diff --git a/disagreement/caching.py b/disagreement/caching.py index f1205e4..e9a481f 100644 --- a/disagreement/caching.py +++ b/disagreement/caching.py @@ -8,10 +8,10 @@ class _MemberCacheFlagValue: flag: int def __init__(self, func: Callable[[Any], bool]): - self.flag = getattr(func, 'flag', 0) + self.flag = getattr(func, "flag", 0) self.__doc__ = func.__doc__ - def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any: + def __get__(self, instance: "MemberCacheFlags", owner: type) -> Any: if instance is None: return self return instance.value & self.flag != 0 @@ -23,23 +23,24 @@ class _MemberCacheFlagValue: instance.value &= ~self.flag def __repr__(self) -> str: - return f'<{self.__class__.__name__} flag={self.flag}>' + return f"<{self.__class__.__name__} flag={self.flag}>" def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]: def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue: - setattr(func, 'flag', flag) + setattr(func, "flag", flag) return _MemberCacheFlagValue(func) + return decorator class MemberCacheFlags: - __slots__ = ('value',) + __slots__ = ("value",) VALID_FLAGS: ClassVar[Dict[str, int]] = { - 'joined': 1 << 0, - 'voice': 1 << 1, - 'online': 1 << 2, + "joined": 1 << 0, + "voice": 1 << 1, + "online": 1 << 2, } DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4 ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values()) @@ -48,7 +49,7 @@ class MemberCacheFlags: self.value = self.DEFAULT_FLAGS for key, value in kwargs.items(): if key not in self.VALID_FLAGS: - raise TypeError(f'{key!r} is not a valid member cache flag.') + raise TypeError(f"{key!r} is not a valid member cache flag.") setattr(self, key, value) @classmethod @@ -67,7 +68,7 @@ class MemberCacheFlags: return hash(self.value) def __repr__(self) -> str: - return f'' + return f"" def __iter__(self) -> Iterator[Tuple[str, bool]]: for name in self.VALID_FLAGS: @@ -92,17 +93,17 @@ class MemberCacheFlags: @classmethod def only_joined(cls) -> MemberCacheFlags: """A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled.""" - return cls._from_value(cls.VALID_FLAGS['joined']) + return cls._from_value(cls.VALID_FLAGS["joined"]) @classmethod def only_voice(cls) -> MemberCacheFlags: """A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled.""" - return cls._from_value(cls.VALID_FLAGS['voice']) + return cls._from_value(cls.VALID_FLAGS["voice"]) @classmethod def only_online(cls) -> MemberCacheFlags: """A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled.""" - return cls._from_value(cls.VALID_FLAGS['online']) + return cls._from_value(cls.VALID_FLAGS["online"]) @flag_value(1 << 0) def joined(self) -> bool: diff --git a/disagreement/client.py b/disagreement/client.py index 6ed4462..03fb37b 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -36,6 +36,7 @@ from .ext import loader as ext_loader from .interactions import Interaction, Snowflake from .error_handler import setup_global_error_handler from .voice_client import VoiceClient +from .models import Activity if TYPE_CHECKING: from .models import ( @@ -75,13 +76,21 @@ class Client: intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`. You might need to enable privileged intents in your bot's application page. loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations. - Defaults to `asyncio.get_event_loop()`. + Defaults to the running loop + via `asyncio.get_running_loop()`, + or a new loop from + `asyncio.new_event_loop()` if + none is running. command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]): The prefix(es) for commands. Defaults to '!'. verbose (bool): If True, print raw HTTP and Gateway traffic for debugging. + mention_replies (bool): Whether replies mention the author by default. + allowed_mentions (Optional[Dict[str, Any]]): Default allowed mentions for messages. http_options (Optional[Dict[str, Any]]): Extra options passed to :class:`HTTPClient` for creating the internal :class:`aiohttp.ClientSession`. + message_cache_maxlen (Optional[int]): Maximum number of messages to keep + in the cache. When ``None``, the cache size is unlimited. """ def __init__( @@ -95,10 +104,12 @@ class Client: application_id: Optional[Union[str, int]] = None, verbose: bool = False, mention_replies: bool = False, + allowed_mentions: Optional[Dict[str, Any]] = None, shard_count: Optional[int] = None, gateway_max_retries: int = 5, gateway_max_backoff: float = 60.0, member_cache_flags: Optional[MemberCacheFlags] = None, + message_cache_maxlen: Optional[int] = None, http_options: Optional[Dict[str, Any]] = None, ): if not token: @@ -108,6 +119,7 @@ class Client: self.member_cache_flags: MemberCacheFlags = ( member_cache_flags if member_cache_flags is not None else MemberCacheFlags() ) + self.message_cache_maxlen: Optional[int] = message_cache_maxlen self.intents: int = intents if intents is not None else GatewayIntent.default() if loop: self.loop: asyncio.AbstractEventLoop = loop @@ -157,7 +169,7 @@ class Client: self._guilds: GuildCache = GuildCache() self._channels: ChannelCache = ChannelCache() self._users: Cache["User"] = Cache() - self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour + self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen) self._views: Dict[Snowflake, "View"] = {} self._persistent_views: Dict[str, "View"] = {} self._voice_clients: Dict[Snowflake, VoiceClient] = {} @@ -165,6 +177,7 @@ class Client: # Default whether replies mention the user self.mention_replies: bool = mention_replies + self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions # Basic signal handling for graceful shutdown # This might be better handled by the user's application code, but can be a nice default. @@ -435,8 +448,7 @@ class Client: async def change_presence( self, status: str, - activity_name: Optional[str] = None, - activity_type: int = 0, + activity: Optional[Activity] = None, since: int = 0, afk: bool = False, ): @@ -445,8 +457,7 @@ class Client: Args: status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible"). - activity_name (Optional[str]): The name of the activity. - activity_type (int): The type of the activity. + activity (Optional[Activity]): Activity instance describing what the bot is doing. since (int): The timestamp (in milliseconds) of when the client went idle. afk (bool): Whether the client is AFK. """ @@ -456,8 +467,7 @@ class Client: if self._gateway: await self._gateway.update_presence( status=status, - activity_name=activity_name, - activity_type=activity_type, + activity=activity, since=since, afk=afk, ) @@ -693,7 +703,7 @@ class Client: ) # import traceback # traceback.print_exception(type(error.original), error.original, error.original.__traceback__) - + async def on_command_completion(self, ctx: "CommandContext") -> None: """ Default command completion handler. Called when a command has successfully completed. @@ -1010,7 +1020,7 @@ class Client: embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`. Discord supports up to 10 embeds per message. components (Optional[List[ActionRow]]): A list of ActionRow components to include. - allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. + allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. Defaults to :attr:`Client.allowed_mentions`. message_reference (Optional[Dict[str, Any]]): Message reference for replying. attachments (Optional[List[Any]]): Attachments to include with the message. files (Optional[List[Any]]): Files to upload with the message. @@ -1057,6 +1067,9 @@ class Client: if isinstance(comp, ComponentModel) ] + if allowed_mentions is None: + allowed_mentions = self.allowed_mentions + message_data = await self._http.send_message( channel_id=channel_id, content=content, @@ -1428,6 +1441,24 @@ class Client: await self._http.delete_guild_template(guild_id, template_code) + async def fetch_widget(self, guild_id: Snowflake) -> Dict[str, Any]: + """|coro| Fetch a guild's widget settings.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + return await self._http.get_guild_widget(guild_id) + + async def edit_widget( + self, guild_id: Snowflake, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """|coro| Edit a guild's widget settings.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + return await self._http.edit_guild_widget(guild_id, payload) + async def fetch_scheduled_events( self, guild_id: Snowflake ) -> List["ScheduledEvent"]: @@ -1514,35 +1545,35 @@ class Client: return [self.parse_invite(inv) for inv in data] def add_persistent_view(self, view: "View") -> None: - """ - Registers a persistent view with the client. + """ + Registers a persistent view with the client. - Persistent views have a timeout of `None` and their components must have a `custom_id`. - This allows the view to be re-instantiated across bot restarts. + Persistent views have a timeout of `None` and their components must have a `custom_id`. + This allows the view to be re-instantiated across bot restarts. - Args: - view (View): The view instance to register. + Args: + view (View): The view instance to register. - Raises: - ValueError: If the view is not persistent (timeout is not None) or if a component's - custom_id is already registered. - """ - if self.is_ready(): - print( - "Warning: Adding a persistent view after the client is ready. " - "This view will only be available for interactions on this session." - ) + Raises: + ValueError: If the view is not persistent (timeout is not None) or if a component's + custom_id is already registered. + """ + if self.is_ready(): + print( + "Warning: Adding a persistent view after the client is ready. " + "This view will only be available for interactions on this session." + ) - if view.timeout is not None: - raise ValueError("Persistent views must have a timeout of None.") + if view.timeout is not None: + raise ValueError("Persistent views must have a timeout of None.") - for item in view.children: - if item.custom_id: # Ensure custom_id is not None - if item.custom_id in self._persistent_views: - raise ValueError( - f"A component with custom_id '{item.custom_id}' is already registered." - ) - self._persistent_views[item.custom_id] = view + for item in view.children: + if item.custom_id: # Ensure custom_id is not None + if item.custom_id in self._persistent_views: + raise ValueError( + f"A component with custom_id '{item.custom_id}' is already registered." + ) + self._persistent_views[item.custom_id] = view # --- Application Command Methods --- async def process_interaction(self, interaction: Interaction) -> None: diff --git a/disagreement/enums.py b/disagreement/enums.py index 0b63105..7b01c72 100644 --- a/disagreement/enums.py +++ b/disagreement/enums.py @@ -375,6 +375,15 @@ class OverwriteType(IntEnum): MEMBER = 1 +class AutoArchiveDuration(IntEnum): + """Thread auto-archive duration in minutes.""" + + HOUR = 60 + DAY = 1440 + THREE_DAYS = 4320 + WEEK = 10080 + + # --- Component Enums --- diff --git a/disagreement/error_handler.py b/disagreement/error_handler.py index 0240cca..e814518 100644 --- a/disagreement/error_handler.py +++ b/disagreement/error_handler.py @@ -14,7 +14,11 @@ def setup_global_error_handler( The handler logs unhandled exceptions so they don't crash the bot. """ if loop is None: - loop = asyncio.get_event_loop() + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) if not logging.getLogger().hasHandlers(): setup_logging(logging.ERROR) diff --git a/disagreement/ext/commands/converters.py b/disagreement/ext/commands/converters.py index 23ff879..e9357f5 100644 --- a/disagreement/ext/commands/converters.py +++ b/disagreement/ext/commands/converters.py @@ -1,8 +1,10 @@ # disagreement/ext/commands/converters.py +# pyright: reportIncompatibleMethodOverride=false from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic from abc import ABC, abstractmethod import re +import inspect from .errors import BadArgument from disagreement.models import Member, Guild, Role @@ -36,6 +38,20 @@ class Converter(ABC, Generic[T]): raise NotImplementedError("Converter subclass must implement convert method.") +class Greedy(list): + """Type hint helper to greedily consume arguments.""" + + converter: Any = None + + def __class_getitem__(cls, param: Any) -> type: # pyright: ignore[override] + if isinstance(param, tuple): + if len(param) != 1: + raise TypeError("Greedy[...] expects a single parameter") + param = param[0] + name = f"Greedy[{getattr(param, '__name__', str(param))}]" + return type(name, (Greedy,), {"converter": param}) + + # --- Built-in Type Converters --- @@ -169,7 +185,3 @@ async def run_converters(ctx: "CommandContext", annotation: Any, argument: str) raise BadArgument(f"No converter found for type annotation '{annotation}'.") return argument # Default to string if no annotation or annotation is str - - -# Need to import inspect for the run_converters function -import inspect diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py index 5a98997..97ce12e 100644 --- a/disagreement/ext/commands/core.py +++ b/disagreement/ext/commands/core.py @@ -29,7 +29,7 @@ from .errors import ( CheckFailure, CommandInvokeError, ) -from .converters import run_converters, DEFAULT_CONVERTERS, Converter +from .converters import Greedy, run_converters, DEFAULT_CONVERTERS, Converter from disagreement.typing import Typing logger = logging.getLogger(__name__) @@ -46,29 +46,39 @@ class GroupMixin: self.commands: Dict[str, "Command"] = {} self.name: str = "" - def command(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Command"]: + def command( + self, **attrs: Any + ) -> Callable[[Callable[..., Awaitable[None]]], "Command"]: def decorator(func: Callable[..., Awaitable[None]]) -> "Command": cmd = Command(func, **attrs) cmd.cog = getattr(self, "cog", None) self.add_command(cmd) return cmd + return decorator - def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Group"]: + def group( + self, **attrs: Any + ) -> Callable[[Callable[..., Awaitable[None]]], "Group"]: def decorator(func: Callable[..., Awaitable[None]]) -> "Group": cmd = Group(func, **attrs) cmd.cog = getattr(self, "cog", None) self.add_command(cmd) return cmd + return decorator def add_command(self, command: "Command") -> None: if command.name in self.commands: - raise ValueError(f"Command '{command.name}' is already registered in group '{self.name}'.") + raise ValueError( + f"Command '{command.name}' is already registered in group '{self.name}'." + ) self.commands[command.name.lower()] = command for alias in command.aliases: if alias in self.commands: - logger.warning(f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias.") + logger.warning( + f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias." + ) self.commands[alias.lower()] = command def get_command(self, name: str) -> Optional["Command"]: @@ -181,6 +191,7 @@ class Command(GroupMixin): class Group(Command): """A command that can have subcommands.""" + def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any): super().__init__(callback, **attrs) @@ -494,7 +505,34 @@ class CommandHandler: None # Holds the raw string for current param ) - if view.eof: # No more input string + annotation = param.annotation + if inspect.isclass(annotation) and issubclass(annotation, Greedy): + greedy_values = [] + converter_type = annotation.converter + while not view.eof: + view.skip_whitespace() + if view.eof: + break + start = view.index + if view.buffer[view.index] == '"': + arg_str_value = view.get_quoted_string() + if arg_str_value == "" and view.buffer[view.index] == '"': + raise BadArgument( + f"Unterminated quoted string for argument '{param.name}'." + ) + else: + arg_str_value = view.get_word() + try: + converted = await run_converters( + ctx, converter_type, arg_str_value + ) + except BadArgument: + view.index = start + break + greedy_values.append(converted) + final_value_for_param = greedy_values + arg_str_value = None + elif view.eof: # No more input string if param.default is not inspect.Parameter.empty: final_value_for_param = param.default elif param.kind != inspect.Parameter.VAR_KEYWORD: @@ -656,7 +694,9 @@ class CommandHandler: elif command.invoke_without_command: view.index -= len(potential_subcommand) + view.previous else: - raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.") + raise CommandNotFound( + f"Subcommand '{potential_subcommand}' not found." + ) ctx = CommandContext( message=message, @@ -681,7 +721,9 @@ class CommandHandler: if hasattr(self.client, "on_command_error"): await self.client.on_command_error(ctx, e) except Exception as e: - logger.error("Unexpected error invoking command '%s': %s", original_command.name, e) + logger.error( + "Unexpected error invoking command '%s': %s", original_command.name, e + ) exc = CommandInvokeError(e) if hasattr(self.client, "on_command_error"): await self.client.on_command_error(ctx, exc) diff --git a/disagreement/ext/commands/decorators.py b/disagreement/ext/commands/decorators.py index 7400118..b2d6680 100644 --- a/disagreement/ext/commands/decorators.py +++ b/disagreement/ext/commands/decorators.py @@ -218,6 +218,7 @@ def requires_permissions( return check(predicate) + def has_role( name_or_id: str | int, ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: @@ -241,9 +242,7 @@ def has_role( raise CheckFailure("Could not resolve author to a guild member.") # Create a list of the member's role objects by looking them up in the guild's roles list - member_roles = [ - role for role in ctx.guild.roles if role.id in author.roles - ] + member_roles = [role for role in ctx.guild.roles if role.id in author.roles] if any( role.id == str(name_or_id) or role.name == name_or_id @@ -278,9 +277,7 @@ def has_any_role( if not author: raise CheckFailure("Could not resolve author to a guild member.") - member_roles = [ - role for role in ctx.guild.roles if role.id in author.roles - ] + member_roles = [role for role in ctx.guild.roles if role.id in author.roles] # Convert names_or_ids to a set for efficient lookup names_or_ids_set = set(map(str, names_or_ids)) diff --git a/disagreement/gateway.py b/disagreement/gateway.py index 3875d92..6499607 100644 --- a/disagreement/gateway.py +++ b/disagreement/gateway.py @@ -14,6 +14,8 @@ import time import random from typing import Optional, TYPE_CHECKING, Any, Dict +from .models import Activity + from .enums import GatewayOpcode, GatewayIntent from .errors import GatewayException, DisagreementException, AuthenticationError from .interactions import Interaction @@ -63,7 +65,11 @@ class GatewayClient: self._max_backoff: float = max_backoff self._ws: Optional[aiohttp.ClientWebSocketResponse] = None - self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + try: + self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) self._heartbeat_interval: Optional[float] = None self._last_sequence: Optional[int] = None self._session_id: Optional[str] = None @@ -213,26 +219,17 @@ class GatewayClient: async def update_presence( self, status: str, - activity_name: Optional[str] = None, - activity_type: int = 0, + activity: Optional[Activity] = None, + *, since: int = 0, afk: bool = False, - ): + ) -> None: """Sends the presence update payload to the Gateway.""" payload = { "op": GatewayOpcode.PRESENCE_UPDATE, "d": { "since": since, - "activities": ( - [ - { - "name": activity_name, - "type": activity_type, - } - ] - if activity_name - else [] - ), + "activities": [activity.to_dict()] if activity else [], "status": status, "afk": afk, }, @@ -353,7 +350,10 @@ class GatewayClient: future._members.extend(raw_event_d_payload.get("members", [])) # type: ignore # If this is the last chunk, resolve the future - if raw_event_d_payload.get("chunk_index") == raw_event_d_payload.get("chunk_count", 1) - 1: + if ( + raw_event_d_payload.get("chunk_index") + == raw_event_d_payload.get("chunk_count", 1) - 1 + ): future.set_result(future._members) # type: ignore del self._member_chunk_requests[nonce] diff --git a/disagreement/http.py b/disagreement/http.py index 2f190e2..78157a3 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -601,18 +601,18 @@ class HTTPClient: ) async def delete_user_reaction( - self, - channel_id: "Snowflake", - message_id: "Snowflake", - emoji: str, - user_id: "Snowflake", - ) -> None: - """Removes another user's reaction from a message.""" - encoded = quote(emoji) - await self.request( - "DELETE", - f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}", - ) + self, + channel_id: "Snowflake", + message_id: "Snowflake", + emoji: str, + user_id: "Snowflake", + ) -> None: + """Removes another user's reaction from a message.""" + encoded = quote(emoji) + await self.request( + "DELETE", + f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}", + ) async def get_reactions( self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str @@ -910,6 +910,20 @@ class HTTPClient: """Fetches a guild object for a given guild ID.""" return await self.request("GET", f"/guilds/{guild_id}") + async def get_guild_widget(self, guild_id: "Snowflake") -> Dict[str, Any]: + """Fetches the guild widget settings.""" + + return await self.request("GET", f"/guilds/{guild_id}/widget") + + async def edit_guild_widget( + self, guild_id: "Snowflake", payload: Dict[str, Any] + ) -> Dict[str, Any]: + """Edits the guild widget settings.""" + + return await self.request( + "PATCH", f"/guilds/{guild_id}/widget", payload=payload + ) + async def get_guild_templates(self, guild_id: "Snowflake") -> List[Dict[str, Any]]: """Fetches all templates for the given guild.""" return await self.request("GET", f"/guilds/{guild_id}/templates") diff --git a/disagreement/models.py b/disagreement/models.py index 9323415..79329a7 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -6,6 +6,7 @@ Data models for Discord objects. import asyncio import json +import re from dataclasses import dataclass from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast @@ -24,6 +25,7 @@ from .enums import ( # These enums will need to be defined in disagreement/enum PremiumTier, GuildFeature, ChannelType, + AutoArchiveDuration, ComponentType, ButtonStyle, # Added for Button GuildScheduledEventPrivacyLevel, @@ -39,6 +41,7 @@ if TYPE_CHECKING: from .enums import OverwriteType # For PermissionOverwrite model from .ui.view import View from .interactions import Snowflake + from .typing import Typing # Forward reference Message if it were used in type hints before its definition # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. @@ -114,31 +117,39 @@ class Message: # self.mention_roles: List[str] = data.get("mention_roles", []) # self.mention_everyone: bool = data.get("mention_everyone", False) + @property + def clean_content(self) -> str: + """Returns message content without user, role, or channel mentions.""" + + pattern = re.compile(r"<@!?\d+>|<#\d+>|<@&\d+>") + cleaned = pattern.sub("", self.content) + return " ".join(cleaned.split()) + async def pin(self) -> None: - """|coro| + """|coro| - Pins this message to its channel. + Pins this message to its channel. - Raises - ------ - HTTPException - Pinning the message failed. - """ - await self._client._http.pin_message(self.channel_id, self.id) - self.pinned = True + Raises + ------ + HTTPException + Pinning the message failed. + """ + await self._client._http.pin_message(self.channel_id, self.id) + self.pinned = True async def unpin(self) -> None: - """|coro| + """|coro| - Unpins this message from its channel. + Unpins this message from its channel. - Raises - ------ - HTTPException - Unpinning the message failed. - """ - await self._client._http.unpin_message(self.channel_id, self.id) - self.pinned = False + Raises + ------ + HTTPException + Unpinning the message failed. + """ + await self._client._http.unpin_message(self.channel_id, self.id) + self.pinned = False async def reply( self, @@ -241,16 +252,16 @@ class Message: await self._client.add_reaction(self.channel_id, self.id, emoji) async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None: - """|coro| - Removes a reaction from this message. - If no ``member`` is provided, removes the bot's own reaction. - """ - if member: - await self._client._http.delete_user_reaction( - self.channel_id, self.id, emoji, member.id - ) - else: - await self._client.remove_reaction(self.channel_id, self.id, emoji) + """|coro| + Removes a reaction from this message. + If no ``member`` is provided, removes the bot's own reaction. + """ + if member: + await self._client._http.delete_user_reaction( + self.channel_id, self.id, emoji, member.id + ) + else: + await self._client.remove_reaction(self.channel_id, self.id, emoji) async def clear_reactions(self) -> None: """|coro| Remove all reactions from this message.""" @@ -280,7 +291,7 @@ class Message: self, name: str, *, - auto_archive_duration: Optional[int] = None, + auto_archive_duration: Optional[AutoArchiveDuration] = None, rate_limit_per_user: Optional[int] = None, reason: Optional[str] = None, ) -> "Thread": @@ -292,9 +303,9 @@ class Message: ---------- name: str The name of the thread. - auto_archive_duration: Optional[int] - The duration in minutes to automatically archive the thread after recent activity. - Can be one of 60, 1440, 4320, 10080. + auto_archive_duration: Optional[AutoArchiveDuration] + How long before the thread is automatically archived after recent activity. + See :class:`AutoArchiveDuration` for allowed values. rate_limit_per_user: Optional[int] The number of seconds a user has to wait before sending another message. reason: Optional[str] @@ -307,7 +318,7 @@ class Message: """ payload: Dict[str, Any] = {"name": name} if auto_archive_duration is not None: - payload["auto_archive_duration"] = auto_archive_duration + payload["auto_archive_duration"] = int(auto_archive_duration) if rate_limit_per_user is not None: payload["rate_limit_per_user"] = rate_limit_per_user @@ -530,8 +541,42 @@ class Embed: payload["fields"] = [f.to_dict() for f in self.fields] return payload - # Convenience methods for building embeds can be added here - # e.g., set_author, add_field, set_footer, set_image, etc. + # Convenience methods mirroring ``discord.py``'s ``Embed`` API + + def set_author( + self, *, name: str, url: Optional[str] = None, icon_url: Optional[str] = None + ) -> "Embed": + """Set the embed author and return ``self`` for chaining.""" + + data: Dict[str, Any] = {"name": name} + if url: + data["url"] = url + if icon_url: + data["icon_url"] = icon_url + self.author = EmbedAuthor(data) + return self + + def add_field(self, *, name: str, value: str, inline: bool = False) -> "Embed": + """Add a field to the embed.""" + + field = EmbedField({"name": name, "value": value, "inline": inline}) + self.fields.append(field) + return self + + def set_footer(self, *, text: str, icon_url: Optional[str] = None) -> "Embed": + """Set the embed footer.""" + + data: Dict[str, Any] = {"text": text} + if icon_url: + data["icon_url"] = icon_url + self.footer = EmbedFooter(data) + return self + + def set_image(self, url: str) -> "Embed": + """Set the embed image.""" + + self.image = EmbedImage({"url": url}) + return self class Attachment: @@ -1088,7 +1133,9 @@ class Guild: # Internal caches, populated by events or specific fetches self._channels: ChannelCache = ChannelCache() - self._members: MemberCache = MemberCache(getattr(client_instance, "member_cache_flags", MemberCacheFlags())) + self._members: MemberCache = MemberCache( + getattr(client_instance, "member_cache_flags", MemberCacheFlags()) + ) self._threads: Dict[str, "Thread"] = {} def get_channel(self, channel_id: str) -> Optional["Channel"]: @@ -1128,6 +1175,16 @@ class Guild: def __repr__(self) -> str: return f"" + async def fetch_widget(self) -> Dict[str, Any]: + """|coro| Fetch this guild's widget settings.""" + + return await self._client.fetch_widget(self.id) + + async def edit_widget(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """|coro| Edit this guild's widget settings.""" + + return await self._client.edit_widget(self.id, payload) + async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]: """|coro| @@ -1278,7 +1335,45 @@ class Channel: return base -class TextChannel(Channel): +class Messageable: + """Mixin for channels that can send messages and show typing.""" + + _client: "Client" + id: str + + async def send( + self, + content: Optional[str] = None, + *, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + ) -> "Message": + if not hasattr(self._client, "send_message"): + raise NotImplementedError( + "Client.send_message is required for Messageable.send" + ) + + return await self._client.send_message( + channel_id=self.id, + content=content, + embed=embed, + embeds=embeds, + components=components, + ) + + async def trigger_typing(self) -> None: + await self._client._http.trigger_typing(self.id) + + def typing(self) -> "Typing": + if not hasattr(self._client, "typing"): + raise NotImplementedError( + "Client.typing is required for Messageable.typing" + ) + return self._client.typing(self.id) + + +class TextChannel(Channel, Messageable): """Represents a guild text channel or announcement channel.""" def __init__(self, data: Dict[str, Any], client_instance: "Client"): @@ -1304,27 +1399,6 @@ class TextChannel(Channel): return message_pager(self, limit=limit, before=before, after=after) - async def send( - self, - content: Optional[str] = None, - *, - embed: Optional[Embed] = None, - embeds: Optional[List[Embed]] = None, - components: Optional[List["ActionRow"]] = None, # Added components - ) -> "Message": # Forward reference Message - if not hasattr(self._client, "send_message"): - raise NotImplementedError( - "Client.send_message is required for TextChannel.send" - ) - - return await self._client.send_message( - channel_id=self.id, - content=content, - embed=embed, - embeds=embeds, - components=components, - ) - async def purge( self, limit: int, *, before: "Snowflake | None" = None ) -> List["Snowflake"]: @@ -1347,41 +1421,41 @@ class TextChannel(Channel): return ids def get_partial_message(self, id: int) -> "PartialMessage": - """Returns a :class:`PartialMessage` for the given ID. + """Returns a :class:`PartialMessage` for the given ID. - This allows performing actions on a message without fetching it first. + This allows performing actions on a message without fetching it first. - Parameters - ---------- - id: int - The ID of the message to get a partial instance of. + Parameters + ---------- + id: int + The ID of the message to get a partial instance of. - Returns - ------- - PartialMessage - The partial message instance. - """ - return PartialMessage(id=str(id), channel=self) + Returns + ------- + PartialMessage + The partial message instance. + """ + return PartialMessage(id=str(id), channel=self) def __repr__(self) -> str: return f"" async def pins(self) -> List["Message"]: """|coro| - + Fetches all pinned messages in this channel. - + Returns ------- List[Message] The pinned messages. - + Raises ------ HTTPException Fetching the pinned messages failed. """ - + messages_data = await self._client._http.get_pinned_messages(self.id) return [self._client.parse_message(m) for m in messages_data] @@ -1390,7 +1464,7 @@ class TextChannel(Channel): name: str, *, type: ChannelType = ChannelType.PUBLIC_THREAD, - auto_archive_duration: Optional[int] = None, + auto_archive_duration: Optional[AutoArchiveDuration] = None, invitable: Optional[bool] = None, rate_limit_per_user: Optional[int] = None, reason: Optional[str] = None, @@ -1406,8 +1480,8 @@ class TextChannel(Channel): type: ChannelType The type of thread to create. Defaults to PUBLIC_THREAD. Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD. - auto_archive_duration: Optional[int] - The duration in minutes to automatically archive the thread after recent activity. + auto_archive_duration: Optional[AutoArchiveDuration] + How long before the thread is automatically archived after recent activity. invitable: Optional[bool] Whether non-moderators can invite other non-moderators to a private thread. Only applicable to private threads. @@ -1426,7 +1500,7 @@ class TextChannel(Channel): "type": type.value, } if auto_archive_duration is not None: - payload["auto_archive_duration"] = auto_archive_duration + payload["auto_archive_duration"] = int(auto_archive_duration) if invitable is not None and type == ChannelType.PRIVATE_THREAD: payload["invitable"] = invitable if rate_limit_per_user is not None: @@ -1606,7 +1680,9 @@ class Thread(TextChannel): # Threads are a specialized TextChannel """ await self._client._http.leave_thread(self.id) - async def archive(self, locked: bool = False, *, reason: Optional[str] = None) -> "Thread": + async def archive( + self, locked: bool = False, *, reason: Optional[str] = None + ) -> "Thread": """|coro| Archives this thread. @@ -1631,7 +1707,7 @@ class Thread(TextChannel): # Threads are a specialized TextChannel return cast("Thread", self._client.parse_channel(data)) -class DMChannel(Channel): +class DMChannel(Channel, Messageable): """Represents a Direct Message channel.""" def __init__(self, data: Dict[str, Any], client_instance: "Client"): @@ -1645,27 +1721,6 @@ class DMChannel(Channel): def recipient(self) -> Optional[User]: return self.recipients[0] if self.recipients else None - async def send( - self, - content: Optional[str] = None, - *, - embed: Optional[Embed] = None, - embeds: Optional[List[Embed]] = None, - components: Optional[List["ActionRow"]] = None, # Added components - ) -> "Message": - if not hasattr(self._client, "send_message"): - raise NotImplementedError( - "Client.send_message is required for DMChannel.send" - ) - - return await self._client.send_message( - channel_id=self.id, - content=content, - embed=embed, - embeds=embeds, - components=components, - ) - async def history( self, *, @@ -2356,6 +2411,37 @@ class ThreadMember: return f"" +class Activity: + """Represents a user's presence activity.""" + + def __init__(self, name: str, type: int) -> None: + self.name = name + self.type = type + + def to_dict(self) -> Dict[str, Any]: + return {"name": self.name, "type": self.type} + + +class Game(Activity): + """Represents a playing activity.""" + + def __init__(self, name: str) -> None: + super().__init__(name, 0) + + +class Streaming(Activity): + """Represents a streaming activity.""" + + def __init__(self, name: str, url: str) -> None: + super().__init__(name, 1) + self.url = url + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["url"] = self.url + return payload + + class PresenceUpdate: """Represents a PRESENCE_UPDATE event.""" @@ -2366,7 +2452,17 @@ class PresenceUpdate: self.user = User(data["user"]) self.guild_id: Optional[str] = data.get("guild_id") self.status: Optional[str] = data.get("status") - self.activities: List[Dict[str, Any]] = data.get("activities", []) + self.activities: List[Activity] = [] + for activity in data.get("activities", []): + act_type = activity.get("type", 0) + name = activity.get("name", "") + if act_type == 0: + obj = Game(name) + elif act_type == 1: + obj = Streaming(name, activity.get("url", "")) + else: + obj = Activity(name, act_type) + self.activities.append(obj) self.client_status: Dict[str, Any] = data.get("client_status", {}) def __repr__(self) -> str: diff --git a/disagreement/ui/view.py b/disagreement/ui/view.py index a1a13da..3e17b28 100644 --- a/disagreement/ui/view.py +++ b/disagreement/ui/view.py @@ -72,7 +72,7 @@ class View: rows: List[ActionRow] = [] for item in self.children: - rows.append(ActionRow(components=[item])) + rows.append(ActionRow(components=[item])) return rows diff --git a/disagreement/voice_client.py b/disagreement/voice_client.py index c771869..8696665 100644 --- a/disagreement/voice_client.py +++ b/disagreement/voice_client.py @@ -7,9 +7,26 @@ import asyncio import contextlib import socket import threading +from array import array + + +def _apply_volume(data: bytes, volume: float) -> bytes: + samples = array("h") + samples.frombytes(data) + for i, sample in enumerate(samples): + scaled = int(sample * volume) + if scaled > 32767: + scaled = 32767 + elif scaled < -32768: + scaled = -32768 + samples[i] = scaled + return samples.tobytes() + + from typing import TYPE_CHECKING, Optional, Sequence import aiohttp + # The following import is correct, but may be flagged by Pylance if the virtual # environment is not configured correctly. from nacl.secret import SecretBox @@ -180,6 +197,9 @@ class VoiceClient: data = await self._current_source.read() if not data: break + volume = getattr(self._current_source, "volume", 1.0) + if volume != 1.0: + data = _apply_volume(data, volume) await self.send_audio_frame(data) finally: await self._current_source.close() diff --git a/docs/embeds.md b/docs/embeds.md new file mode 100644 index 0000000..bac95c6 --- /dev/null +++ b/docs/embeds.md @@ -0,0 +1,22 @@ +# Embeds + +`Embed` objects can be constructed piece by piece much like in `discord.py`. +These helper methods return the embed instance so you can chain calls. + +```python +from disagreement.models import Embed + +embed = ( + Embed() + .set_author(name="Disagreement", url="https://example.com", icon_url="https://cdn.example.com/bot.png") + .add_field(name="Info", value="Some details") + .set_footer(text="Made with Disagreement") + .set_image(url="https://cdn.example.com/image.png") +) +``` + +Call `to_dict()` to convert the embed back to a payload dictionary before sending: + +```python +payload = embed.to_dict() +``` diff --git a/docs/mentions.md b/docs/mentions.md new file mode 100644 index 0000000..2a604cf --- /dev/null +++ b/docs/mentions.md @@ -0,0 +1,23 @@ +# Controlling Mentions + +The client exposes settings to control how mentions behave in outgoing messages. + +## Default Allowed Mentions + +Use the ``allowed_mentions`` parameter of :class:`disagreement.Client` to set a +default for all messages: + +```python +client = disagreement.Client( + token="YOUR_TOKEN", + allowed_mentions={"parse": [], "replied_user": False}, +) +``` + +When ``Client.send_message`` is called without an explicit ``allowed_mentions`` +argument this value will be used. + +## Next Steps + +- [Commands](commands.md) +- [HTTP Client Options](http_client.md) diff --git a/docs/presence.md b/docs/presence.md index 0e5da19..ac87644 100644 --- a/docs/presence.md +++ b/docs/presence.md @@ -1,6 +1,7 @@ # Updating Presence The `Client.change_presence` method allows you to update the bot's status and displayed activity. +Pass an :class:`~disagreement.models.Activity` (such as :class:`~disagreement.models.Game` or :class:`~disagreement.models.Streaming`) to describe what your bot is doing. ## Status Strings @@ -22,8 +23,18 @@ An activity dictionary must include a `name` and a `type` field. The type value | `4` | Custom | | `5` | Competing | -Example: +Example using the provided activity classes: ```python -await client.change_presence(status="idle", activity={"name": "with Discord", "type": 0}) +from disagreement.models import Game + +await client.change_presence(status="idle", activity=Game("with Discord")) +``` + +You can also specify a streaming URL: + +```python +from disagreement.models import Streaming + +await client.change_presence(status="online", activity=Streaming("My Stream", "https://twitch.tv/someone")) ``` diff --git a/docs/threads.md b/docs/threads.md new file mode 100644 index 0000000..3853e9c --- /dev/null +++ b/docs/threads.md @@ -0,0 +1,18 @@ +# Threads + +`Message.create_thread` and `TextChannel.create_thread` let you start new threads. +Use :class:`AutoArchiveDuration` to control when a thread is automatically archived. + +```python +from disagreement.enums import AutoArchiveDuration + +await message.create_thread( + "discussion", + auto_archive_duration=AutoArchiveDuration.DAY, +) +``` + +## Next Steps + +- [Message History](message_history.md) +- [Caching](caching.md) diff --git a/examples/basic_bot.py b/examples/basic_bot.py index d2836f1..5ad01d4 100644 --- a/examples/basic_bot.py +++ b/examples/basic_bot.py @@ -39,9 +39,14 @@ except ImportError: ) sys.exit(1) -from dotenv import load_dotenv +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") -load_dotenv() +if load_dotenv: + load_dotenv() # Optional: Configure logging for more insight, especially for gateway events # logging.basicConfig(level=logging.DEBUG) # For very verbose output diff --git a/examples/component_bot.py b/examples/component_bot.py index 60a8405..d2f9753 100644 --- a/examples/component_bot.py +++ b/examples/component_bot.py @@ -37,9 +37,15 @@ from disagreement.interactions import ( InteractionResponsePayload, InteractionCallbackData, ) -from dotenv import load_dotenv -load_dotenv() +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") + +if load_dotenv: + load_dotenv() # Get the bot token and application ID from the environment variables token = os.getenv("DISCORD_BOT_TOKEN") diff --git a/examples/context_menus.py b/examples/context_menus.py index 8ceb1e7..9ba3656 100644 --- a/examples/context_menus.py +++ b/examples/context_menus.py @@ -15,9 +15,14 @@ from disagreement.ext.app_commands import ( ) from disagreement.models import User, Message -from dotenv import load_dotenv +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") -load_dotenv() +if load_dotenv: + load_dotenv() BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "") APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "") diff --git a/examples/extension_management.py b/examples/extension_management.py index 8ba86cc..b128e08 100644 --- a/examples/extension_management.py +++ b/examples/extension_management.py @@ -4,7 +4,11 @@ import asyncio import os import sys -from dotenv import load_dotenv +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") # Allow running from the examples folder without installing if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): @@ -12,7 +16,8 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi from disagreement import Client -load_dotenv() +if load_dotenv: + load_dotenv() TOKEN = os.environ.get("DISCORD_BOT_TOKEN") diff --git a/examples/hybrid_bot.py b/examples/hybrid_bot.py index c6d3bdd..d4aeae3 100644 --- a/examples/hybrid_bot.py +++ b/examples/hybrid_bot.py @@ -36,9 +36,14 @@ from disagreement.enums import ( logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -from dotenv import load_dotenv +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") -load_dotenv() +if load_dotenv: + load_dotenv() # --- Define a Test Cog --- diff --git a/examples/message_history.py b/examples/message_history.py index 104f6fa..d5c6ffc 100644 --- a/examples/message_history.py +++ b/examples/message_history.py @@ -10,9 +10,15 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi from disagreement.client import Client from disagreement.models import TextChannel -from dotenv import load_dotenv -load_dotenv() +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") + +if load_dotenv: + load_dotenv() BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "") CHANNEL_ID = os.environ.get("DISCORD_CHANNEL_ID", "") diff --git a/examples/modal_command.py b/examples/modal_command.py index 120747f..17e5da8 100644 --- a/examples/modal_command.py +++ b/examples/modal_command.py @@ -2,14 +2,20 @@ import os import asyncio -from dotenv import load_dotenv + +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") from disagreement import Client, ui from disagreement.enums import GatewayIntent, TextInputStyle from disagreement.ext.app_commands.decorators import slash_command from disagreement.ext.app_commands.context import AppCommandContext -load_dotenv() +if load_dotenv: + load_dotenv() token = os.getenv("DISCORD_BOT_TOKEN", "") application_id = os.getenv("DISCORD_APPLICATION_ID", "") diff --git a/examples/modal_send.py b/examples/modal_send.py index 0b15489..8e91dba 100644 --- a/examples/modal_send.py +++ b/examples/modal_send.py @@ -3,7 +3,11 @@ import os import sys -from dotenv import load_dotenv +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -11,7 +15,8 @@ from disagreement import Client, GatewayIntent, ui # type: ignore from disagreement.ext.app_commands.decorators import slash_command from disagreement.ext.app_commands.context import AppCommandContext -load_dotenv() +if load_dotenv: + load_dotenv() TOKEN = os.getenv("DISCORD_BOT_TOKEN", "") APP_ID = os.getenv("DISCORD_APPLICATION_ID", "") diff --git a/examples/sharded_bot.py b/examples/sharded_bot.py index be4a554..0d5e756 100644 --- a/examples/sharded_bot.py +++ b/examples/sharded_bot.py @@ -9,9 +9,15 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import disagreement -from dotenv import load_dotenv -load_dotenv() +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") + +if load_dotenv: + load_dotenv() TOKEN = os.environ.get("DISCORD_BOT_TOKEN") if not TOKEN: diff --git a/examples/voice_bot.py b/examples/voice_bot.py index a9b4766..077fccd 100644 --- a/examples/voice_bot.py +++ b/examples/voice_bot.py @@ -10,11 +10,16 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi from typing import cast -from dotenv import load_dotenv +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover - example helper + load_dotenv = None + print("python-dotenv is not installed. Environment variables will not be loaded") import disagreement -load_dotenv() +if load_dotenv: + load_dotenv() _TOKEN = os.getenv("DISCORD_BOT_TOKEN") _GUILD_ID = os.getenv("DISCORD_GUILD_ID") diff --git a/tests/test_cache.py b/tests/test_cache.py index 234077e..6909697 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -15,3 +15,14 @@ def test_cache_ttl_expiry(): assert cache.get("b") == 1 time.sleep(0.02) assert cache.get("b") is None + + +def test_cache_lru_eviction(): + cache = Cache(maxlen=2) + cache.set("a", 1) + cache.set("b", 2) + assert cache.get("a") == 1 + cache.set("c", 3) + assert cache.get("b") is None + assert cache.get("a") == 1 + assert cache.get("c") == 3 diff --git a/tests/test_client_message_cache.py b/tests/test_client_message_cache.py new file mode 100644 index 0000000..f89d98c --- /dev/null +++ b/tests/test_client_message_cache.py @@ -0,0 +1,23 @@ +import pytest + +from disagreement.client import Client + + +def _add_message(client: Client, message_id: str) -> None: + data = { + "id": message_id, + "channel_id": "c", + "author": {"id": "u", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + client.parse_message(data) + + +def test_client_message_cache_size(): + client = Client(token="t", message_cache_maxlen=1) + _add_message(client, "1") + assert client._messages.get("1").id == "1" + _add_message(client, "2") + assert client._messages.get("1") is None + assert client._messages.get("2").id == "2" diff --git a/tests/test_embed_methods.py b/tests/test_embed_methods.py new file mode 100644 index 0000000..227b179 --- /dev/null +++ b/tests/test_embed_methods.py @@ -0,0 +1,18 @@ +from disagreement.models import Embed + + +def test_embed_helper_methods(): + embed = ( + Embed() + .set_author(name="name", url="url", icon_url="icon") + .add_field(name="n", value="v") + .set_footer(text="footer", icon_url="icon") + .set_image(url="https://example.com/image.png") + ) + + assert embed.author.name == "name" + assert embed.author.url == "url" + assert embed.author.icon_url == "icon" + assert len(embed.fields) == 1 and embed.fields[0].name == "n" + assert embed.footer.text == "footer" + assert embed.image.url == "https://example.com/image.png" diff --git a/tests/test_gateway_backoff.py b/tests/test_gateway_backoff.py index a4c3723..f64a62b 100644 --- a/tests/test_gateway_backoff.py +++ b/tests/test_gateway_backoff.py @@ -24,7 +24,7 @@ class DummyDispatcher: class DummyClient: def __init__(self): - self.loop = asyncio.get_event_loop() + self.loop = asyncio.get_running_loop() self.application_id = None # Mock application_id for Client.connect @@ -39,7 +39,7 @@ async def test_client_connect_backoff(monkeypatch): client = Client( token="test_token", intents=0, - loop=asyncio.get_event_loop(), + loop=asyncio.get_running_loop(), command_prefix="!", verbose=False, mention_replies=False, diff --git a/tests/test_message_clean_content.py b/tests/test_message_clean_content.py new file mode 100644 index 0000000..83b1b91 --- /dev/null +++ b/tests/test_message_clean_content.py @@ -0,0 +1,23 @@ +import types +from disagreement.models import Message + + +def make_message(content: str) -> Message: + data = { + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": content, + "timestamp": "t", + } + return Message(data, client_instance=types.SimpleNamespace()) + + +def test_clean_content_removes_mentions(): + msg = make_message("Hello <@123> <#456> <@&789> world") + assert msg.clean_content == "Hello world" + + +def test_clean_content_no_mentions(): + msg = make_message("Just text") + assert msg.clean_content == "Just text" diff --git a/tests/test_presence_update.py b/tests/test_presence_update.py index 3135cf7..73551d9 100644 --- a/tests/test_presence_update.py +++ b/tests/test_presence_update.py @@ -2,6 +2,7 @@ import pytest from unittest.mock import AsyncMock from disagreement.client import Client +from disagreement.models import Game from disagreement.errors import DisagreementException @@ -18,11 +19,11 @@ class DummyGateway(MagicMock): async def test_change_presence_passes_arguments(): client = Client(token="t") client._gateway = DummyGateway() - - await client.change_presence(status="idle", activity_name="hi", activity_type=0) + game = Game("hi") + await client.change_presence(status="idle", activity=game) client._gateway.update_presence.assert_awaited_once_with( - status="idle", activity_name="hi", activity_type=0, since=0, afk=False + status="idle", activity=game, since=0, afk=False ) diff --git a/tests/test_voice_client.py b/tests/test_voice_client.py index da052bc..1c6cb98 100644 --- a/tests/test_voice_client.py +++ b/tests/test_voice_client.py @@ -1,8 +1,11 @@ import asyncio +import io +from array import array import pytest +from disagreement.audio import AudioSource, FFmpegAudioSource + from disagreement.voice_client import VoiceClient -from disagreement.audio import AudioSource from disagreement.client import Client @@ -137,3 +140,68 @@ async def test_play_and_switch_sources(): await vc.play(DummySource([b"c"])) assert udp.sent == [b"a", b"b", b"c"] + + +@pytest.mark.asyncio +async def test_ffmpeg_source_custom_options(monkeypatch): + captured = {} + + class DummyProcess: + def __init__(self): + self.stdout = io.BytesIO(b"") + + async def wait(self): + return 0 + + async def fake_exec(*args, **kwargs): + captured["args"] = args + return DummyProcess() + + monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_exec) + src = FFmpegAudioSource( + "file.mp3", before_options="-reconnect 1", options="-vn", volume=0.5 + ) + + await src._spawn() + + cmd = captured["args"] + assert "-reconnect" in cmd + assert "-vn" in cmd + assert src.volume == 0.5 + + +@pytest.mark.asyncio +async def test_voice_client_volume_scaling(monkeypatch): + ws = DummyWebSocket( + [ + {"d": {"heartbeat_interval": 50}}, + {"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}}, + {"d": {"secret_key": []}}, + ] + ) + udp = DummyUDP() + vc = VoiceClient( + client=DummyVoiceClient(), + endpoint="ws://localhost", + session_id="sess", + token="tok", + guild_id=1, + user_id=2, + ws=ws, + udp=udp, + ) + await vc.connect() + vc._heartbeat_task.cancel() + + chunk = b"\x10\x00\x10\x00" + src = DummySource([chunk]) + src.volume = 0.5 + + await vc.play(src) + + samples = array("h") + samples.frombytes(chunk) + samples[0] = int(samples[0] * 0.5) + samples[1] = int(samples[1] * 0.5) + expected = samples.tobytes() + assert udp.sent == [expected] diff --git a/tests/test_widget.py b/tests/test_widget.py new file mode 100644 index 0000000..45a3e13 --- /dev/null +++ b/tests/test_widget.py @@ -0,0 +1,50 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.http import HTTPClient +from disagreement.client import Client + + +@pytest.mark.asyncio +async def test_get_guild_widget_calls_request(): + http = HTTPClient(token="t") + http.request = AsyncMock(return_value={}) + await http.get_guild_widget("1") + http.request.assert_called_once_with("GET", "/guilds/1/widget") + + +@pytest.mark.asyncio +async def test_edit_guild_widget_calls_request(): + http = HTTPClient(token="t") + http.request = AsyncMock(return_value={}) + payload = {"enabled": True} + await http.edit_guild_widget("1", payload) + http.request.assert_called_once_with("PATCH", "/guilds/1/widget", payload=payload) + + +@pytest.mark.asyncio +async def test_client_fetch_widget_returns_data(): + http = SimpleNamespace(get_guild_widget=AsyncMock(return_value={"enabled": True})) + client = Client.__new__(Client) + client._http = http + client._closed = False + + data = await client.fetch_widget("1") + + http.get_guild_widget.assert_awaited_once_with("1") + assert data == {"enabled": True} + + +@pytest.mark.asyncio +async def test_client_edit_widget_returns_data(): + http = SimpleNamespace(edit_guild_widget=AsyncMock(return_value={"enabled": False})) + client = Client.__new__(Client) + client._http = http + client._closed = False + + payload = {"enabled": False} + data = await client.edit_widget("1", payload) + + http.edit_guild_widget.assert_awaited_once_with("1", payload) + assert data == {"enabled": False}