feat: Enhance command framework with groups, checks, and events
This commit is contained in:
parent
be85444aa0
commit
2bd45c87ca
@ -579,6 +579,41 @@ class Client:
|
||||
# For now, assuming name is sufficient for removal from the handler's flat list.
|
||||
return removed_cog
|
||||
|
||||
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
|
||||
"""
|
||||
A decorator that adds a global check to the bot.
|
||||
This check will be called for every command before it's executed.
|
||||
|
||||
Example:
|
||||
@bot.check
|
||||
async def block_dms(ctx):
|
||||
return ctx.guild is not None
|
||||
"""
|
||||
self.command_handler.add_check(coro)
|
||||
return coro
|
||||
|
||||
def command(
|
||||
self, **attrs: Any
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Command]:
|
||||
"""A decorator that transforms a function into a Command."""
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> Command:
|
||||
cmd = Command(func, **attrs)
|
||||
self.command_handler.add_command(cmd)
|
||||
return cmd
|
||||
|
||||
return decorator
|
||||
|
||||
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], Group]:
|
||||
"""A decorator that transforms a function into a Group command."""
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> Group:
|
||||
cmd = Group(func, **attrs)
|
||||
self.command_handler.add_command(cmd)
|
||||
return cmd
|
||||
|
||||
return decorator
|
||||
|
||||
def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
|
||||
"""
|
||||
Adds a standalone application command or group to the bot.
|
||||
@ -648,6 +683,16 @@ class Client:
|
||||
)
|
||||
# import traceback
|
||||
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
|
||||
|
||||
async def on_command_completion(self, ctx: "CommandContext") -> None:
|
||||
"""
|
||||
Default command completion handler. Called when a command has successfully completed.
|
||||
Users can override this method in a subclass of Client.
|
||||
|
||||
Args:
|
||||
ctx (CommandContext): The context of the command that completed.
|
||||
"""
|
||||
pass
|
||||
|
||||
# --- Extension Management Methods ---
|
||||
|
||||
|
@ -40,7 +40,42 @@ if TYPE_CHECKING:
|
||||
from disagreement.models import Message, User
|
||||
|
||||
|
||||
class Command:
|
||||
class GroupMixin:
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.commands: Dict[str, "Command"] = {}
|
||||
self.name: str = ""
|
||||
|
||||
def command(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
|
||||
cmd = Command(func, **attrs)
|
||||
cmd.cog = getattr(self, "cog", None)
|
||||
self.add_command(cmd)
|
||||
return cmd
|
||||
return decorator
|
||||
|
||||
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
|
||||
def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
|
||||
cmd = Group(func, **attrs)
|
||||
cmd.cog = getattr(self, "cog", None)
|
||||
self.add_command(cmd)
|
||||
return cmd
|
||||
return decorator
|
||||
|
||||
def add_command(self, command: "Command") -> None:
|
||||
if command.name in self.commands:
|
||||
raise ValueError(f"Command '{command.name}' is already registered in group '{self.name}'.")
|
||||
self.commands[command.name.lower()] = command
|
||||
for alias in command.aliases:
|
||||
if alias in self.commands:
|
||||
logger.warning(f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias.")
|
||||
self.commands[alias.lower()] = command
|
||||
|
||||
def get_command(self, name: str) -> Optional["Command"]:
|
||||
return self.commands.get(name.lower())
|
||||
|
||||
|
||||
class Command(GroupMixin):
|
||||
"""
|
||||
Represents a bot command.
|
||||
|
||||
@ -58,12 +93,14 @@ class Command:
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise TypeError("Command callback must be a coroutine function.")
|
||||
|
||||
super().__init__(**attrs)
|
||||
self.callback: Callable[..., Awaitable[None]] = callback
|
||||
self.name: str = attrs.get("name", callback.__name__)
|
||||
self.aliases: List[str] = attrs.get("aliases", [])
|
||||
self.brief: Optional[str] = attrs.get("brief")
|
||||
self.description: Optional[str] = attrs.get("description") or callback.__doc__
|
||||
self.cog: Optional["Cog"] = attrs.get("cog")
|
||||
self.invoke_without_command: bool = attrs.get("invoke_without_command", False)
|
||||
|
||||
self.params = inspect.signature(callback).parameters
|
||||
self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = []
|
||||
@ -79,20 +116,73 @@ class Command:
|
||||
) -> None:
|
||||
self.checks.append(predicate)
|
||||
|
||||
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
|
||||
async def _run_checks(self, ctx: "CommandContext") -> None:
|
||||
"""Runs all cog, local and global checks for the command."""
|
||||
from .errors import CheckFailure
|
||||
|
||||
# Run cog-level check first
|
||||
if self.cog:
|
||||
cog_check = getattr(self.cog, "cog_check", None)
|
||||
if cog_check:
|
||||
try:
|
||||
result = cog_check(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if not result:
|
||||
raise CheckFailure(
|
||||
f"The cog-level check for command '{self.name}' failed."
|
||||
)
|
||||
except CheckFailure:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise CommandInvokeError(e) from e
|
||||
|
||||
# Run local checks
|
||||
for predicate in self.checks:
|
||||
result = predicate(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if not result:
|
||||
raise CheckFailure("Check predicate failed.")
|
||||
raise CheckFailure(f"A local check for command '{self.name}' failed.")
|
||||
|
||||
# Then run global checks from the handler
|
||||
if hasattr(ctx.bot, "command_handler"):
|
||||
for predicate in ctx.bot.command_handler._global_checks:
|
||||
result = predicate(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if not result:
|
||||
raise CheckFailure(
|
||||
f"A global check failed for command '{self.name}'."
|
||||
)
|
||||
|
||||
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
|
||||
await self._run_checks(ctx)
|
||||
|
||||
before_invoke = None
|
||||
after_invoke = None
|
||||
|
||||
if self.cog:
|
||||
await self.callback(self.cog, ctx, *args, **kwargs)
|
||||
else:
|
||||
await self.callback(ctx, *args, **kwargs)
|
||||
before_invoke = getattr(self.cog, "cog_before_invoke", None)
|
||||
after_invoke = getattr(self.cog, "cog_after_invoke", None)
|
||||
|
||||
if before_invoke:
|
||||
await before_invoke(ctx)
|
||||
|
||||
try:
|
||||
if self.cog:
|
||||
await self.callback(self.cog, ctx, *args, **kwargs)
|
||||
else:
|
||||
await self.callback(ctx, *args, **kwargs)
|
||||
finally:
|
||||
if after_invoke:
|
||||
await after_invoke(ctx)
|
||||
|
||||
|
||||
class Group(Command):
|
||||
"""A command that can have subcommands."""
|
||||
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
|
||||
super().__init__(callback, **attrs)
|
||||
|
||||
|
||||
PrefixCommand = Command # Alias for clarity in hybrid commands
|
||||
@ -220,11 +310,20 @@ class CommandHandler:
|
||||
self.commands: Dict[str, Command] = {}
|
||||
self.cogs: Dict[str, "Cog"] = {}
|
||||
self._concurrency: Dict[str, Dict[str, int]] = {}
|
||||
self._global_checks: List[
|
||||
Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
] = []
|
||||
|
||||
from .help import HelpCommand
|
||||
|
||||
self.add_command(HelpCommand(self))
|
||||
|
||||
def add_check(
|
||||
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
) -> None:
|
||||
"""Adds a global check to the command handler."""
|
||||
self._global_checks.append(predicate)
|
||||
|
||||
def add_command(self, command: Command) -> None:
|
||||
if command.name in self.commands:
|
||||
raise ValueError(f"Command '{command.name}' is already registered.")
|
||||
@ -239,6 +338,15 @@ class CommandHandler:
|
||||
)
|
||||
self.commands[alias.lower()] = command
|
||||
|
||||
if isinstance(command, Group):
|
||||
for sub_cmd in command.commands.values():
|
||||
if sub_cmd.name in self.commands:
|
||||
logger.warning(
|
||||
"Subcommand '%s' of group '%s' conflicts with a top-level command.",
|
||||
sub_cmd.name,
|
||||
command.name,
|
||||
)
|
||||
|
||||
def remove_command(self, name: str) -> Optional[Command]:
|
||||
command = self.commands.pop(name.lower(), None)
|
||||
if command:
|
||||
@ -534,12 +642,28 @@ class CommandHandler:
|
||||
if not command:
|
||||
return
|
||||
|
||||
invoked_with = command_name
|
||||
original_command = command
|
||||
|
||||
if isinstance(command, Group):
|
||||
view.skip_whitespace()
|
||||
potential_subcommand = view.get_word()
|
||||
if potential_subcommand:
|
||||
subcommand = command.get_command(potential_subcommand)
|
||||
if subcommand:
|
||||
command = subcommand
|
||||
invoked_with += f" {potential_subcommand}"
|
||||
elif command.invoke_without_command:
|
||||
view.index -= len(potential_subcommand) + view.previous
|
||||
else:
|
||||
raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.")
|
||||
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=self.client,
|
||||
prefix=actual_prefix,
|
||||
command=command,
|
||||
invoked_with=command_name,
|
||||
invoked_with=invoked_with,
|
||||
cog=command.cog,
|
||||
)
|
||||
|
||||
@ -553,11 +677,14 @@ class CommandHandler:
|
||||
finally:
|
||||
self._release_concurrency(ctx)
|
||||
except CommandError as e:
|
||||
logger.error("Command error for '%s': %s", command.name, e)
|
||||
logger.error("Command error for '%s': %s", original_command.name, e)
|
||||
if hasattr(self.client, "on_command_error"):
|
||||
await self.client.on_command_error(ctx, e)
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error invoking command '%s': %s", command.name, e)
|
||||
logger.error("Unexpected error invoking command '%s': %s", original_command.name, e)
|
||||
exc = CommandInvokeError(e)
|
||||
if hasattr(self.client, "on_command_error"):
|
||||
await self.client.on_command_error(ctx, exc)
|
||||
else:
|
||||
if hasattr(self.client, "on_command_completion"):
|
||||
await self.client.on_command_completion(ctx)
|
||||
|
Loading…
x
Reference in New Issue
Block a user