296 lines
10 KiB
Python

# disagreement/ext/commands/decorators.py
from __future__ import annotations
import asyncio
import inspect
import time
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
if TYPE_CHECKING:
from .core import Command, CommandContext
from disagreement.permissions import Permissions
from disagreement.models import Member, Guild, Channel
def command(
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
) -> Callable:
"""
A decorator that transforms a function into a Command.
Args:
name (Optional[str]): The name of the command. Defaults to the function name.
aliases (Optional[List[str]]): Alternative names for the command.
**attrs: Additional attributes to pass to the Command constructor
(e.g., brief, description, hidden).
Returns:
Callable: A decorator that registers the command.
"""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Command callback must be a coroutine function.")
from .core import Command
cmd_name = name or func.__name__
if hasattr(func, "__command_attrs__"):
raise TypeError("Function is already a command or has command attributes.")
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
func.__command_object__ = cmd # type: ignore
return func
return decorator
def listener(
name: Optional[str] = None,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""
A decorator that marks a function as an event listener within a Cog.
"""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Listener callback must be a coroutine function.")
actual_event_name = name or func.__name__
setattr(func, "__listener_name__", actual_event_name)
return func
return decorator
def check(
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator to add a check to a command."""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
checks = getattr(func, "__command_checks__", [])
checks.append(predicate)
setattr(func, "__command_checks__", checks)
return func
return decorator
def check_any(
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator that passes if any predicate returns ``True``."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckAnyFailure, CheckFailure
errors = []
for p in predicates:
try:
result = p(ctx)
if inspect.isawaitable(result):
result = await result
if result:
return True
except CheckFailure as e:
errors.append(e)
raise CheckAnyFailure(errors)
return check(predicate)
def max_concurrency(
number: int, per: str = "user"
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Limit how many concurrent invocations of a command are allowed.
Parameters
----------
number:
The maximum number of concurrent invocations.
per:
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
"""
if number < 1:
raise ValueError("Concurrency number must be at least 1.")
if per not in {"user", "guild", "global"}:
raise ValueError("per must be 'user', 'guild', or 'global'.")
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
setattr(func, "__max_concurrency__", (number, per))
return func
return decorator
def cooldown(
rate: int, per: float
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Simple per-user cooldown decorator."""
buckets: dict[str, dict[str, float]] = {}
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CommandOnCooldown
now = time.monotonic()
user_buckets = buckets.setdefault(ctx.command.name, {})
reset = user_buckets.get(ctx.author.id, 0)
if now < reset:
raise CommandOnCooldown(reset - now)
user_buckets[ctx.author.id] = now + per
return True
return check(predicate)
def _compute_permissions(
member: "Member", channel: "Channel", guild: "Guild"
) -> "Permissions":
"""Compute the effective permissions for a member in a channel."""
return channel.permissions_for(member)
def requires_permissions(
*perms: "Permissions",
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has the given permissions in the channel."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.permissions import (
has_permissions,
missing_permissions,
)
from disagreement.models import Member
channel = getattr(ctx, "channel", None)
if channel is None and hasattr(ctx.bot, "get_channel"):
channel = ctx.bot.get_channel(ctx.message.channel_id)
if channel is None and hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
if channel is None:
raise CheckFailure("Channel for permission check not found.")
guild = getattr(channel, "guild", None)
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
if hasattr(ctx.bot, "get_guild"):
guild = ctx.bot.get_guild(channel.guild_id)
if not guild and hasattr(ctx.bot, "fetch_guild"):
guild = await ctx.bot.fetch_guild(channel.guild_id)
if not guild:
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
if is_dm:
if perms:
raise CheckFailure("Permission checks are not supported in DMs.")
return True
raise CheckFailure("Guild for permission check not found.")
member = ctx.author
if not isinstance(member, Member):
member = guild.get_member(ctx.author.id)
if not member and hasattr(ctx.bot, "fetch_member"):
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
if not member:
raise CheckFailure("Could not resolve author to a guild member.")
perms_value = _compute_permissions(member, channel, guild)
if not has_permissions(perms_value, *perms):
missing = missing_permissions(perms_value, *perms)
missing_names = ", ".join(p.name for p in missing if p.name)
raise CheckFailure(f"Missing permissions: {missing_names}")
return True
return check(predicate)
def has_role(
name_or_id: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has a role with the given name or ID."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.models import Member
if not ctx.guild:
raise CheckFailure("This command cannot be used in DMs.")
author = ctx.author
if not isinstance(author, Member):
try:
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
except Exception:
raise CheckFailure("Could not resolve author to a guild member.")
if not author:
raise CheckFailure("Could not resolve author to a guild member.")
# Create a list of the member's role objects by looking them up in the guild's roles list
member_roles = [role for role in ctx.guild.roles if role.id in author.roles]
if any(
role.id == str(name_or_id) or role.name == name_or_id
for role in member_roles
):
return True
raise CheckFailure(f"You need the '{name_or_id}' role to use this command.")
return check(predicate)
def has_any_role(
*names_or_ids: str | int,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has any of the roles with the given names or IDs."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.models import Member
if not ctx.guild:
raise CheckFailure("This command cannot be used in DMs.")
author = ctx.author
if not isinstance(author, Member):
try:
author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
except Exception:
raise CheckFailure("Could not resolve author to a guild member.")
if not author:
raise CheckFailure("Could not resolve author to a guild member.")
member_roles = [role for role in ctx.guild.roles if role.id in author.roles]
# Convert names_or_ids to a set for efficient lookup
names_or_ids_set = set(map(str, names_or_ids))
if any(
role.id in names_or_ids_set or role.name in names_or_ids_set
for role in member_roles
):
return True
role_list = ", ".join(f"'{r}'" for r in names_or_ids)
raise CheckFailure(
f"You need one of the following roles to use this command: {role_list}"
)
return check(predicate)