diff --git a/disagreement/ext/commands/converters.py b/disagreement/ext/commands/converters.py index ea09879..4235234 100644 --- a/disagreement/ext/commands/converters.py +++ b/disagreement/ext/commands/converters.py @@ -6,7 +6,16 @@ import re import inspect from .errors import BadArgument -from disagreement.models import Member, Guild, Role, User +from disagreement.models import ( + Member, + Guild, + Role, + User, + TextChannel, + VoiceChannel, + Emoji, + PartialEmoji, +) if TYPE_CHECKING: from .core import CommandContext @@ -158,6 +167,82 @@ class UserConverter(Converter["User"]): raise BadArgument(f"User '{argument}' not found.") +class TextChannelConverter(Converter["TextChannel"]): + async def convert(self, ctx: "CommandContext", argument: str) -> "TextChannel": + if not ctx.message.guild_id: + raise BadArgument("TextChannel converter requires guild context.") + + match = re.match(r"<#(?P\d+)>$", argument) + channel_id = match.group("id") if match else argument + + guild = ctx.bot.get_guild(ctx.message.guild_id) + if guild: + channel = guild.get_channel(channel_id) + if isinstance(channel, TextChannel): + return channel + + channel = ( + ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None + ) + if isinstance(channel, TextChannel): + return channel + + if hasattr(ctx.bot, "fetch_channel"): + channel = await ctx.bot.fetch_channel(channel_id) + if isinstance(channel, TextChannel): + return channel + + raise BadArgument(f"Text channel '{argument}' not found.") + + +class VoiceChannelConverter(Converter["VoiceChannel"]): + async def convert(self, ctx: "CommandContext", argument: str) -> "VoiceChannel": + if not ctx.message.guild_id: + raise BadArgument("VoiceChannel converter requires guild context.") + + match = re.match(r"<#(?P\d+)>$", argument) + channel_id = match.group("id") if match else argument + + guild = ctx.bot.get_guild(ctx.message.guild_id) + if guild: + channel = guild.get_channel(channel_id) + if isinstance(channel, VoiceChannel): + return channel + + channel = ( + ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None + ) + if isinstance(channel, VoiceChannel): + return channel + + if hasattr(ctx.bot, "fetch_channel"): + channel = await ctx.bot.fetch_channel(channel_id) + if isinstance(channel, VoiceChannel): + return channel + + raise BadArgument(f"Voice channel '{argument}' not found.") + + +class EmojiConverter(Converter["PartialEmoji"]): + _CUSTOM_RE = re.compile(r"<(?Pa)?:(?P[^:]+):(?P\d+)>$") + + async def convert(self, ctx: "CommandContext", argument: str) -> "PartialEmoji": + match = self._CUSTOM_RE.match(argument) + if match: + return PartialEmoji( + { + "id": match.group("id"), + "name": match.group("name"), + "animated": bool(match.group("animated")), + } + ) + + if argument: + return PartialEmoji({"id": None, "name": argument}) + + raise BadArgument(f"Emoji '{argument}' not found.") + + # Default converters mapping DEFAULT_CONVERTERS: dict[type, Converter[Any]] = { int: IntConverter(), @@ -168,6 +253,10 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = { Guild: GuildConverter(), Role: RoleConverter(), User: UserConverter(), + TextChannel: TextChannelConverter(), + VoiceChannel: VoiceChannelConverter(), + PartialEmoji: EmojiConverter(), + Emoji: EmojiConverter(), } diff --git a/tests/test_additional_converters.py b/tests/test_additional_converters.py index 8ee751f..952edc7 100644 --- a/tests/test_additional_converters.py +++ b/tests/test_additional_converters.py @@ -3,7 +3,16 @@ import pytest from disagreement.ext.commands.converters import run_converters from disagreement.ext.commands.core import CommandContext, Command from disagreement.ext.commands.errors import BadArgument -from disagreement.models import Message, Member, Role, Guild, User +from disagreement.models import ( + Message, + Member, + Role, + Guild, + User, + TextChannel, + VoiceChannel, + PartialEmoji, +) from disagreement.enums import ( VerificationLevel, MessageNotificationLevel, @@ -11,11 +20,12 @@ from disagreement.enums import ( MFALevel, GuildNSFWLevel, PremiumTier, + ChannelType, ) from disagreement.client import Client -from disagreement.cache import GuildCache, Cache +from disagreement.cache import GuildCache, Cache, ChannelCache class DummyBot(Client): @@ -23,10 +33,14 @@ class DummyBot(Client): super().__init__(token="test") self._guilds = GuildCache() self._users = Cache() + self._channels = ChannelCache() def get_guild(self, guild_id): return self._guilds.get(guild_id) + def get_channel(self, channel_id): + return self._channels.get(channel_id) + async def fetch_member(self, guild_id, member_id): guild = self._guilds.get(guild_id) return guild.get_member(member_id) if guild else None @@ -41,6 +55,9 @@ class DummyBot(Client): async def fetch_user(self, user_id): return self._users.get(user_id) + async def fetch_channel(self, channel_id): + return self._channels.get(channel_id) + @pytest.fixture() def guild_objects(): @@ -93,12 +110,38 @@ def guild_objects(): guild._members.set(member.id, member) guild.roles.append(role) - return guild, member, role, user + text_channel = TextChannel( + { + "id": "20", + "type": ChannelType.GUILD_TEXT.value, + "guild_id": guild.id, + "permission_overwrites": [], + }, + client_instance=bot, + ) + voice_channel = VoiceChannel( + { + "id": "21", + "type": ChannelType.GUILD_VOICE.value, + "guild_id": guild.id, + "permission_overwrites": [], + }, + client_instance=bot, + ) + + guild._channels.set(text_channel.id, text_channel) + guild.text_channels.append(text_channel) + guild._channels.set(voice_channel.id, voice_channel) + guild.voice_channels.append(voice_channel) + bot._channels.set(text_channel.id, text_channel) + bot._channels.set(voice_channel.id, voice_channel) + + return guild, member, role, user, text_channel, voice_channel @pytest.fixture() def command_context(guild_objects): - guild, member, role, _ = guild_objects + guild, member, role, _, _, _ = guild_objects bot = guild._client message_data = { "id": "10", @@ -121,7 +164,7 @@ def command_context(guild_objects): @pytest.mark.asyncio async def test_member_converter(command_context, guild_objects): - _, member, _, _ = guild_objects + _, member, _, _, _, _ = guild_objects mention = f"<@!{member.id}>" result = await run_converters(command_context, Member, mention) assert result is member @@ -131,7 +174,7 @@ async def test_member_converter(command_context, guild_objects): @pytest.mark.asyncio async def test_role_converter(command_context, guild_objects): - _, _, role, _ = guild_objects + _, _, role, _, _, _ = guild_objects mention = f"<@&{role.id}>" result = await run_converters(command_context, Role, mention) assert result is role @@ -141,7 +184,7 @@ async def test_role_converter(command_context, guild_objects): @pytest.mark.asyncio async def test_user_converter(command_context, guild_objects): - _, _, _, user = guild_objects + _, _, _, user, _, _ = guild_objects mention = f"<@{user.id}>" result = await run_converters(command_context, User, mention) assert result is user @@ -151,11 +194,43 @@ async def test_user_converter(command_context, guild_objects): @pytest.mark.asyncio async def test_guild_converter(command_context, guild_objects): - guild, _, _, _ = guild_objects + guild, _, _, _, _, _ = guild_objects result = await run_converters(command_context, Guild, guild.id) assert result is guild +@pytest.mark.asyncio +async def test_text_channel_converter(command_context, guild_objects): + _, _, _, _, text_channel, _ = guild_objects + mention = f"<#{text_channel.id}>" + result = await run_converters(command_context, TextChannel, mention) + assert result is text_channel + result = await run_converters(command_context, TextChannel, text_channel.id) + assert result is text_channel + + +@pytest.mark.asyncio +async def test_voice_channel_converter(command_context, guild_objects): + _, _, _, _, _, voice_channel = guild_objects + mention = f"<#{voice_channel.id}>" + result = await run_converters(command_context, VoiceChannel, mention) + assert result is voice_channel + result = await run_converters(command_context, VoiceChannel, voice_channel.id) + assert result is voice_channel + + +@pytest.mark.asyncio +async def test_emoji_converter(command_context): + result = await run_converters(command_context, PartialEmoji, "<:smile:1>") + assert isinstance(result, PartialEmoji) + assert result.id == "1" + assert result.name == "smile" + + result = await run_converters(command_context, PartialEmoji, "😄") + assert result.id is None + assert result.name == "😄" + + @pytest.mark.asyncio async def test_member_converter_no_guild(): guild_data = {