diff --git a/disagreement/ext/commands/converters.py b/disagreement/ext/commands/converters.py index 05158a7..ea09879 100644 --- a/disagreement/ext/commands/converters.py +++ b/disagreement/ext/commands/converters.py @@ -6,7 +6,7 @@ import re import inspect from .errors import BadArgument -from disagreement.models import Member, Guild, Role +from disagreement.models import Member, Guild, Role, User if TYPE_CHECKING: from .core import CommandContext @@ -143,6 +143,21 @@ class GuildConverter(Converter["Guild"]): raise BadArgument(f"Guild '{argument}' not found.") +class UserConverter(Converter["User"]): + async def convert(self, ctx: "CommandContext", argument: str) -> "User": + match = re.match(r"<@!?(\d+)>$", argument) + user_id = match.group(1) if match else argument + + user = ctx.bot._users.get(user_id) + if user: + return user + + user = await ctx.bot.fetch_user(user_id) + if user: + return user + raise BadArgument(f"User '{argument}' not found.") + + # Default converters mapping DEFAULT_CONVERTERS: dict[type, Converter[Any]] = { int: IntConverter(), @@ -152,7 +167,7 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = { Member: MemberConverter(), Guild: GuildConverter(), Role: RoleConverter(), - # User: UserConverter(), # Add when User model and converter are ready + User: UserConverter(), } diff --git a/tests/test_additional_converters.py b/tests/test_additional_converters.py index 40da104..8ee751f 100644 --- a/tests/test_additional_converters.py +++ b/tests/test_additional_converters.py @@ -3,7 +3,7 @@ import pytest from disagreement.ext.commands.converters import run_converters from disagreement.ext.commands.core import CommandContext, Command from disagreement.ext.commands.errors import BadArgument -from disagreement.models import Message, Member, Role, Guild +from disagreement.models import Message, Member, Role, Guild, User from disagreement.enums import ( VerificationLevel, MessageNotificationLevel, @@ -15,13 +15,14 @@ from disagreement.enums import ( from disagreement.client import Client -from disagreement.cache import GuildCache +from disagreement.cache import GuildCache, Cache class DummyBot(Client): def __init__(self): super().__init__(token="test") self._guilds = GuildCache() + self._users = Cache() def get_guild(self, guild_id): return self._guilds.get(guild_id) @@ -37,6 +38,9 @@ class DummyBot(Client): async def fetch_guild(self, guild_id): return self._guilds.get(guild_id) + async def fetch_user(self, user_id): + return self._users.get(user_id) + @pytest.fixture() def guild_objects(): @@ -60,6 +64,9 @@ def guild_objects(): guild = Guild(guild_data, client_instance=bot) bot._guilds.set(guild.id, guild) + user = User({"id": "7", "username": "u", "discriminator": "0001"}) + bot._users.set(user.id, user) + member = Member( { "user": {"id": "3", "username": "m", "discriminator": "0001"}, @@ -86,12 +93,12 @@ def guild_objects(): guild._members.set(member.id, member) guild.roles.append(role) - return guild, member, role + return guild, member, role, user @pytest.fixture() def command_context(guild_objects): - guild, member, role = guild_objects + guild, member, role, _ = guild_objects bot = guild._client message_data = { "id": "10", @@ -114,7 +121,7 @@ def command_context(guild_objects): @pytest.mark.asyncio async def test_member_converter(command_context, guild_objects): - _, member, _ = guild_objects + _, member, _, _ = guild_objects mention = f"<@!{member.id}>" result = await run_converters(command_context, Member, mention) assert result is member @@ -124,7 +131,7 @@ async def test_member_converter(command_context, guild_objects): @pytest.mark.asyncio async def test_role_converter(command_context, guild_objects): - _, _, role = guild_objects + _, _, role, _ = guild_objects mention = f"<@&{role.id}>" result = await run_converters(command_context, Role, mention) assert result is role @@ -132,9 +139,19 @@ async def test_role_converter(command_context, guild_objects): assert result is role +@pytest.mark.asyncio +async def test_user_converter(command_context, guild_objects): + _, _, _, user = guild_objects + mention = f"<@{user.id}>" + result = await run_converters(command_context, User, mention) + assert result is user + result = await run_converters(command_context, User, user.id) + assert result is user + + @pytest.mark.asyncio async def test_guild_converter(command_context, guild_objects): - guild, _, _ = guild_objects + guild, _, _, _ = guild_objects result = await run_converters(command_context, Guild, guild.id) assert result is guild