Add max concurrency support
This commit is contained in:
parent
669f00e745
commit
b0f9381fa6
@ -16,6 +16,7 @@ from .decorators import (
|
||||
check,
|
||||
check_any,
|
||||
cooldown,
|
||||
max_concurrency,
|
||||
requires_permissions,
|
||||
)
|
||||
from .errors import (
|
||||
@ -28,6 +29,7 @@ from .errors import (
|
||||
CheckAnyFailure,
|
||||
CommandOnCooldown,
|
||||
CommandInvokeError,
|
||||
MaxConcurrencyReached,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -43,6 +45,7 @@ __all__ = [
|
||||
"check",
|
||||
"check_any",
|
||||
"cooldown",
|
||||
"max_concurrency",
|
||||
"requires_permissions",
|
||||
# Errors
|
||||
"CommandError",
|
||||
@ -54,4 +57,5 @@ __all__ = [
|
||||
"CheckAnyFailure",
|
||||
"CommandOnCooldown",
|
||||
"CommandInvokeError",
|
||||
"MaxConcurrencyReached",
|
||||
]
|
||||
|
@ -70,6 +70,10 @@ class Command:
|
||||
if hasattr(callback, "__command_checks__"):
|
||||
self.checks.extend(getattr(callback, "__command_checks__"))
|
||||
|
||||
self.max_concurrency: Optional[Tuple[int, str]] = None
|
||||
if hasattr(callback, "__max_concurrency__"):
|
||||
self.max_concurrency = getattr(callback, "__max_concurrency__")
|
||||
|
||||
def add_check(
|
||||
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
) -> None:
|
||||
@ -215,6 +219,7 @@ class CommandHandler:
|
||||
] = prefix
|
||||
self.commands: Dict[str, Command] = {}
|
||||
self.cogs: Dict[str, "Cog"] = {}
|
||||
self._concurrency: Dict[str, Dict[str, int]] = {}
|
||||
|
||||
from .help import HelpCommand
|
||||
|
||||
@ -300,6 +305,47 @@ class CommandHandler:
|
||||
logger.info("Cog '%s' removed.", cog_name)
|
||||
return cog_to_remove
|
||||
|
||||
def _acquire_concurrency(self, ctx: CommandContext) -> None:
|
||||
mc = getattr(ctx.command, "max_concurrency", None)
|
||||
if not mc:
|
||||
return
|
||||
limit, scope = mc
|
||||
if scope == "user":
|
||||
key = ctx.author.id
|
||||
elif scope == "guild":
|
||||
key = ctx.message.guild_id or ctx.author.id
|
||||
else:
|
||||
key = "global"
|
||||
buckets = self._concurrency.setdefault(ctx.command.name, {})
|
||||
current = buckets.get(key, 0)
|
||||
if current >= limit:
|
||||
from .errors import MaxConcurrencyReached
|
||||
|
||||
raise MaxConcurrencyReached(limit)
|
||||
buckets[key] = current + 1
|
||||
|
||||
def _release_concurrency(self, ctx: CommandContext) -> None:
|
||||
mc = getattr(ctx.command, "max_concurrency", None)
|
||||
if not mc:
|
||||
return
|
||||
_, scope = mc
|
||||
if scope == "user":
|
||||
key = ctx.author.id
|
||||
elif scope == "guild":
|
||||
key = ctx.message.guild_id or ctx.author.id
|
||||
else:
|
||||
key = "global"
|
||||
buckets = self._concurrency.get(ctx.command.name)
|
||||
if not buckets:
|
||||
return
|
||||
current = buckets.get(key, 0)
|
||||
if current <= 1:
|
||||
buckets.pop(key, None)
|
||||
else:
|
||||
buckets[key] = current - 1
|
||||
if not buckets:
|
||||
self._concurrency.pop(ctx.command.name, None)
|
||||
|
||||
async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
|
||||
if callable(self.prefix):
|
||||
if inspect.iscoroutinefunction(self.prefix):
|
||||
@ -501,7 +547,11 @@ class CommandHandler:
|
||||
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
|
||||
ctx.args = parsed_args
|
||||
ctx.kwargs = parsed_kwargs
|
||||
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
|
||||
self._acquire_concurrency(ctx)
|
||||
try:
|
||||
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
|
||||
finally:
|
||||
self._release_concurrency(ctx)
|
||||
except CommandError as e:
|
||||
logger.error("Command error for '%s': %s", command.name, e)
|
||||
if hasattr(self.client, "on_command_error"):
|
||||
|
@ -107,6 +107,33 @@ def check_any(
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def max_concurrency(
|
||||
number: int, per: str = "user"
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Limit how many concurrent invocations of a command are allowed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
number:
|
||||
The maximum number of concurrent invocations.
|
||||
per:
|
||||
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
|
||||
"""
|
||||
|
||||
if number < 1:
|
||||
raise ValueError("Concurrency number must be at least 1.")
|
||||
if per not in {"user", "guild", "global"}:
|
||||
raise ValueError("per must be 'user', 'guild', or 'global'.")
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
setattr(func, "__max_concurrency__", (number, per))
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def cooldown(
|
||||
rate: int, per: float
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
|
@ -72,5 +72,13 @@ class CommandInvokeError(CommandError):
|
||||
super().__init__(f"Error during command invocation: {original}")
|
||||
|
||||
|
||||
class MaxConcurrencyReached(CommandError):
|
||||
"""Raised when a command exceeds its concurrency limit."""
|
||||
|
||||
def __init__(self, limit: int):
|
||||
self.limit = limit
|
||||
super().__init__(f"Max concurrency of {limit} reached")
|
||||
|
||||
|
||||
# Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc.
|
||||
# These might inherit from BadArgument.
|
||||
|
103
tests/test_max_concurrency.py
Normal file
103
tests/test_max_concurrency.py
Normal file
@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.commands.core import CommandHandler
|
||||
from disagreement.ext.commands.decorators import command, max_concurrency
|
||||
from disagreement.ext.commands.errors import MaxConcurrencyReached
|
||||
from disagreement.models import Message
|
||||
|
||||
|
||||
class DummyBot:
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
|
||||
async def on_command_error(self, ctx, error):
|
||||
self.errors.append(error)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_concurrency_per_user():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
@command()
|
||||
@max_concurrency(1, per="user")
|
||||
async def foo(ctx):
|
||||
started.set()
|
||||
await release.wait()
|
||||
|
||||
handler.add_command(foo.__command_object__)
|
||||
|
||||
data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"guild_id": "g",
|
||||
"author": {"id": "a", "username": "u", "discriminator": "0001"},
|
||||
"content": "!foo",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg1 = Message(data, client_instance=bot)
|
||||
msg2 = Message({**data, "id": "2"}, client_instance=bot)
|
||||
|
||||
task = asyncio.create_task(handler.process_commands(msg1))
|
||||
await started.wait()
|
||||
|
||||
await handler.process_commands(msg2)
|
||||
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
|
||||
|
||||
release.set()
|
||||
await task
|
||||
|
||||
await handler.process_commands(msg2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_concurrency_per_guild():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
@command()
|
||||
@max_concurrency(1, per="guild")
|
||||
async def foo(ctx):
|
||||
started.set()
|
||||
await release.wait()
|
||||
|
||||
handler.add_command(foo.__command_object__)
|
||||
|
||||
base = {
|
||||
"channel_id": "c",
|
||||
"guild_id": "g",
|
||||
"content": "!foo",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg1 = Message(
|
||||
{
|
||||
**base,
|
||||
"id": "1",
|
||||
"author": {"id": "a", "username": "u", "discriminator": "0001"},
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
msg2 = Message(
|
||||
{
|
||||
**base,
|
||||
"id": "2",
|
||||
"author": {"id": "b", "username": "v", "discriminator": "0001"},
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(handler.process_commands(msg1))
|
||||
await started.wait()
|
||||
|
||||
await handler.process_commands(msg2)
|
||||
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
|
||||
|
||||
release.set()
|
||||
await task
|
||||
|
||||
await handler.process_commands(msg2)
|
Loading…
x
Reference in New Issue
Block a user