Add channel and emoji converters (#112)
This commit is contained in:
parent
e2061adc55
commit
506adeca20
@ -6,7 +6,16 @@ import re
|
||||
import inspect
|
||||
|
||||
from .errors import BadArgument
|
||||
from disagreement.models import Member, Guild, Role, User
|
||||
from disagreement.models import (
|
||||
Member,
|
||||
Guild,
|
||||
Role,
|
||||
User,
|
||||
TextChannel,
|
||||
VoiceChannel,
|
||||
Emoji,
|
||||
PartialEmoji,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import CommandContext
|
||||
@ -158,6 +167,82 @@ class UserConverter(Converter["User"]):
|
||||
raise BadArgument(f"User '{argument}' not found.")
|
||||
|
||||
|
||||
class TextChannelConverter(Converter["TextChannel"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "TextChannel":
|
||||
if not ctx.message.guild_id:
|
||||
raise BadArgument("TextChannel converter requires guild context.")
|
||||
|
||||
match = re.match(r"<#(?P<id>\d+)>$", argument)
|
||||
channel_id = match.group("id") if match else argument
|
||||
|
||||
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||
if guild:
|
||||
channel = guild.get_channel(channel_id)
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel
|
||||
|
||||
channel = (
|
||||
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
|
||||
)
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel
|
||||
|
||||
if hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(channel_id)
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel
|
||||
|
||||
raise BadArgument(f"Text channel '{argument}' not found.")
|
||||
|
||||
|
||||
class VoiceChannelConverter(Converter["VoiceChannel"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "VoiceChannel":
|
||||
if not ctx.message.guild_id:
|
||||
raise BadArgument("VoiceChannel converter requires guild context.")
|
||||
|
||||
match = re.match(r"<#(?P<id>\d+)>$", argument)
|
||||
channel_id = match.group("id") if match else argument
|
||||
|
||||
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||
if guild:
|
||||
channel = guild.get_channel(channel_id)
|
||||
if isinstance(channel, VoiceChannel):
|
||||
return channel
|
||||
|
||||
channel = (
|
||||
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
|
||||
)
|
||||
if isinstance(channel, VoiceChannel):
|
||||
return channel
|
||||
|
||||
if hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(channel_id)
|
||||
if isinstance(channel, VoiceChannel):
|
||||
return channel
|
||||
|
||||
raise BadArgument(f"Voice channel '{argument}' not found.")
|
||||
|
||||
|
||||
class EmojiConverter(Converter["PartialEmoji"]):
|
||||
_CUSTOM_RE = re.compile(r"<(?P<animated>a)?:(?P<name>[^:]+):(?P<id>\d+)>$")
|
||||
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "PartialEmoji":
|
||||
match = self._CUSTOM_RE.match(argument)
|
||||
if match:
|
||||
return PartialEmoji(
|
||||
{
|
||||
"id": match.group("id"),
|
||||
"name": match.group("name"),
|
||||
"animated": bool(match.group("animated")),
|
||||
}
|
||||
)
|
||||
|
||||
if argument:
|
||||
return PartialEmoji({"id": None, "name": argument})
|
||||
|
||||
raise BadArgument(f"Emoji '{argument}' not found.")
|
||||
|
||||
|
||||
# Default converters mapping
|
||||
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||
int: IntConverter(),
|
||||
@ -168,6 +253,10 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||
Guild: GuildConverter(),
|
||||
Role: RoleConverter(),
|
||||
User: UserConverter(),
|
||||
TextChannel: TextChannelConverter(),
|
||||
VoiceChannel: VoiceChannelConverter(),
|
||||
PartialEmoji: EmojiConverter(),
|
||||
Emoji: EmojiConverter(),
|
||||
}
|
||||
|
||||
|
||||
|
@ -3,7 +3,16 @@ import pytest
|
||||
from disagreement.ext.commands.converters import run_converters
|
||||
from disagreement.ext.commands.core import CommandContext, Command
|
||||
from disagreement.ext.commands.errors import BadArgument
|
||||
from disagreement.models import Message, Member, Role, Guild, User
|
||||
from disagreement.models import (
|
||||
Message,
|
||||
Member,
|
||||
Role,
|
||||
Guild,
|
||||
User,
|
||||
TextChannel,
|
||||
VoiceChannel,
|
||||
PartialEmoji,
|
||||
)
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
@ -11,11 +20,12 @@ from disagreement.enums import (
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
ChannelType,
|
||||
)
|
||||
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.cache import GuildCache, Cache
|
||||
from disagreement.cache import GuildCache, Cache, ChannelCache
|
||||
|
||||
|
||||
class DummyBot(Client):
|
||||
@ -23,10 +33,14 @@ class DummyBot(Client):
|
||||
super().__init__(token="test")
|
||||
self._guilds = GuildCache()
|
||||
self._users = Cache()
|
||||
self._channels = ChannelCache()
|
||||
|
||||
def get_guild(self, guild_id):
|
||||
return self._guilds.get(guild_id)
|
||||
|
||||
def get_channel(self, channel_id):
|
||||
return self._channels.get(channel_id)
|
||||
|
||||
async def fetch_member(self, guild_id, member_id):
|
||||
guild = self._guilds.get(guild_id)
|
||||
return guild.get_member(member_id) if guild else None
|
||||
@ -41,6 +55,9 @@ class DummyBot(Client):
|
||||
async def fetch_user(self, user_id):
|
||||
return self._users.get(user_id)
|
||||
|
||||
async def fetch_channel(self, channel_id):
|
||||
return self._channels.get(channel_id)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def guild_objects():
|
||||
@ -93,12 +110,38 @@ def guild_objects():
|
||||
guild._members.set(member.id, member)
|
||||
guild.roles.append(role)
|
||||
|
||||
return guild, member, role, user
|
||||
text_channel = TextChannel(
|
||||
{
|
||||
"id": "20",
|
||||
"type": ChannelType.GUILD_TEXT.value,
|
||||
"guild_id": guild.id,
|
||||
"permission_overwrites": [],
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
voice_channel = VoiceChannel(
|
||||
{
|
||||
"id": "21",
|
||||
"type": ChannelType.GUILD_VOICE.value,
|
||||
"guild_id": guild.id,
|
||||
"permission_overwrites": [],
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
|
||||
guild._channels.set(text_channel.id, text_channel)
|
||||
guild.text_channels.append(text_channel)
|
||||
guild._channels.set(voice_channel.id, voice_channel)
|
||||
guild.voice_channels.append(voice_channel)
|
||||
bot._channels.set(text_channel.id, text_channel)
|
||||
bot._channels.set(voice_channel.id, voice_channel)
|
||||
|
||||
return guild, member, role, user, text_channel, voice_channel
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def command_context(guild_objects):
|
||||
guild, member, role, _ = guild_objects
|
||||
guild, member, role, _, _, _ = guild_objects
|
||||
bot = guild._client
|
||||
message_data = {
|
||||
"id": "10",
|
||||
@ -121,7 +164,7 @@ def command_context(guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_converter(command_context, guild_objects):
|
||||
_, member, _, _ = guild_objects
|
||||
_, member, _, _, _, _ = guild_objects
|
||||
mention = f"<@!{member.id}>"
|
||||
result = await run_converters(command_context, Member, mention)
|
||||
assert result is member
|
||||
@ -131,7 +174,7 @@ async def test_member_converter(command_context, guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_converter(command_context, guild_objects):
|
||||
_, _, role, _ = guild_objects
|
||||
_, _, role, _, _, _ = guild_objects
|
||||
mention = f"<@&{role.id}>"
|
||||
result = await run_converters(command_context, Role, mention)
|
||||
assert result is role
|
||||
@ -141,7 +184,7 @@ async def test_role_converter(command_context, guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_converter(command_context, guild_objects):
|
||||
_, _, _, user = guild_objects
|
||||
_, _, _, user, _, _ = guild_objects
|
||||
mention = f"<@{user.id}>"
|
||||
result = await run_converters(command_context, User, mention)
|
||||
assert result is user
|
||||
@ -151,11 +194,43 @@ async def test_user_converter(command_context, guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_converter(command_context, guild_objects):
|
||||
guild, _, _, _ = guild_objects
|
||||
guild, _, _, _, _, _ = guild_objects
|
||||
result = await run_converters(command_context, Guild, guild.id)
|
||||
assert result is guild
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_channel_converter(command_context, guild_objects):
|
||||
_, _, _, _, text_channel, _ = guild_objects
|
||||
mention = f"<#{text_channel.id}>"
|
||||
result = await run_converters(command_context, TextChannel, mention)
|
||||
assert result is text_channel
|
||||
result = await run_converters(command_context, TextChannel, text_channel.id)
|
||||
assert result is text_channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_channel_converter(command_context, guild_objects):
|
||||
_, _, _, _, _, voice_channel = guild_objects
|
||||
mention = f"<#{voice_channel.id}>"
|
||||
result = await run_converters(command_context, VoiceChannel, mention)
|
||||
assert result is voice_channel
|
||||
result = await run_converters(command_context, VoiceChannel, voice_channel.id)
|
||||
assert result is voice_channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emoji_converter(command_context):
|
||||
result = await run_converters(command_context, PartialEmoji, "<:smile:1>")
|
||||
assert isinstance(result, PartialEmoji)
|
||||
assert result.id == "1"
|
||||
assert result.name == "smile"
|
||||
|
||||
result = await run_converters(command_context, PartialEmoji, "😄")
|
||||
assert result.id is None
|
||||
assert result.name == "😄"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_converter_no_guild():
|
||||
guild_data = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user