Compare commits

...

2 Commits

5 changed files with 193 additions and 1 deletions

View File

@ -16,6 +16,7 @@ from .decorators import (
check, check,
check_any, check_any,
cooldown, cooldown,
max_concurrency,
requires_permissions, requires_permissions,
) )
from .errors import ( from .errors import (
@ -28,6 +29,7 @@ from .errors import (
CheckAnyFailure, CheckAnyFailure,
CommandOnCooldown, CommandOnCooldown,
CommandInvokeError, CommandInvokeError,
MaxConcurrencyReached,
) )
__all__ = [ __all__ = [
@ -43,6 +45,7 @@ __all__ = [
"check", "check",
"check_any", "check_any",
"cooldown", "cooldown",
"max_concurrency",
"requires_permissions", "requires_permissions",
# Errors # Errors
"CommandError", "CommandError",
@ -54,4 +57,5 @@ __all__ = [
"CheckAnyFailure", "CheckAnyFailure",
"CommandOnCooldown", "CommandOnCooldown",
"CommandInvokeError", "CommandInvokeError",
"MaxConcurrencyReached",
] ]

View File

@ -70,6 +70,10 @@ class Command:
if hasattr(callback, "__command_checks__"): if hasattr(callback, "__command_checks__"):
self.checks.extend(getattr(callback, "__command_checks__")) self.checks.extend(getattr(callback, "__command_checks__"))
self.max_concurrency: Optional[Tuple[int, str]] = None
if hasattr(callback, "__max_concurrency__"):
self.max_concurrency = getattr(callback, "__max_concurrency__")
def add_check( def add_check(
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool] self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
) -> None: ) -> None:
@ -215,6 +219,7 @@ class CommandHandler:
] = prefix ] = prefix
self.commands: Dict[str, Command] = {} self.commands: Dict[str, Command] = {}
self.cogs: Dict[str, "Cog"] = {} self.cogs: Dict[str, "Cog"] = {}
self._concurrency: Dict[str, Dict[str, int]] = {}
from .help import HelpCommand from .help import HelpCommand
@ -300,6 +305,47 @@ class CommandHandler:
logger.info("Cog '%s' removed.", cog_name) logger.info("Cog '%s' removed.", cog_name)
return cog_to_remove return cog_to_remove
def _acquire_concurrency(self, ctx: CommandContext) -> None:
mc = getattr(ctx.command, "max_concurrency", None)
if not mc:
return
limit, scope = mc
if scope == "user":
key = ctx.author.id
elif scope == "guild":
key = ctx.message.guild_id or ctx.author.id
else:
key = "global"
buckets = self._concurrency.setdefault(ctx.command.name, {})
current = buckets.get(key, 0)
if current >= limit:
from .errors import MaxConcurrencyReached
raise MaxConcurrencyReached(limit)
buckets[key] = current + 1
def _release_concurrency(self, ctx: CommandContext) -> None:
mc = getattr(ctx.command, "max_concurrency", None)
if not mc:
return
_, scope = mc
if scope == "user":
key = ctx.author.id
elif scope == "guild":
key = ctx.message.guild_id or ctx.author.id
else:
key = "global"
buckets = self._concurrency.get(ctx.command.name)
if not buckets:
return
current = buckets.get(key, 0)
if current <= 1:
buckets.pop(key, None)
else:
buckets[key] = current - 1
if not buckets:
self._concurrency.pop(ctx.command.name, None)
async def get_prefix(self, message: "Message") -> Union[str, List[str], None]: async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
if callable(self.prefix): if callable(self.prefix):
if inspect.iscoroutinefunction(self.prefix): if inspect.iscoroutinefunction(self.prefix):
@ -501,7 +547,11 @@ class CommandHandler:
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view) parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
ctx.args = parsed_args ctx.args = parsed_args
ctx.kwargs = parsed_kwargs ctx.kwargs = parsed_kwargs
await command.invoke(ctx, *parsed_args, **parsed_kwargs) self._acquire_concurrency(ctx)
try:
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
finally:
self._release_concurrency(ctx)
except CommandError as e: except CommandError as e:
logger.error("Command error for '%s': %s", command.name, e) logger.error("Command error for '%s': %s", command.name, e)
if hasattr(self.client, "on_command_error"): if hasattr(self.client, "on_command_error"):

View File

@ -107,6 +107,33 @@ def check_any(
return check(predicate) return check(predicate)
def max_concurrency(
number: int, per: str = "user"
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Limit how many concurrent invocations of a command are allowed.
Parameters
----------
number:
The maximum number of concurrent invocations.
per:
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
"""
if number < 1:
raise ValueError("Concurrency number must be at least 1.")
if per not in {"user", "guild", "global"}:
raise ValueError("per must be 'user', 'guild', or 'global'.")
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
setattr(func, "__max_concurrency__", (number, per))
return func
return decorator
def cooldown( def cooldown(
rate: int, per: float rate: int, per: float
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:

View File

@ -72,5 +72,13 @@ class CommandInvokeError(CommandError):
super().__init__(f"Error during command invocation: {original}") super().__init__(f"Error during command invocation: {original}")
class MaxConcurrencyReached(CommandError):
"""Raised when a command exceeds its concurrency limit."""
def __init__(self, limit: int):
self.limit = limit
super().__init__(f"Max concurrency of {limit} reached")
# Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc. # Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc.
# These might inherit from BadArgument. # These might inherit from BadArgument.

View File

@ -0,0 +1,103 @@
import asyncio
import pytest
from disagreement.ext.commands.core import CommandHandler
from disagreement.ext.commands.decorators import command, max_concurrency
from disagreement.ext.commands.errors import MaxConcurrencyReached
from disagreement.models import Message
class DummyBot:
def __init__(self):
self.errors = []
async def on_command_error(self, ctx, error):
self.errors.append(error)
@pytest.mark.asyncio
async def test_max_concurrency_per_user():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
started = asyncio.Event()
release = asyncio.Event()
@command()
@max_concurrency(1, per="user")
async def foo(ctx):
started.set()
await release.wait()
handler.add_command(foo.__command_object__)
data = {
"id": "1",
"channel_id": "c",
"guild_id": "g",
"author": {"id": "a", "username": "u", "discriminator": "0001"},
"content": "!foo",
"timestamp": "t",
}
msg1 = Message(data, client_instance=bot)
msg2 = Message({**data, "id": "2"}, client_instance=bot)
task = asyncio.create_task(handler.process_commands(msg1))
await started.wait()
await handler.process_commands(msg2)
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
release.set()
await task
await handler.process_commands(msg2)
@pytest.mark.asyncio
async def test_max_concurrency_per_guild():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
started = asyncio.Event()
release = asyncio.Event()
@command()
@max_concurrency(1, per="guild")
async def foo(ctx):
started.set()
await release.wait()
handler.add_command(foo.__command_object__)
base = {
"channel_id": "c",
"guild_id": "g",
"content": "!foo",
"timestamp": "t",
}
msg1 = Message(
{
**base,
"id": "1",
"author": {"id": "a", "username": "u", "discriminator": "0001"},
},
client_instance=bot,
)
msg2 = Message(
{
**base,
"id": "2",
"author": {"id": "b", "username": "v", "discriminator": "0001"},
},
client_instance=bot,
)
task = asyncio.create_task(handler.process_commands(msg1))
await started.wait()
await handler.process_commands(msg2)
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
release.set()
await task
await handler.process_commands(msg2)