104 lines
2.5 KiB
Python
104 lines
2.5 KiB
Python
import asyncio
|
|
import pytest
|
|
|
|
from disagreement.ext.commands.core import CommandHandler
|
|
from disagreement.ext.commands.decorators import command, max_concurrency
|
|
from disagreement.ext.commands.errors import MaxConcurrencyReached
|
|
from disagreement.models import Message
|
|
|
|
|
|
class DummyBot:
|
|
def __init__(self):
|
|
self.errors = []
|
|
|
|
async def on_command_error(self, ctx, error):
|
|
self.errors.append(error)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_concurrency_per_user():
|
|
bot = DummyBot()
|
|
handler = CommandHandler(client=bot, prefix="!")
|
|
started = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
@command()
|
|
@max_concurrency(1, per="user")
|
|
async def foo(ctx):
|
|
started.set()
|
|
await release.wait()
|
|
|
|
handler.add_command(foo.__command_object__)
|
|
|
|
data = {
|
|
"id": "1",
|
|
"channel_id": "c",
|
|
"guild_id": "g",
|
|
"author": {"id": "a", "username": "u", "discriminator": "0001"},
|
|
"content": "!foo",
|
|
"timestamp": "t",
|
|
}
|
|
msg1 = Message(data, client_instance=bot)
|
|
msg2 = Message({**data, "id": "2"}, client_instance=bot)
|
|
|
|
task = asyncio.create_task(handler.process_commands(msg1))
|
|
await started.wait()
|
|
|
|
await handler.process_commands(msg2)
|
|
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
|
|
|
|
release.set()
|
|
await task
|
|
|
|
await handler.process_commands(msg2)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_concurrency_per_guild():
|
|
bot = DummyBot()
|
|
handler = CommandHandler(client=bot, prefix="!")
|
|
started = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
@command()
|
|
@max_concurrency(1, per="guild")
|
|
async def foo(ctx):
|
|
started.set()
|
|
await release.wait()
|
|
|
|
handler.add_command(foo.__command_object__)
|
|
|
|
base = {
|
|
"channel_id": "c",
|
|
"guild_id": "g",
|
|
"content": "!foo",
|
|
"timestamp": "t",
|
|
}
|
|
msg1 = Message(
|
|
{
|
|
**base,
|
|
"id": "1",
|
|
"author": {"id": "a", "username": "u", "discriminator": "0001"},
|
|
},
|
|
client_instance=bot,
|
|
)
|
|
msg2 = Message(
|
|
{
|
|
**base,
|
|
"id": "2",
|
|
"author": {"id": "b", "username": "v", "discriminator": "0001"},
|
|
},
|
|
client_instance=bot,
|
|
)
|
|
|
|
task = asyncio.create_task(handler.process_commands(msg1))
|
|
await started.wait()
|
|
|
|
await handler.process_commands(msg2)
|
|
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
|
|
|
|
release.set()
|
|
await task
|
|
|
|
await handler.process_commands(msg2)
|