diff --git a/disagreement/client.py b/disagreement/client.py index fe0d91f..c0e8240 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -593,14 +593,19 @@ class Client: self._event_dispatcher.unregister(event_name, coro) - async def _process_message_for_commands(self, message: "Message") -> None: - """Internal listener to process messages for commands.""" - # Make sure message object is valid and not from a bot (optional, common check) - if ( - not message or not message.author or message.author.bot - ): # Add .bot check to User model - return - await self.command_handler.process_commands(message) + async def _process_message_for_commands(self, message: "Message") -> None: + """Internal listener to process messages for commands.""" + # Make sure message object is valid and not from a bot (optional, common check) + if ( + not message or not message.author or message.author.bot + ): # Add .bot check to User model + return + await self.command_handler.process_commands(message) + + async def get_context(self, message: "Message") -> Optional["CommandContext"]: + """Return a :class:`CommandContext` for ``message`` without executing the command.""" + + return await self.command_handler.get_context(message) # --- Command Framework Methods --- diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py index 7dd6c26..0e76282 100644 --- a/disagreement/ext/commands/core.py +++ b/disagreement/ext/commands/core.py @@ -471,9 +471,9 @@ class CommandHandler: return self.prefix(self.client, message) # type: ignore return self.prefix - async def _parse_arguments( - self, command: Command, ctx: CommandContext, view: StringView - ) -> Tuple[List[Any], Dict[str, Any]]: + async def _parse_arguments( + self, command: Command, ctx: CommandContext, view: StringView + ) -> Tuple[List[Any], Dict[str, Any]]: args_list = [] kwargs_dict = {} params_to_parse = list(command.params.values()) @@ -636,7 +636,79 @@ class CommandHandler: elif param.kind == inspect.Parameter.KEYWORD_ONLY: kwargs_dict[param.name] = final_value_for_param - return args_list, kwargs_dict + return args_list, kwargs_dict + + async def get_context(self, message: "Message") -> Optional[CommandContext]: + """Parse a message and return a :class:`CommandContext` without executing the command. + + Returns ``None`` if the message does not invoke a command.""" + + if not message.content: + return None + + prefix_to_use = await self.get_prefix(message) + if not prefix_to_use: + return None + + actual_prefix: Optional[str] = None + if isinstance(prefix_to_use, list): + for p in prefix_to_use: + if message.content.startswith(p): + actual_prefix = p + break + if not actual_prefix: + return None + elif isinstance(prefix_to_use, str): + if message.content.startswith(prefix_to_use): + actual_prefix = prefix_to_use + else: + return None + else: + return None + + if actual_prefix is None: + return None + + view = StringView(message.content[len(actual_prefix) :]) + + command_name = view.get_word() + if not command_name: + return None + + command = self.get_command(command_name) + if not command: + return None + + invoked_with = command_name + + if isinstance(command, Group): + view.skip_whitespace() + potential_subcommand = view.get_word() + if potential_subcommand: + subcommand = command.get_command(potential_subcommand) + if subcommand: + command = subcommand + invoked_with += f" {potential_subcommand}" + elif command.invoke_without_command: + view.index -= len(potential_subcommand) + view.previous + else: + raise CommandNotFound( + f"Subcommand '{potential_subcommand}' not found." + ) + + ctx = CommandContext( + message=message, + bot=self.client, + prefix=actual_prefix, + command=command, + invoked_with=invoked_with, + cog=command.cog, + ) + + parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view) + ctx.args = parsed_args + ctx.kwargs = parsed_kwargs + return ctx async def process_commands(self, message: "Message") -> None: if not message.content: diff --git a/tests/test_client_context_manager.py b/tests/test_client_context_manager.py index ba4ef7f..dc6c68d 100644 --- a/tests/test_client_context_manager.py +++ b/tests/test_client_context_manager.py @@ -1,7 +1,10 @@ import asyncio -import pytest from unittest.mock import AsyncMock +import pytest + +# pylint: disable=no-member + from disagreement.client import Client diff --git a/tests/test_get_context.py b/tests/test_get_context.py new file mode 100644 index 0000000..8cff434 --- /dev/null +++ b/tests/test_get_context.py @@ -0,0 +1,59 @@ +import pytest + +from disagreement.client import Client +from disagreement.ext.commands.core import Command, CommandHandler +from disagreement.models import Message + + +class DummyBot: + def __init__(self): + self.executed = False + + +@pytest.mark.asyncio +async def test_get_context_parses_without_execution(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + + async def foo(ctx, number: int, word: str): + bot.executed = True + + handler.add_command(Command(foo, name="foo")) + + msg_data = { + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "!foo 1 bar", + "timestamp": "t", + } + msg = Message(msg_data, client_instance=bot) + + ctx = await handler.get_context(msg) + assert ctx is not None + assert ctx.command.name == "foo" + assert ctx.args == [1, "bar"] + assert bot.executed is False + + +@pytest.mark.asyncio +async def test_client_get_context(): + client = Client(token="t") + + async def foo(ctx): + raise RuntimeError("should not run") + + client.command_handler.add_command(Command(foo, name="foo")) + + msg_data = { + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "!foo", + "timestamp": "t", + } + msg = Message(msg_data, client_instance=client) + + ctx = await client.get_context(msg) + assert ctx is not None + assert ctx.command.name == "foo"