disagreement/tests/test_max_concurrency.py

104 lines
2.5 KiB
Python

import asyncio
import pytest
from disagreement.ext.commands.core import CommandHandler
from disagreement.ext.commands.decorators import command, max_concurrency
from disagreement.ext.commands.errors import MaxConcurrencyReached
from disagreement.models import Message
class DummyBot:
def __init__(self):
self.errors = []
async def on_command_error(self, ctx, error):
self.errors.append(error)
@pytest.mark.asyncio
async def test_max_concurrency_per_user():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
started = asyncio.Event()
release = asyncio.Event()
@command()
@max_concurrency(1, per="user")
async def foo(ctx):
started.set()
await release.wait()
handler.add_command(foo.__command_object__)
data = {
"id": "1",
"channel_id": "c",
"guild_id": "g",
"author": {"id": "a", "username": "u", "discriminator": "0001"},
"content": "!foo",
"timestamp": "t",
}
msg1 = Message(data, client_instance=bot)
msg2 = Message({**data, "id": "2"}, client_instance=bot)
task = asyncio.create_task(handler.process_commands(msg1))
await started.wait()
await handler.process_commands(msg2)
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
release.set()
await task
await handler.process_commands(msg2)
@pytest.mark.asyncio
async def test_max_concurrency_per_guild():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
started = asyncio.Event()
release = asyncio.Event()
@command()
@max_concurrency(1, per="guild")
async def foo(ctx):
started.set()
await release.wait()
handler.add_command(foo.__command_object__)
base = {
"channel_id": "c",
"guild_id": "g",
"content": "!foo",
"timestamp": "t",
}
msg1 = Message(
{
**base,
"id": "1",
"author": {"id": "a", "username": "u", "discriminator": "0001"},
},
client_instance=bot,
)
msg2 = Message(
{
**base,
"id": "2",
"author": {"id": "b", "username": "v", "discriminator": "0001"},
},
client_instance=bot,
)
task = asyncio.create_task(handler.process_commands(msg1))
await started.wait()
await handler.process_commands(msg2)
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
release.set()
await task
await handler.process_commands(msg2)