Add Greedy converter and support in parser (#58)
This commit is contained in:
parent
afeb86a395
commit
6d55a2ca98
@ -1,8 +1,10 @@
|
||||
# disagreement/ext/commands/converters.py
|
||||
# pyright: reportIncompatibleMethodOverride=false
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic
|
||||
from abc import ABC, abstractmethod
|
||||
import re
|
||||
import inspect
|
||||
|
||||
from .errors import BadArgument
|
||||
from disagreement.models import Member, Guild, Role
|
||||
@ -36,6 +38,20 @@ class Converter(ABC, Generic[T]):
|
||||
raise NotImplementedError("Converter subclass must implement convert method.")
|
||||
|
||||
|
||||
class Greedy(list):
|
||||
"""Type hint helper to greedily consume arguments."""
|
||||
|
||||
converter: Any = None
|
||||
|
||||
def __class_getitem__(cls, param: Any) -> type: # pyright: ignore[override]
|
||||
if isinstance(param, tuple):
|
||||
if len(param) != 1:
|
||||
raise TypeError("Greedy[...] expects a single parameter")
|
||||
param = param[0]
|
||||
name = f"Greedy[{getattr(param, '__name__', str(param))}]"
|
||||
return type(name, (Greedy,), {"converter": param})
|
||||
|
||||
|
||||
# --- Built-in Type Converters ---
|
||||
|
||||
|
||||
@ -169,7 +185,3 @@ async def run_converters(ctx: "CommandContext", annotation: Any, argument: str)
|
||||
raise BadArgument(f"No converter found for type annotation '{annotation}'.")
|
||||
|
||||
return argument # Default to string if no annotation or annotation is str
|
||||
|
||||
|
||||
# Need to import inspect for the run_converters function
|
||||
import inspect
|
||||
|
@ -29,7 +29,7 @@ from .errors import (
|
||||
CheckFailure,
|
||||
CommandInvokeError,
|
||||
)
|
||||
from .converters import run_converters, DEFAULT_CONVERTERS, Converter
|
||||
from .converters import Greedy, run_converters, DEFAULT_CONVERTERS, Converter
|
||||
from disagreement.typing import Typing
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -46,29 +46,39 @@ class GroupMixin:
|
||||
self.commands: Dict[str, "Command"] = {}
|
||||
self.name: str = ""
|
||||
|
||||
def command(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
|
||||
def command(
|
||||
self, **attrs: Any
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
|
||||
cmd = Command(func, **attrs)
|
||||
cmd.cog = getattr(self, "cog", None)
|
||||
self.add_command(cmd)
|
||||
return cmd
|
||||
|
||||
return decorator
|
||||
|
||||
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
|
||||
def group(
|
||||
self, **attrs: Any
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
|
||||
cmd = Group(func, **attrs)
|
||||
cmd.cog = getattr(self, "cog", None)
|
||||
self.add_command(cmd)
|
||||
return cmd
|
||||
|
||||
return decorator
|
||||
|
||||
def add_command(self, command: "Command") -> None:
|
||||
if command.name in self.commands:
|
||||
raise ValueError(f"Command '{command.name}' is already registered in group '{self.name}'.")
|
||||
raise ValueError(
|
||||
f"Command '{command.name}' is already registered in group '{self.name}'."
|
||||
)
|
||||
self.commands[command.name.lower()] = command
|
||||
for alias in command.aliases:
|
||||
if alias in self.commands:
|
||||
logger.warning(f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias.")
|
||||
logger.warning(
|
||||
f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias."
|
||||
)
|
||||
self.commands[alias.lower()] = command
|
||||
|
||||
def get_command(self, name: str) -> Optional["Command"]:
|
||||
@ -181,6 +191,7 @@ class Command(GroupMixin):
|
||||
|
||||
class Group(Command):
|
||||
"""A command that can have subcommands."""
|
||||
|
||||
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
|
||||
super().__init__(callback, **attrs)
|
||||
|
||||
@ -494,7 +505,34 @@ class CommandHandler:
|
||||
None # Holds the raw string for current param
|
||||
)
|
||||
|
||||
if view.eof: # No more input string
|
||||
annotation = param.annotation
|
||||
if inspect.isclass(annotation) and issubclass(annotation, Greedy):
|
||||
greedy_values = []
|
||||
converter_type = annotation.converter
|
||||
while not view.eof:
|
||||
view.skip_whitespace()
|
||||
if view.eof:
|
||||
break
|
||||
start = view.index
|
||||
if view.buffer[view.index] == '"':
|
||||
arg_str_value = view.get_quoted_string()
|
||||
if arg_str_value == "" and view.buffer[view.index] == '"':
|
||||
raise BadArgument(
|
||||
f"Unterminated quoted string for argument '{param.name}'."
|
||||
)
|
||||
else:
|
||||
arg_str_value = view.get_word()
|
||||
try:
|
||||
converted = await run_converters(
|
||||
ctx, converter_type, arg_str_value
|
||||
)
|
||||
except BadArgument:
|
||||
view.index = start
|
||||
break
|
||||
greedy_values.append(converted)
|
||||
final_value_for_param = greedy_values
|
||||
arg_str_value = None
|
||||
elif view.eof: # No more input string
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
final_value_for_param = param.default
|
||||
elif param.kind != inspect.Parameter.VAR_KEYWORD:
|
||||
@ -656,7 +694,9 @@ class CommandHandler:
|
||||
elif command.invoke_without_command:
|
||||
view.index -= len(potential_subcommand) + view.previous
|
||||
else:
|
||||
raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.")
|
||||
raise CommandNotFound(
|
||||
f"Subcommand '{potential_subcommand}' not found."
|
||||
)
|
||||
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
@ -681,7 +721,9 @@ class CommandHandler:
|
||||
if hasattr(self.client, "on_command_error"):
|
||||
await self.client.on_command_error(ctx, e)
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error invoking command '%s': %s", original_command.name, e)
|
||||
logger.error(
|
||||
"Unexpected error invoking command '%s': %s", original_command.name, e
|
||||
)
|
||||
exc = CommandInvokeError(e)
|
||||
if hasattr(self.client, "on_command_error"):
|
||||
await self.client.on_command_error(ctx, exc)
|
||||
|
Loading…
x
Reference in New Issue
Block a user