From e9375a5a364f24a5cdf6aa748dedf477386568fe Mon Sep 17 00:00:00 2001 From: Slipstream Date: Tue, 10 Jun 2025 15:54:00 -0600 Subject: [PATCH] Implement requires_permissions decorator (#22) * Add permissions check decorator * Refactor command decorator and add permission computation logic --- disagreement/ext/commands/__init__.py | 10 +- disagreement/ext/commands/decorators.py | 142 +++++++++++++++++++----- docs/commands.md | 17 +++ tests/test_command_checks.py | 62 ++++++++++- 4 files changed, 199 insertions(+), 32 deletions(-) diff --git a/disagreement/ext/commands/__init__.py b/disagreement/ext/commands/__init__.py index 779a7cb..21b1e35 100644 --- a/disagreement/ext/commands/__init__.py +++ b/disagreement/ext/commands/__init__.py @@ -10,7 +10,14 @@ from .core import ( CommandContext, CommandHandler, ) # CommandHandler might be internal -from .decorators import command, listener, check, check_any, cooldown +from .decorators import ( + command, + listener, + check, + check_any, + cooldown, + requires_permissions, +) from .errors import ( CommandError, CommandNotFound, @@ -36,6 +43,7 @@ __all__ = [ "check", "check_any", "cooldown", + "requires_permissions", # Errors "CommandError", "CommandNotFound", diff --git a/disagreement/ext/commands/decorators.py b/disagreement/ext/commands/decorators.py index e988ea2..8fd14f7 100644 --- a/disagreement/ext/commands/decorators.py +++ b/disagreement/ext/commands/decorators.py @@ -1,4 +1,5 @@ # disagreement/ext/commands/decorators.py +from __future__ import annotations import asyncio import inspect @@ -6,9 +7,9 @@ import time from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable if TYPE_CHECKING: - from .core import Command, CommandContext # For type hinting return or internal use - - # from .cog import Cog # For Cog specific decorators + from .core import Command, CommandContext + from disagreement.permissions import Permissions + from disagreement.models import Member, Guild, Channel def command( @@ -33,32 +34,16 @@ def command( if not asyncio.iscoroutinefunction(func): raise TypeError("Command callback must be a coroutine function.") - from .core import ( - Command, - ) # Late import to avoid circular dependencies at module load time - - # The actual registration will happen when a Cog is added or if commands are global. - # For now, this decorator creates a Command instance and attaches it to the function, - # or returns a Command instance that can be collected. + from .core import Command cmd_name = name or func.__name__ - # Store command attributes on the function itself for later collection by Cog or Client - # This is a common pattern. if hasattr(func, "__command_attrs__"): - # This case might occur if decorators are stacked in an unusual way, - # or if a function is decorated multiple times (which should be disallowed or handled). - # For now, let's assume one @command decorator per function. raise TypeError("Function is already a command or has command attributes.") - # Create the command object. It will be registered by the Cog or Client. cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs) - - # We can attach the command object to the function, so Cogs can find it. - func.__command_object__ = cmd # type: ignore # type: ignore[attr-defined] - return func # Return the original function, now marked. - # Or return `cmd` if commands are registered globally immediately. - # For Cogs, returning `func` and letting Cog collect is cleaner. + func.__command_object__ = cmd # type: ignore + return func return decorator @@ -68,11 +53,6 @@ def listener( ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: """ A decorator that marks a function as an event listener within a Cog. - The actual registration happens when the Cog is added to the client. - - Args: - name (Optional[str]): The name of the event to listen to. - Defaults to the function name (e.g., `on_message`). """ def decorator( @@ -81,13 +61,11 @@ def listener( if not asyncio.iscoroutinefunction(func): raise TypeError("Listener callback must be a coroutine function.") - # 'name' here is from the outer 'listener' scope (closure) actual_event_name = name or func.__name__ - # Store listener info on the function for Cog to collect setattr(func, "__listener_name__", actual_event_name) return func - return decorator # This must be correctly indented under 'listener' + return decorator def check( @@ -148,3 +126,107 @@ def cooldown( return True return check(predicate) + + +def _compute_permissions( + member: "Member", channel: "Channel", guild: "Guild" +) -> "Permissions": + """Compute the effective permissions for a member in a channel.""" + from disagreement.models import Member, Guild, Channel + from disagreement.permissions import Permissions + + if guild.owner_id == member.id: + return Permissions(~0) + + roles = {str(r.id): r for r in guild.roles} + everyone_role = roles.get(str(guild.id)) + if not everyone_role: + base_permissions = Permissions(0) + else: + base_permissions = Permissions(int(everyone_role.permissions)) + + for role_id in member.roles: + role = roles.get(str(role_id)) + if role: + base_permissions |= Permissions(int(role.permissions)) + + if base_permissions & Permissions.ADMINISTRATOR: + return Permissions(~0) + + overwrites = { + ow.id: ow for ow in getattr(channel, "permission_overwrites", []) + } + allow = Permissions(0) + deny = Permissions(0) + + if everyone_overwrite := overwrites.get(str(guild.id)): + allow |= Permissions(int(everyone_overwrite.allow)) + deny |= Permissions(int(everyone_overwrite.deny)) + + for role_id in member.roles: + if role_overwrite := overwrites.get(str(role_id)): + allow |= Permissions(int(role_overwrite.allow)) + deny |= Permissions(int(role_overwrite.deny)) + + if member_overwrite := overwrites.get(str(member.id)): + allow |= Permissions(int(member_overwrite.allow)) + deny |= Permissions(int(member_overwrite.deny)) + + return (base_permissions & ~deny) | allow + + +def requires_permissions( + *perms: "Permissions", +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Check that the invoking member has the given permissions in the channel.""" + + async def predicate(ctx: "CommandContext") -> bool: + from .errors import CheckFailure + from disagreement.permissions import ( + has_permissions, + missing_permissions, + ) + from disagreement.models import Member + + channel = getattr(ctx, "channel", None) + if channel is None and hasattr(ctx.bot, "get_channel"): + channel = ctx.bot.get_channel(ctx.message.channel_id) + if channel is None and hasattr(ctx.bot, "fetch_channel"): + channel = await ctx.bot.fetch_channel(ctx.message.channel_id) + + if channel is None: + raise CheckFailure("Channel for permission check not found.") + + guild = getattr(channel, "guild", None) + if not guild and hasattr(channel, "guild_id") and channel.guild_id: + if hasattr(ctx.bot, "get_guild"): + guild = ctx.bot.get_guild(channel.guild_id) + if not guild and hasattr(ctx.bot, "fetch_guild"): + guild = await ctx.bot.fetch_guild(channel.guild_id) + + if not guild: + is_dm = not hasattr(channel, "guild_id") or not channel.guild_id + if is_dm: + if perms: + raise CheckFailure("Permission checks are not supported in DMs.") + return True + raise CheckFailure("Guild for permission check not found.") + + member = ctx.author + if not isinstance(member, Member): + member = guild.get_member(ctx.author.id) + if not member and hasattr(ctx.bot, "fetch_member"): + member = await ctx.bot.fetch_member(guild.id, ctx.author.id) + + if not member: + raise CheckFailure("Could not resolve author to a guild member.") + + perms_value = _compute_permissions(member, channel, guild) + + if not has_permissions(perms_value, *perms): + missing = missing_permissions(perms_value, *perms) + missing_names = ", ".join(p.name for p in missing if p.name) + raise CheckFailure(f"Missing permissions: {missing_names}") + return True + + return check(predicate) diff --git a/docs/commands.md b/docs/commands.md index e109899..cbd3e0a 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -49,3 +49,20 @@ async def ping(ctx): ``` Invoking a command while it is on cooldown raises :class:`CommandOnCooldown`. + +## Permission Checks + +Use `commands.requires_permissions` to ensure the invoking member has the +required permissions in the channel. + +```python +from disagreement.ext.commands import command, requires_permissions +from disagreement.permissions import Permissions + +@command() +@requires_permissions(Permissions.MANAGE_MESSAGES) +async def purge(ctx): + await ctx.send("Purged!") +``` + +Missing permissions raise :class:`CheckFailure`. diff --git a/tests/test_command_checks.py b/tests/test_command_checks.py index 50754f2..8620395 100644 --- a/tests/test_command_checks.py +++ b/tests/test_command_checks.py @@ -2,8 +2,13 @@ import asyncio import pytest from disagreement.ext.commands.core import Command, CommandContext -from disagreement.ext.commands.decorators import check, cooldown +from disagreement.ext.commands.decorators import ( + check, + cooldown, + requires_permissions, +) from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown +from disagreement.permissions import Permissions @pytest.mark.asyncio @@ -49,3 +54,58 @@ async def test_cooldown_per_user(message): await asyncio.sleep(0.05) await cmd.invoke(ctx) assert len(uses) == 2 + + +@pytest.mark.asyncio +async def test_requires_permissions_pass(message): + class Channel: + def __init__(self, perms): + self.perms = perms + + def permissions_for(self, member): + return self.perms + + message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES) + + @requires_permissions(Permissions.SEND_MESSAGES) + async def cb(ctx): + pass + + cmd = Command(cb) + ctx = CommandContext( + message=message, + bot=message._client, + prefix="!", + command=cmd, + invoked_with="test", + ) + + await cmd.invoke(ctx) + + +@pytest.mark.asyncio +async def test_requires_permissions_fail(message): + class Channel: + def __init__(self, perms): + self.perms = perms + + def permissions_for(self, member): + return self.perms + + message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES) + + @requires_permissions(Permissions.MANAGE_MESSAGES) + async def cb(ctx): + pass + + cmd = Command(cb) + ctx = CommandContext( + message=message, + bot=message._client, + prefix="!", + command=cmd, + invoked_with="test", + ) + + with pytest.raises(CheckFailure): + await cmd.invoke(ctx)