diff --git a/disagreement/color.py b/disagreement/color.py index 6b8c063..28bd96a 100644 --- a/disagreement/color.py +++ b/disagreement/color.py @@ -46,6 +46,103 @@ class Color: def blue(cls) -> "Color": return cls(0x0000FF) + # Discord brand colors + @classmethod + def blurple(cls) -> "Color": + """Discord brand blurple (#5865F2).""" + return cls(0x5865F2) + + @classmethod + def light_blurple(cls) -> "Color": + """Light blurple used by Discord (#E0E3FF).""" + return cls(0xE0E3FF) + + @classmethod + def legacy_blurple(cls) -> "Color": + """Legacy Discord blurple (#7289DA).""" + return cls(0x7289DA) + + # Additional assorted colors + @classmethod + def teal(cls) -> "Color": + return cls(0x1ABC9C) + + @classmethod + def dark_teal(cls) -> "Color": + return cls(0x11806A) + + @classmethod + def brand_green(cls) -> "Color": + return cls(0x57F287) + + @classmethod + def dark_green(cls) -> "Color": + return cls(0x206694) + + @classmethod + def orange(cls) -> "Color": + return cls(0xE67E22) + + @classmethod + def dark_orange(cls) -> "Color": + return cls(0xA84300) + + @classmethod + def brand_red(cls) -> "Color": + return cls(0xED4245) + + @classmethod + def dark_red(cls) -> "Color": + return cls(0x992D22) + + @classmethod + def magenta(cls) -> "Color": + return cls(0xE91E63) + + @classmethod + def dark_magenta(cls) -> "Color": + return cls(0xAD1457) + + @classmethod + def purple(cls) -> "Color": + return cls(0x9B59B6) + + @classmethod + def dark_purple(cls) -> "Color": + return cls(0x71368A) + + @classmethod + def yellow(cls) -> "Color": + return cls(0xF1C40F) + + @classmethod + def dark_gold(cls) -> "Color": + return cls(0xC27C0E) + + @classmethod + def light_gray(cls) -> "Color": + return cls(0x99AAB5) + + @classmethod + def dark_gray(cls) -> "Color": + return cls(0x2C2F33) + + @classmethod + def lighter_gray(cls) -> "Color": + return cls(0xBFBFBF) + + @classmethod + def darker_gray(cls) -> "Color": + return cls(0x23272A) + + @classmethod + def black(cls) -> "Color": + return cls(0x000000) + + @classmethod + def white(cls) -> "Color": + return cls(0xFFFFFF) + def to_rgb(self) -> tuple[int, int, int]: return ((self.value >> 16) & 0xFF, (self.value >> 8) & 0xFF, self.value & 0xFF) diff --git a/disagreement/ext/app_commands/handler.py b/disagreement/ext/app_commands/handler.py index b9ca1e6..7c5e4fa 100644 --- a/disagreement/ext/app_commands/handler.py +++ b/disagreement/ext/app_commands/handler.py @@ -1,7 +1,9 @@ # disagreement/ext/app_commands/handler.py import inspect +import json import logging +import os from typing import ( TYPE_CHECKING, Dict, @@ -67,6 +69,8 @@ if not TYPE_CHECKING: logger = logging.getLogger(__name__) +COMMANDS_CACHE_FILE = ".disagreement_commands.json" + class AppCommandHandler: """ @@ -84,6 +88,33 @@ class AppCommandHandler: self._app_command_groups: Dict[str, AppCommandGroup] = {} self._converter_registry: Dict[type, type] = {} + def _load_cached_ids(self) -> Dict[str, Dict[str, str]]: + try: + with open(COMMANDS_CACHE_FILE, "r", encoding="utf-8") as fp: + return json.load(fp) + except FileNotFoundError: + return {} + except json.JSONDecodeError: + logger.warning("Invalid command cache file. Ignoring.") + return {} + + def _save_cached_ids(self, data: Dict[str, Dict[str, str]]) -> None: + try: + with open(COMMANDS_CACHE_FILE, "w", encoding="utf-8") as fp: + json.dump(data, fp, indent=2) + except Exception as e: # pragma: no cover - logging only + logger.error("Failed to write command cache: %s", e) + + def clear_stored_registrations(self) -> None: + """Remove persisted command registration data.""" + if os.path.exists(COMMANDS_CACHE_FILE): + os.remove(COMMANDS_CACHE_FILE) + + def migrate_stored_registrations(self, new_path: str) -> None: + """Move stored registrations to ``new_path``.""" + if os.path.exists(COMMANDS_CACHE_FILE): + os.replace(COMMANDS_CACHE_FILE, new_path) + def add_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None: """Adds an application command or a command group to the handler.""" if isinstance(command, AppCommandGroup): @@ -564,11 +595,13 @@ class AppCommandHandler: Synchronizes (registers/updates) all application commands with Discord. If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands. """ - commands_to_sync: List[Dict[str, Any]] = [] + cache = self._load_cached_ids() + scope_key = str(guild_id) if guild_id else "global" + stored = cache.get(scope_key, {}) + + current_payloads: Dict[str, Dict[str, Any]] = {} # Collect commands based on scope (global or specific guild) - # This needs to be more sophisticated to handle guild_ids on commands/groups - source_commands = ( list(self._slash_commands.values()) + list(self._user_commands.values()) @@ -577,26 +610,22 @@ class AppCommandHandler: ) for cmd_or_group in source_commands: - # Determine if this command/group should be synced for the current scope is_guild_specific_command = ( cmd_or_group.guild_ids is not None and len(cmd_or_group.guild_ids) > 0 ) - if guild_id: # Syncing for a specific guild - # Skip if not a guild-specific command OR if it's for a different guild + if guild_id: if not is_guild_specific_command or ( cmd_or_group.guild_ids is not None and guild_id not in cmd_or_group.guild_ids ): continue - else: # Syncing global commands + else: if is_guild_specific_command: - continue # Skip guild-specific commands when syncing global + continue - # Use the to_dict() method from AppCommand or AppCommandGroup try: - payload = cmd_or_group.to_dict() - commands_to_sync.append(payload) + current_payloads[cmd_or_group.name] = cmd_or_group.to_dict() except AttributeError: logger.warning( "Command or group '%s' does not have a to_dict() method. Skipping.", @@ -609,32 +638,74 @@ class AppCommandHandler: e, ) - if not commands_to_sync: + if not current_payloads: logger.info( "No commands to sync for %s scope.", f"guild {guild_id}" if guild_id else "global", ) return + names_current = set(current_payloads) + names_stored = set(stored) + + to_delete = names_stored - names_current + to_create = names_current - names_stored + to_update = names_current & names_stored + + if not to_delete and not to_create and not to_update: + logger.info( + "Application commands already up to date for %s scope.", scope_key + ) + return + try: - if guild_id: - logger.info( - "Syncing %s commands for guild %s...", - len(commands_to_sync), - guild_id, - ) - await self.client._http.bulk_overwrite_guild_application_commands( - application_id, guild_id, commands_to_sync - ) - else: - logger.info( - "Syncing %s global commands...", - len(commands_to_sync), - ) - await self.client._http.bulk_overwrite_global_application_commands( - application_id, commands_to_sync - ) + for name in to_delete: + cmd_id = stored[name] + if guild_id: + await self.client._http.delete_guild_application_command( + application_id, guild_id, cmd_id + ) + else: + await self.client._http.delete_global_application_command( + application_id, cmd_id + ) + + new_ids: Dict[str, str] = {} + for name in to_create: + payload = current_payloads[name] + if guild_id: + result = await self.client._http.create_guild_application_command( + application_id, guild_id, payload + ) + else: + result = await self.client._http.create_global_application_command( + application_id, payload + ) + if result.id: + new_ids[name] = str(result.id) + + for name in to_update: + payload = current_payloads[name] + cmd_id = stored[name] + if guild_id: + await self.client._http.edit_guild_application_command( + application_id, guild_id, cmd_id, payload + ) + else: + await self.client._http.edit_global_application_command( + application_id, cmd_id, payload + ) + new_ids[name] = cmd_id + + final_ids: Dict[str, str] = {} + for name in names_current: + if name in new_ids: + final_ids[name] = new_ids[name] + else: + final_ids[name] = stored[name] + + cache[scope_key] = final_ids + self._save_cached_ids(cache) logger.info("Command sync successful.") except Exception as e: logger.error("Error syncing application commands: %s", e) - # Consider re-raising or specific error handling diff --git a/disagreement/ext/commands/__init__.py b/disagreement/ext/commands/__init__.py index 21b1e35..7dcde6b 100644 --- a/disagreement/ext/commands/__init__.py +++ b/disagreement/ext/commands/__init__.py @@ -16,6 +16,7 @@ from .decorators import ( check, check_any, cooldown, + max_concurrency, requires_permissions, ) from .errors import ( @@ -28,6 +29,7 @@ from .errors import ( CheckAnyFailure, CommandOnCooldown, CommandInvokeError, + MaxConcurrencyReached, ) __all__ = [ @@ -43,6 +45,7 @@ __all__ = [ "check", "check_any", "cooldown", + "max_concurrency", "requires_permissions", # Errors "CommandError", @@ -54,4 +57,5 @@ __all__ = [ "CheckAnyFailure", "CommandOnCooldown", "CommandInvokeError", + "MaxConcurrencyReached", ] diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py index adb816a..1eb4cdf 100644 --- a/disagreement/ext/commands/core.py +++ b/disagreement/ext/commands/core.py @@ -70,6 +70,10 @@ class Command: if hasattr(callback, "__command_checks__"): self.checks.extend(getattr(callback, "__command_checks__")) + self.max_concurrency: Optional[Tuple[int, str]] = None + if hasattr(callback, "__max_concurrency__"): + self.max_concurrency = getattr(callback, "__max_concurrency__") + def add_check( self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool] ) -> None: @@ -215,6 +219,7 @@ class CommandHandler: ] = prefix self.commands: Dict[str, Command] = {} self.cogs: Dict[str, "Cog"] = {} + self._concurrency: Dict[str, Dict[str, int]] = {} from .help import HelpCommand @@ -300,6 +305,47 @@ class CommandHandler: logger.info("Cog '%s' removed.", cog_name) return cog_to_remove + def _acquire_concurrency(self, ctx: CommandContext) -> None: + mc = getattr(ctx.command, "max_concurrency", None) + if not mc: + return + limit, scope = mc + if scope == "user": + key = ctx.author.id + elif scope == "guild": + key = ctx.message.guild_id or ctx.author.id + else: + key = "global" + buckets = self._concurrency.setdefault(ctx.command.name, {}) + current = buckets.get(key, 0) + if current >= limit: + from .errors import MaxConcurrencyReached + + raise MaxConcurrencyReached(limit) + buckets[key] = current + 1 + + def _release_concurrency(self, ctx: CommandContext) -> None: + mc = getattr(ctx.command, "max_concurrency", None) + if not mc: + return + _, scope = mc + if scope == "user": + key = ctx.author.id + elif scope == "guild": + key = ctx.message.guild_id or ctx.author.id + else: + key = "global" + buckets = self._concurrency.get(ctx.command.name) + if not buckets: + return + current = buckets.get(key, 0) + if current <= 1: + buckets.pop(key, None) + else: + buckets[key] = current - 1 + if not buckets: + self._concurrency.pop(ctx.command.name, None) + async def get_prefix(self, message: "Message") -> Union[str, List[str], None]: if callable(self.prefix): if inspect.iscoroutinefunction(self.prefix): @@ -501,7 +547,11 @@ class CommandHandler: parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view) ctx.args = parsed_args ctx.kwargs = parsed_kwargs - await command.invoke(ctx, *parsed_args, **parsed_kwargs) + self._acquire_concurrency(ctx) + try: + await command.invoke(ctx, *parsed_args, **parsed_kwargs) + finally: + self._release_concurrency(ctx) except CommandError as e: logger.error("Command error for '%s': %s", command.name, e) if hasattr(self.client, "on_command_error"): diff --git a/disagreement/ext/commands/decorators.py b/disagreement/ext/commands/decorators.py index 53f540c..8413cc1 100644 --- a/disagreement/ext/commands/decorators.py +++ b/disagreement/ext/commands/decorators.py @@ -107,6 +107,33 @@ def check_any( return check(predicate) +def max_concurrency( + number: int, per: str = "user" +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Limit how many concurrent invocations of a command are allowed. + + Parameters + ---------- + number: + The maximum number of concurrent invocations. + per: + The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``. + """ + + if number < 1: + raise ValueError("Concurrency number must be at least 1.") + if per not in {"user", "guild", "global"}: + raise ValueError("per must be 'user', 'guild', or 'global'.") + + def decorator( + func: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + setattr(func, "__max_concurrency__", (number, per)) + return func + + return decorator + + def cooldown( rate: int, per: float ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: diff --git a/disagreement/ext/commands/errors.py b/disagreement/ext/commands/errors.py index 5fa6f06..0600ad7 100644 --- a/disagreement/ext/commands/errors.py +++ b/disagreement/ext/commands/errors.py @@ -72,5 +72,13 @@ class CommandInvokeError(CommandError): super().__init__(f"Error during command invocation: {original}") +class MaxConcurrencyReached(CommandError): + """Raised when a command exceeds its concurrency limit.""" + + def __init__(self, limit: int): + self.limit = limit + super().__init__(f"Max concurrency of {limit} reached") + + # Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc. # These might inherit from BadArgument. diff --git a/docs/slash_commands.md b/docs/slash_commands.md index 42b24ff..31f526c 100644 --- a/docs/slash_commands.md +++ b/docs/slash_commands.md @@ -20,3 +20,14 @@ Use `AppCommandGroup` to group related commands. See the [components guide](usin - [Caching](caching.md) - [Voice Features](voice_features.md) +## Command Persistence + +`AppCommandHandler.sync_commands` can persist registered command IDs in +`.disagreement_commands.json`. When enabled, subsequent syncs compare the +stored IDs to the commands defined in code and only create, edit or delete +commands when changes are detected. + +Call `AppCommandHandler.clear_stored_registrations()` if you need to wipe the +stored IDs or migrate them elsewhere with +`AppCommandHandler.migrate_stored_registrations()`. + diff --git a/tests/test_color.py b/tests/test_color.py index c2a5aa8..31edde8 100644 --- a/tests/test_color.py +++ b/tests/test_color.py @@ -11,3 +11,6 @@ def test_static_colors(): assert Color.red().value == 0xFF0000 assert Color.green().value == 0x00FF00 assert Color.blue().value == 0x0000FF + assert Color.blurple().value == 0x5865F2 + assert Color.light_blurple().value == 0xE0E3FF + assert Color.legacy_blurple().value == 0x7289DA diff --git a/tests/test_max_concurrency.py b/tests/test_max_concurrency.py new file mode 100644 index 0000000..f225769 --- /dev/null +++ b/tests/test_max_concurrency.py @@ -0,0 +1,103 @@ +import asyncio +import pytest + +from disagreement.ext.commands.core import CommandHandler +from disagreement.ext.commands.decorators import command, max_concurrency +from disagreement.ext.commands.errors import MaxConcurrencyReached +from disagreement.models import Message + + +class DummyBot: + def __init__(self): + self.errors = [] + + async def on_command_error(self, ctx, error): + self.errors.append(error) + + +@pytest.mark.asyncio +async def test_max_concurrency_per_user(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + started = asyncio.Event() + release = asyncio.Event() + + @command() + @max_concurrency(1, per="user") + async def foo(ctx): + started.set() + await release.wait() + + handler.add_command(foo.__command_object__) + + data = { + "id": "1", + "channel_id": "c", + "guild_id": "g", + "author": {"id": "a", "username": "u", "discriminator": "0001"}, + "content": "!foo", + "timestamp": "t", + } + msg1 = Message(data, client_instance=bot) + msg2 = Message({**data, "id": "2"}, client_instance=bot) + + task = asyncio.create_task(handler.process_commands(msg1)) + await started.wait() + + await handler.process_commands(msg2) + assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors) + + release.set() + await task + + await handler.process_commands(msg2) + + +@pytest.mark.asyncio +async def test_max_concurrency_per_guild(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + started = asyncio.Event() + release = asyncio.Event() + + @command() + @max_concurrency(1, per="guild") + async def foo(ctx): + started.set() + await release.wait() + + handler.add_command(foo.__command_object__) + + base = { + "channel_id": "c", + "guild_id": "g", + "content": "!foo", + "timestamp": "t", + } + msg1 = Message( + { + **base, + "id": "1", + "author": {"id": "a", "username": "u", "discriminator": "0001"}, + }, + client_instance=bot, + ) + msg2 = Message( + { + **base, + "id": "2", + "author": {"id": "b", "username": "v", "discriminator": "0001"}, + }, + client_instance=bot, + ) + + task = asyncio.create_task(handler.process_commands(msg1)) + await started.wait() + + await handler.process_commands(msg2) + assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors) + + release.set() + await task + + await handler.process_commands(msg2)