Add UserConverter and tests (#99)
This commit is contained in:
parent
2056a3ddcf
commit
de40aa2c29
@ -6,7 +6,7 @@ import re
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from .errors import BadArgument
|
from .errors import BadArgument
|
||||||
from disagreement.models import Member, Guild, Role
|
from disagreement.models import Member, Guild, Role, User
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .core import CommandContext
|
from .core import CommandContext
|
||||||
@ -143,6 +143,21 @@ class GuildConverter(Converter["Guild"]):
|
|||||||
raise BadArgument(f"Guild '{argument}' not found.")
|
raise BadArgument(f"Guild '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class UserConverter(Converter["User"]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> "User":
|
||||||
|
match = re.match(r"<@!?(\d+)>$", argument)
|
||||||
|
user_id = match.group(1) if match else argument
|
||||||
|
|
||||||
|
user = ctx.bot._users.get(user_id)
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
|
||||||
|
user = await ctx.bot.fetch_user(user_id)
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
raise BadArgument(f"User '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
# Default converters mapping
|
# Default converters mapping
|
||||||
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||||
int: IntConverter(),
|
int: IntConverter(),
|
||||||
@ -152,7 +167,7 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
|||||||
Member: MemberConverter(),
|
Member: MemberConverter(),
|
||||||
Guild: GuildConverter(),
|
Guild: GuildConverter(),
|
||||||
Role: RoleConverter(),
|
Role: RoleConverter(),
|
||||||
# User: UserConverter(), # Add when User model and converter are ready
|
User: UserConverter(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import pytest
|
|||||||
from disagreement.ext.commands.converters import run_converters
|
from disagreement.ext.commands.converters import run_converters
|
||||||
from disagreement.ext.commands.core import CommandContext, Command
|
from disagreement.ext.commands.core import CommandContext, Command
|
||||||
from disagreement.ext.commands.errors import BadArgument
|
from disagreement.ext.commands.errors import BadArgument
|
||||||
from disagreement.models import Message, Member, Role, Guild
|
from disagreement.models import Message, Member, Role, Guild, User
|
||||||
from disagreement.enums import (
|
from disagreement.enums import (
|
||||||
VerificationLevel,
|
VerificationLevel,
|
||||||
MessageNotificationLevel,
|
MessageNotificationLevel,
|
||||||
@ -15,13 +15,14 @@ from disagreement.enums import (
|
|||||||
|
|
||||||
|
|
||||||
from disagreement.client import Client
|
from disagreement.client import Client
|
||||||
from disagreement.cache import GuildCache
|
from disagreement.cache import GuildCache, Cache
|
||||||
|
|
||||||
|
|
||||||
class DummyBot(Client):
|
class DummyBot(Client):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(token="test")
|
super().__init__(token="test")
|
||||||
self._guilds = GuildCache()
|
self._guilds = GuildCache()
|
||||||
|
self._users = Cache()
|
||||||
|
|
||||||
def get_guild(self, guild_id):
|
def get_guild(self, guild_id):
|
||||||
return self._guilds.get(guild_id)
|
return self._guilds.get(guild_id)
|
||||||
@ -37,6 +38,9 @@ class DummyBot(Client):
|
|||||||
async def fetch_guild(self, guild_id):
|
async def fetch_guild(self, guild_id):
|
||||||
return self._guilds.get(guild_id)
|
return self._guilds.get(guild_id)
|
||||||
|
|
||||||
|
async def fetch_user(self, user_id):
|
||||||
|
return self._users.get(user_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def guild_objects():
|
def guild_objects():
|
||||||
@ -60,6 +64,9 @@ def guild_objects():
|
|||||||
guild = Guild(guild_data, client_instance=bot)
|
guild = Guild(guild_data, client_instance=bot)
|
||||||
bot._guilds.set(guild.id, guild)
|
bot._guilds.set(guild.id, guild)
|
||||||
|
|
||||||
|
user = User({"id": "7", "username": "u", "discriminator": "0001"})
|
||||||
|
bot._users.set(user.id, user)
|
||||||
|
|
||||||
member = Member(
|
member = Member(
|
||||||
{
|
{
|
||||||
"user": {"id": "3", "username": "m", "discriminator": "0001"},
|
"user": {"id": "3", "username": "m", "discriminator": "0001"},
|
||||||
@ -86,12 +93,12 @@ def guild_objects():
|
|||||||
guild._members.set(member.id, member)
|
guild._members.set(member.id, member)
|
||||||
guild.roles.append(role)
|
guild.roles.append(role)
|
||||||
|
|
||||||
return guild, member, role
|
return guild, member, role, user
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def command_context(guild_objects):
|
def command_context(guild_objects):
|
||||||
guild, member, role = guild_objects
|
guild, member, role, _ = guild_objects
|
||||||
bot = guild._client
|
bot = guild._client
|
||||||
message_data = {
|
message_data = {
|
||||||
"id": "10",
|
"id": "10",
|
||||||
@ -114,7 +121,7 @@ def command_context(guild_objects):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_member_converter(command_context, guild_objects):
|
async def test_member_converter(command_context, guild_objects):
|
||||||
_, member, _ = guild_objects
|
_, member, _, _ = guild_objects
|
||||||
mention = f"<@!{member.id}>"
|
mention = f"<@!{member.id}>"
|
||||||
result = await run_converters(command_context, Member, mention)
|
result = await run_converters(command_context, Member, mention)
|
||||||
assert result is member
|
assert result is member
|
||||||
@ -124,7 +131,7 @@ async def test_member_converter(command_context, guild_objects):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_role_converter(command_context, guild_objects):
|
async def test_role_converter(command_context, guild_objects):
|
||||||
_, _, role = guild_objects
|
_, _, role, _ = guild_objects
|
||||||
mention = f"<@&{role.id}>"
|
mention = f"<@&{role.id}>"
|
||||||
result = await run_converters(command_context, Role, mention)
|
result = await run_converters(command_context, Role, mention)
|
||||||
assert result is role
|
assert result is role
|
||||||
@ -132,9 +139,19 @@ async def test_role_converter(command_context, guild_objects):
|
|||||||
assert result is role
|
assert result is role
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_converter(command_context, guild_objects):
|
||||||
|
_, _, _, user = guild_objects
|
||||||
|
mention = f"<@{user.id}>"
|
||||||
|
result = await run_converters(command_context, User, mention)
|
||||||
|
assert result is user
|
||||||
|
result = await run_converters(command_context, User, user.id)
|
||||||
|
assert result is user
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_guild_converter(command_context, guild_objects):
|
async def test_guild_converter(command_context, guild_objects):
|
||||||
guild, _, _ = guild_objects
|
guild, _, _, _ = guild_objects
|
||||||
result = await run_converters(command_context, Guild, guild.id)
|
result = await run_converters(command_context, Guild, guild.id)
|
||||||
assert result is guild
|
assert result is guild
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user