Add channel and emoji converters (#112)

This commit is contained in:
Slipstream 2025-06-15 20:39:28 -06:00 committed by GitHub
parent e2061adc55
commit 506adeca20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 173 additions and 9 deletions

View File

@ -6,7 +6,16 @@ import re
import inspect
from .errors import BadArgument
from disagreement.models import Member, Guild, Role, User
from disagreement.models import (
Member,
Guild,
Role,
User,
TextChannel,
VoiceChannel,
Emoji,
PartialEmoji,
)
if TYPE_CHECKING:
from .core import CommandContext
@ -158,6 +167,82 @@ class UserConverter(Converter["User"]):
raise BadArgument(f"User '{argument}' not found.")
class TextChannelConverter(Converter["TextChannel"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "TextChannel":
if not ctx.message.guild_id:
raise BadArgument("TextChannel converter requires guild context.")
match = re.match(r"<#(?P<id>\d+)>$", argument)
channel_id = match.group("id") if match else argument
guild = ctx.bot.get_guild(ctx.message.guild_id)
if guild:
channel = guild.get_channel(channel_id)
if isinstance(channel, TextChannel):
return channel
channel = (
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
)
if isinstance(channel, TextChannel):
return channel
if hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(channel_id)
if isinstance(channel, TextChannel):
return channel
raise BadArgument(f"Text channel '{argument}' not found.")
class VoiceChannelConverter(Converter["VoiceChannel"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "VoiceChannel":
if not ctx.message.guild_id:
raise BadArgument("VoiceChannel converter requires guild context.")
match = re.match(r"<#(?P<id>\d+)>$", argument)
channel_id = match.group("id") if match else argument
guild = ctx.bot.get_guild(ctx.message.guild_id)
if guild:
channel = guild.get_channel(channel_id)
if isinstance(channel, VoiceChannel):
return channel
channel = (
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
)
if isinstance(channel, VoiceChannel):
return channel
if hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(channel_id)
if isinstance(channel, VoiceChannel):
return channel
raise BadArgument(f"Voice channel '{argument}' not found.")
class EmojiConverter(Converter["PartialEmoji"]):
_CUSTOM_RE = re.compile(r"<(?P<animated>a)?:(?P<name>[^:]+):(?P<id>\d+)>$")
async def convert(self, ctx: "CommandContext", argument: str) -> "PartialEmoji":
match = self._CUSTOM_RE.match(argument)
if match:
return PartialEmoji(
{
"id": match.group("id"),
"name": match.group("name"),
"animated": bool(match.group("animated")),
}
)
if argument:
return PartialEmoji({"id": None, "name": argument})
raise BadArgument(f"Emoji '{argument}' not found.")
# Default converters mapping
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
int: IntConverter(),
@ -168,6 +253,10 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
Guild: GuildConverter(),
Role: RoleConverter(),
User: UserConverter(),
TextChannel: TextChannelConverter(),
VoiceChannel: VoiceChannelConverter(),
PartialEmoji: EmojiConverter(),
Emoji: EmojiConverter(),
}

View File

@ -3,7 +3,16 @@ import pytest
from disagreement.ext.commands.converters import run_converters
from disagreement.ext.commands.core import CommandContext, Command
from disagreement.ext.commands.errors import BadArgument
from disagreement.models import Message, Member, Role, Guild, User
from disagreement.models import (
Message,
Member,
Role,
Guild,
User,
TextChannel,
VoiceChannel,
PartialEmoji,
)
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
@ -11,11 +20,12 @@ from disagreement.enums import (
MFALevel,
GuildNSFWLevel,
PremiumTier,
ChannelType,
)
from disagreement.client import Client
from disagreement.cache import GuildCache, Cache
from disagreement.cache import GuildCache, Cache, ChannelCache
class DummyBot(Client):
@ -23,10 +33,14 @@ class DummyBot(Client):
super().__init__(token="test")
self._guilds = GuildCache()
self._users = Cache()
self._channels = ChannelCache()
def get_guild(self, guild_id):
return self._guilds.get(guild_id)
def get_channel(self, channel_id):
return self._channels.get(channel_id)
async def fetch_member(self, guild_id, member_id):
guild = self._guilds.get(guild_id)
return guild.get_member(member_id) if guild else None
@ -41,6 +55,9 @@ class DummyBot(Client):
async def fetch_user(self, user_id):
return self._users.get(user_id)
async def fetch_channel(self, channel_id):
return self._channels.get(channel_id)
@pytest.fixture()
def guild_objects():
@ -93,12 +110,38 @@ def guild_objects():
guild._members.set(member.id, member)
guild.roles.append(role)
return guild, member, role, user
text_channel = TextChannel(
{
"id": "20",
"type": ChannelType.GUILD_TEXT.value,
"guild_id": guild.id,
"permission_overwrites": [],
},
client_instance=bot,
)
voice_channel = VoiceChannel(
{
"id": "21",
"type": ChannelType.GUILD_VOICE.value,
"guild_id": guild.id,
"permission_overwrites": [],
},
client_instance=bot,
)
guild._channels.set(text_channel.id, text_channel)
guild.text_channels.append(text_channel)
guild._channels.set(voice_channel.id, voice_channel)
guild.voice_channels.append(voice_channel)
bot._channels.set(text_channel.id, text_channel)
bot._channels.set(voice_channel.id, voice_channel)
return guild, member, role, user, text_channel, voice_channel
@pytest.fixture()
def command_context(guild_objects):
guild, member, role, _ = guild_objects
guild, member, role, _, _, _ = guild_objects
bot = guild._client
message_data = {
"id": "10",
@ -121,7 +164,7 @@ def command_context(guild_objects):
@pytest.mark.asyncio
async def test_member_converter(command_context, guild_objects):
_, member, _, _ = guild_objects
_, member, _, _, _, _ = guild_objects
mention = f"<@!{member.id}>"
result = await run_converters(command_context, Member, mention)
assert result is member
@ -131,7 +174,7 @@ async def test_member_converter(command_context, guild_objects):
@pytest.mark.asyncio
async def test_role_converter(command_context, guild_objects):
_, _, role, _ = guild_objects
_, _, role, _, _, _ = guild_objects
mention = f"<@&{role.id}>"
result = await run_converters(command_context, Role, mention)
assert result is role
@ -141,7 +184,7 @@ async def test_role_converter(command_context, guild_objects):
@pytest.mark.asyncio
async def test_user_converter(command_context, guild_objects):
_, _, _, user = guild_objects
_, _, _, user, _, _ = guild_objects
mention = f"<@{user.id}>"
result = await run_converters(command_context, User, mention)
assert result is user
@ -151,11 +194,43 @@ async def test_user_converter(command_context, guild_objects):
@pytest.mark.asyncio
async def test_guild_converter(command_context, guild_objects):
guild, _, _, _ = guild_objects
guild, _, _, _, _, _ = guild_objects
result = await run_converters(command_context, Guild, guild.id)
assert result is guild
@pytest.mark.asyncio
async def test_text_channel_converter(command_context, guild_objects):
_, _, _, _, text_channel, _ = guild_objects
mention = f"<#{text_channel.id}>"
result = await run_converters(command_context, TextChannel, mention)
assert result is text_channel
result = await run_converters(command_context, TextChannel, text_channel.id)
assert result is text_channel
@pytest.mark.asyncio
async def test_voice_channel_converter(command_context, guild_objects):
_, _, _, _, _, voice_channel = guild_objects
mention = f"<#{voice_channel.id}>"
result = await run_converters(command_context, VoiceChannel, mention)
assert result is voice_channel
result = await run_converters(command_context, VoiceChannel, voice_channel.id)
assert result is voice_channel
@pytest.mark.asyncio
async def test_emoji_converter(command_context):
result = await run_converters(command_context, PartialEmoji, "<:smile:1>")
assert isinstance(result, PartialEmoji)
assert result.id == "1"
assert result.name == "smile"
result = await run_converters(command_context, PartialEmoji, "😄")
assert result.id is None
assert result.name == "😄"
@pytest.mark.asyncio
async def test_member_converter_no_guild():
guild_data = {