Add get_context parsing for commands (#91)
This commit is contained in:
parent
c1c5cfb41a
commit
095e7e7192
@ -593,14 +593,19 @@ class Client:
|
|||||||
|
|
||||||
self._event_dispatcher.unregister(event_name, coro)
|
self._event_dispatcher.unregister(event_name, coro)
|
||||||
|
|
||||||
async def _process_message_for_commands(self, message: "Message") -> None:
|
async def _process_message_for_commands(self, message: "Message") -> None:
|
||||||
"""Internal listener to process messages for commands."""
|
"""Internal listener to process messages for commands."""
|
||||||
# Make sure message object is valid and not from a bot (optional, common check)
|
# Make sure message object is valid and not from a bot (optional, common check)
|
||||||
if (
|
if (
|
||||||
not message or not message.author or message.author.bot
|
not message or not message.author or message.author.bot
|
||||||
): # Add .bot check to User model
|
): # Add .bot check to User model
|
||||||
return
|
return
|
||||||
await self.command_handler.process_commands(message)
|
await self.command_handler.process_commands(message)
|
||||||
|
|
||||||
|
async def get_context(self, message: "Message") -> Optional["CommandContext"]:
|
||||||
|
"""Return a :class:`CommandContext` for ``message`` without executing the command."""
|
||||||
|
|
||||||
|
return await self.command_handler.get_context(message)
|
||||||
|
|
||||||
# --- Command Framework Methods ---
|
# --- Command Framework Methods ---
|
||||||
|
|
||||||
|
@ -471,9 +471,9 @@ class CommandHandler:
|
|||||||
return self.prefix(self.client, message) # type: ignore
|
return self.prefix(self.client, message) # type: ignore
|
||||||
return self.prefix
|
return self.prefix
|
||||||
|
|
||||||
async def _parse_arguments(
|
async def _parse_arguments(
|
||||||
self, command: Command, ctx: CommandContext, view: StringView
|
self, command: Command, ctx: CommandContext, view: StringView
|
||||||
) -> Tuple[List[Any], Dict[str, Any]]:
|
) -> Tuple[List[Any], Dict[str, Any]]:
|
||||||
args_list = []
|
args_list = []
|
||||||
kwargs_dict = {}
|
kwargs_dict = {}
|
||||||
params_to_parse = list(command.params.values())
|
params_to_parse = list(command.params.values())
|
||||||
@ -636,7 +636,79 @@ class CommandHandler:
|
|||||||
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
|
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
|
||||||
kwargs_dict[param.name] = final_value_for_param
|
kwargs_dict[param.name] = final_value_for_param
|
||||||
|
|
||||||
return args_list, kwargs_dict
|
return args_list, kwargs_dict
|
||||||
|
|
||||||
|
async def get_context(self, message: "Message") -> Optional[CommandContext]:
|
||||||
|
"""Parse a message and return a :class:`CommandContext` without executing the command.
|
||||||
|
|
||||||
|
Returns ``None`` if the message does not invoke a command."""
|
||||||
|
|
||||||
|
if not message.content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
prefix_to_use = await self.get_prefix(message)
|
||||||
|
if not prefix_to_use:
|
||||||
|
return None
|
||||||
|
|
||||||
|
actual_prefix: Optional[str] = None
|
||||||
|
if isinstance(prefix_to_use, list):
|
||||||
|
for p in prefix_to_use:
|
||||||
|
if message.content.startswith(p):
|
||||||
|
actual_prefix = p
|
||||||
|
break
|
||||||
|
if not actual_prefix:
|
||||||
|
return None
|
||||||
|
elif isinstance(prefix_to_use, str):
|
||||||
|
if message.content.startswith(prefix_to_use):
|
||||||
|
actual_prefix = prefix_to_use
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if actual_prefix is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
view = StringView(message.content[len(actual_prefix) :])
|
||||||
|
|
||||||
|
command_name = view.get_word()
|
||||||
|
if not command_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
command = self.get_command(command_name)
|
||||||
|
if not command:
|
||||||
|
return None
|
||||||
|
|
||||||
|
invoked_with = command_name
|
||||||
|
|
||||||
|
if isinstance(command, Group):
|
||||||
|
view.skip_whitespace()
|
||||||
|
potential_subcommand = view.get_word()
|
||||||
|
if potential_subcommand:
|
||||||
|
subcommand = command.get_command(potential_subcommand)
|
||||||
|
if subcommand:
|
||||||
|
command = subcommand
|
||||||
|
invoked_with += f" {potential_subcommand}"
|
||||||
|
elif command.invoke_without_command:
|
||||||
|
view.index -= len(potential_subcommand) + view.previous
|
||||||
|
else:
|
||||||
|
raise CommandNotFound(
|
||||||
|
f"Subcommand '{potential_subcommand}' not found."
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message,
|
||||||
|
bot=self.client,
|
||||||
|
prefix=actual_prefix,
|
||||||
|
command=command,
|
||||||
|
invoked_with=invoked_with,
|
||||||
|
cog=command.cog,
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
|
||||||
|
ctx.args = parsed_args
|
||||||
|
ctx.kwargs = parsed_kwargs
|
||||||
|
return ctx
|
||||||
|
|
||||||
async def process_commands(self, message: "Message") -> None:
|
async def process_commands(self, message: "Message") -> None:
|
||||||
if not message.content:
|
if not message.content:
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# pylint: disable=no-member
|
||||||
|
|
||||||
from disagreement.client import Client
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
59
tests/test_get_context.py
Normal file
59
tests/test_get_context.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.ext.commands.core import Command, CommandHandler
|
||||||
|
from disagreement.models import Message
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBot:
|
||||||
|
def __init__(self):
|
||||||
|
self.executed = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_context_parses_without_execution():
|
||||||
|
bot = DummyBot()
|
||||||
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
|
||||||
|
async def foo(ctx, number: int, word: str):
|
||||||
|
bot.executed = True
|
||||||
|
|
||||||
|
handler.add_command(Command(foo, name="foo"))
|
||||||
|
|
||||||
|
msg_data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "!foo 1 bar",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(msg_data, client_instance=bot)
|
||||||
|
|
||||||
|
ctx = await handler.get_context(msg)
|
||||||
|
assert ctx is not None
|
||||||
|
assert ctx.command.name == "foo"
|
||||||
|
assert ctx.args == [1, "bar"]
|
||||||
|
assert bot.executed is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_get_context():
|
||||||
|
client = Client(token="t")
|
||||||
|
|
||||||
|
async def foo(ctx):
|
||||||
|
raise RuntimeError("should not run")
|
||||||
|
|
||||||
|
client.command_handler.add_command(Command(foo, name="foo"))
|
||||||
|
|
||||||
|
msg_data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "!foo",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(msg_data, client_instance=client)
|
||||||
|
|
||||||
|
ctx = await client.get_context(msg)
|
||||||
|
assert ctx is not None
|
||||||
|
assert ctx.command.name == "foo"
|
Loading…
x
Reference in New Issue
Block a user