feat: Enhance command framework with groups, checks, and events

This commit is contained in:
Slipstream 2025-06-11 02:06:11 -06:00
parent be85444aa0
commit 2bd45c87ca
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
2 changed files with 181 additions and 9 deletions

View File

@ -579,6 +579,41 @@ class Client:
# For now, assuming name is sufficient for removal from the handler's flat list.
return removed_cog
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
"""
A decorator that adds a global check to the bot.
This check will be called for every command before it's executed.
Example:
@bot.check
async def block_dms(ctx):
return ctx.guild is not None
"""
self.command_handler.add_check(coro)
return coro
def command(
self, **attrs: Any
) -> Callable[[Callable[..., Awaitable[None]]], Command]:
"""A decorator that transforms a function into a Command."""
def decorator(func: Callable[..., Awaitable[None]]) -> Command:
cmd = Command(func, **attrs)
self.command_handler.add_command(cmd)
return cmd
return decorator
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], Group]:
"""A decorator that transforms a function into a Group command."""
def decorator(func: Callable[..., Awaitable[None]]) -> Group:
cmd = Group(func, **attrs)
self.command_handler.add_command(cmd)
return cmd
return decorator
def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
"""
Adds a standalone application command or group to the bot.
@ -648,6 +683,16 @@ class Client:
)
# import traceback
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
async def on_command_completion(self, ctx: "CommandContext") -> None:
"""
Default command completion handler. Called when a command has successfully completed.
Users can override this method in a subclass of Client.
Args:
ctx (CommandContext): The context of the command that completed.
"""
pass
# --- Extension Management Methods ---

View File

@ -40,7 +40,42 @@ if TYPE_CHECKING:
from disagreement.models import Message, User
class Command:
class GroupMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.commands: Dict[str, "Command"] = {}
self.name: str = ""
def command(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
cmd = Command(func, **attrs)
cmd.cog = getattr(self, "cog", None)
self.add_command(cmd)
return cmd
return decorator
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
cmd = Group(func, **attrs)
cmd.cog = getattr(self, "cog", None)
self.add_command(cmd)
return cmd
return decorator
def add_command(self, command: "Command") -> None:
if command.name in self.commands:
raise ValueError(f"Command '{command.name}' is already registered in group '{self.name}'.")
self.commands[command.name.lower()] = command
for alias in command.aliases:
if alias in self.commands:
logger.warning(f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias.")
self.commands[alias.lower()] = command
def get_command(self, name: str) -> Optional["Command"]:
return self.commands.get(name.lower())
class Command(GroupMixin):
"""
Represents a bot command.
@ -58,12 +93,14 @@ class Command:
if not asyncio.iscoroutinefunction(callback):
raise TypeError("Command callback must be a coroutine function.")
super().__init__(**attrs)
self.callback: Callable[..., Awaitable[None]] = callback
self.name: str = attrs.get("name", callback.__name__)
self.aliases: List[str] = attrs.get("aliases", [])
self.brief: Optional[str] = attrs.get("brief")
self.description: Optional[str] = attrs.get("description") or callback.__doc__
self.cog: Optional["Cog"] = attrs.get("cog")
self.invoke_without_command: bool = attrs.get("invoke_without_command", False)
self.params = inspect.signature(callback).parameters
self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = []
@ -79,20 +116,73 @@ class Command:
) -> None:
self.checks.append(predicate)
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
async def _run_checks(self, ctx: "CommandContext") -> None:
"""Runs all cog, local and global checks for the command."""
from .errors import CheckFailure
# Run cog-level check first
if self.cog:
cog_check = getattr(self.cog, "cog_check", None)
if cog_check:
try:
result = cog_check(ctx)
if inspect.isawaitable(result):
result = await result
if not result:
raise CheckFailure(
f"The cog-level check for command '{self.name}' failed."
)
except CheckFailure:
raise
except Exception as e:
raise CommandInvokeError(e) from e
# Run local checks
for predicate in self.checks:
result = predicate(ctx)
if inspect.isawaitable(result):
result = await result
if not result:
raise CheckFailure("Check predicate failed.")
raise CheckFailure(f"A local check for command '{self.name}' failed.")
# Then run global checks from the handler
if hasattr(ctx.bot, "command_handler"):
for predicate in ctx.bot.command_handler._global_checks:
result = predicate(ctx)
if inspect.isawaitable(result):
result = await result
if not result:
raise CheckFailure(
f"A global check failed for command '{self.name}'."
)
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
await self._run_checks(ctx)
before_invoke = None
after_invoke = None
if self.cog:
await self.callback(self.cog, ctx, *args, **kwargs)
else:
await self.callback(ctx, *args, **kwargs)
before_invoke = getattr(self.cog, "cog_before_invoke", None)
after_invoke = getattr(self.cog, "cog_after_invoke", None)
if before_invoke:
await before_invoke(ctx)
try:
if self.cog:
await self.callback(self.cog, ctx, *args, **kwargs)
else:
await self.callback(ctx, *args, **kwargs)
finally:
if after_invoke:
await after_invoke(ctx)
class Group(Command):
"""A command that can have subcommands."""
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
super().__init__(callback, **attrs)
PrefixCommand = Command # Alias for clarity in hybrid commands
@ -220,11 +310,20 @@ class CommandHandler:
self.commands: Dict[str, Command] = {}
self.cogs: Dict[str, "Cog"] = {}
self._concurrency: Dict[str, Dict[str, int]] = {}
self._global_checks: List[
Callable[["CommandContext"], Awaitable[bool] | bool]
] = []
from .help import HelpCommand
self.add_command(HelpCommand(self))
def add_check(
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
) -> None:
"""Adds a global check to the command handler."""
self._global_checks.append(predicate)
def add_command(self, command: Command) -> None:
if command.name in self.commands:
raise ValueError(f"Command '{command.name}' is already registered.")
@ -239,6 +338,15 @@ class CommandHandler:
)
self.commands[alias.lower()] = command
if isinstance(command, Group):
for sub_cmd in command.commands.values():
if sub_cmd.name in self.commands:
logger.warning(
"Subcommand '%s' of group '%s' conflicts with a top-level command.",
sub_cmd.name,
command.name,
)
def remove_command(self, name: str) -> Optional[Command]:
command = self.commands.pop(name.lower(), None)
if command:
@ -534,12 +642,28 @@ class CommandHandler:
if not command:
return
invoked_with = command_name
original_command = command
if isinstance(command, Group):
view.skip_whitespace()
potential_subcommand = view.get_word()
if potential_subcommand:
subcommand = command.get_command(potential_subcommand)
if subcommand:
command = subcommand
invoked_with += f" {potential_subcommand}"
elif command.invoke_without_command:
view.index -= len(potential_subcommand) + view.previous
else:
raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.")
ctx = CommandContext(
message=message,
bot=self.client,
prefix=actual_prefix,
command=command,
invoked_with=command_name,
invoked_with=invoked_with,
cog=command.cog,
)
@ -553,11 +677,14 @@ class CommandHandler:
finally:
self._release_concurrency(ctx)
except CommandError as e:
logger.error("Command error for '%s': %s", command.name, e)
logger.error("Command error for '%s': %s", original_command.name, e)
if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, e)
except Exception as e:
logger.error("Unexpected error invoking command '%s': %s", command.name, e)
logger.error("Unexpected error invoking command '%s': %s", original_command.name, e)
exc = CommandInvokeError(e)
if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, exc)
else:
if hasattr(self.client, "on_command_completion"):
await self.client.on_command_completion(ctx)