diff --git a/disagreement/ext/commands/__init__.py b/disagreement/ext/commands/__init__.py index 21b1e35..7dcde6b 100644 --- a/disagreement/ext/commands/__init__.py +++ b/disagreement/ext/commands/__init__.py @@ -16,6 +16,7 @@ from .decorators import ( check, check_any, cooldown, + max_concurrency, requires_permissions, ) from .errors import ( @@ -28,6 +29,7 @@ from .errors import ( CheckAnyFailure, CommandOnCooldown, CommandInvokeError, + MaxConcurrencyReached, ) __all__ = [ @@ -43,6 +45,7 @@ __all__ = [ "check", "check_any", "cooldown", + "max_concurrency", "requires_permissions", # Errors "CommandError", @@ -54,4 +57,5 @@ __all__ = [ "CheckAnyFailure", "CommandOnCooldown", "CommandInvokeError", + "MaxConcurrencyReached", ] diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py index adb816a..1eb4cdf 100644 --- a/disagreement/ext/commands/core.py +++ b/disagreement/ext/commands/core.py @@ -70,6 +70,10 @@ class Command: if hasattr(callback, "__command_checks__"): self.checks.extend(getattr(callback, "__command_checks__")) + self.max_concurrency: Optional[Tuple[int, str]] = None + if hasattr(callback, "__max_concurrency__"): + self.max_concurrency = getattr(callback, "__max_concurrency__") + def add_check( self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool] ) -> None: @@ -215,6 +219,7 @@ class CommandHandler: ] = prefix self.commands: Dict[str, Command] = {} self.cogs: Dict[str, "Cog"] = {} + self._concurrency: Dict[str, Dict[str, int]] = {} from .help import HelpCommand @@ -300,6 +305,47 @@ class CommandHandler: logger.info("Cog '%s' removed.", cog_name) return cog_to_remove + def _acquire_concurrency(self, ctx: CommandContext) -> None: + mc = getattr(ctx.command, "max_concurrency", None) + if not mc: + return + limit, scope = mc + if scope == "user": + key = ctx.author.id + elif scope == "guild": + key = ctx.message.guild_id or ctx.author.id + else: + key = "global" + buckets = self._concurrency.setdefault(ctx.command.name, {}) + current = buckets.get(key, 0) + if current >= limit: + from .errors import MaxConcurrencyReached + + raise MaxConcurrencyReached(limit) + buckets[key] = current + 1 + + def _release_concurrency(self, ctx: CommandContext) -> None: + mc = getattr(ctx.command, "max_concurrency", None) + if not mc: + return + _, scope = mc + if scope == "user": + key = ctx.author.id + elif scope == "guild": + key = ctx.message.guild_id or ctx.author.id + else: + key = "global" + buckets = self._concurrency.get(ctx.command.name) + if not buckets: + return + current = buckets.get(key, 0) + if current <= 1: + buckets.pop(key, None) + else: + buckets[key] = current - 1 + if not buckets: + self._concurrency.pop(ctx.command.name, None) + async def get_prefix(self, message: "Message") -> Union[str, List[str], None]: if callable(self.prefix): if inspect.iscoroutinefunction(self.prefix): @@ -501,7 +547,11 @@ class CommandHandler: parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view) ctx.args = parsed_args ctx.kwargs = parsed_kwargs - await command.invoke(ctx, *parsed_args, **parsed_kwargs) + self._acquire_concurrency(ctx) + try: + await command.invoke(ctx, *parsed_args, **parsed_kwargs) + finally: + self._release_concurrency(ctx) except CommandError as e: logger.error("Command error for '%s': %s", command.name, e) if hasattr(self.client, "on_command_error"): diff --git a/disagreement/ext/commands/decorators.py b/disagreement/ext/commands/decorators.py index 53f540c..8413cc1 100644 --- a/disagreement/ext/commands/decorators.py +++ b/disagreement/ext/commands/decorators.py @@ -107,6 +107,33 @@ def check_any( return check(predicate) +def max_concurrency( + number: int, per: str = "user" +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Limit how many concurrent invocations of a command are allowed. + + Parameters + ---------- + number: + The maximum number of concurrent invocations. + per: + The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``. + """ + + if number < 1: + raise ValueError("Concurrency number must be at least 1.") + if per not in {"user", "guild", "global"}: + raise ValueError("per must be 'user', 'guild', or 'global'.") + + def decorator( + func: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + setattr(func, "__max_concurrency__", (number, per)) + return func + + return decorator + + def cooldown( rate: int, per: float ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: diff --git a/disagreement/ext/commands/errors.py b/disagreement/ext/commands/errors.py index 5fa6f06..0600ad7 100644 --- a/disagreement/ext/commands/errors.py +++ b/disagreement/ext/commands/errors.py @@ -72,5 +72,13 @@ class CommandInvokeError(CommandError): super().__init__(f"Error during command invocation: {original}") +class MaxConcurrencyReached(CommandError): + """Raised when a command exceeds its concurrency limit.""" + + def __init__(self, limit: int): + self.limit = limit + super().__init__(f"Max concurrency of {limit} reached") + + # Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc. # These might inherit from BadArgument. diff --git a/tests/test_max_concurrency.py b/tests/test_max_concurrency.py new file mode 100644 index 0000000..f225769 --- /dev/null +++ b/tests/test_max_concurrency.py @@ -0,0 +1,103 @@ +import asyncio +import pytest + +from disagreement.ext.commands.core import CommandHandler +from disagreement.ext.commands.decorators import command, max_concurrency +from disagreement.ext.commands.errors import MaxConcurrencyReached +from disagreement.models import Message + + +class DummyBot: + def __init__(self): + self.errors = [] + + async def on_command_error(self, ctx, error): + self.errors.append(error) + + +@pytest.mark.asyncio +async def test_max_concurrency_per_user(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + started = asyncio.Event() + release = asyncio.Event() + + @command() + @max_concurrency(1, per="user") + async def foo(ctx): + started.set() + await release.wait() + + handler.add_command(foo.__command_object__) + + data = { + "id": "1", + "channel_id": "c", + "guild_id": "g", + "author": {"id": "a", "username": "u", "discriminator": "0001"}, + "content": "!foo", + "timestamp": "t", + } + msg1 = Message(data, client_instance=bot) + msg2 = Message({**data, "id": "2"}, client_instance=bot) + + task = asyncio.create_task(handler.process_commands(msg1)) + await started.wait() + + await handler.process_commands(msg2) + assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors) + + release.set() + await task + + await handler.process_commands(msg2) + + +@pytest.mark.asyncio +async def test_max_concurrency_per_guild(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + started = asyncio.Event() + release = asyncio.Event() + + @command() + @max_concurrency(1, per="guild") + async def foo(ctx): + started.set() + await release.wait() + + handler.add_command(foo.__command_object__) + + base = { + "channel_id": "c", + "guild_id": "g", + "content": "!foo", + "timestamp": "t", + } + msg1 = Message( + { + **base, + "id": "1", + "author": {"id": "a", "username": "u", "discriminator": "0001"}, + }, + client_instance=bot, + ) + msg2 = Message( + { + **base, + "id": "2", + "author": {"id": "b", "username": "v", "discriminator": "0001"}, + }, + client_instance=bot, + ) + + task = asyncio.create_task(handler.process_commands(msg1)) + await started.wait() + + await handler.process_commands(msg2) + assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors) + + release.set() + await task + + await handler.process_commands(msg2)