Implement requires_permissions decorator (#22)

* Add permissions check decorator

* Refactor command decorator and add permission computation logic
This commit is contained in:
Slipstream 2025-06-10 15:54:00 -06:00 committed by GitHub
parent 1ff56106c9
commit e9375a5a36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 199 additions and 32 deletions

View File

@ -10,7 +10,14 @@ from .core import (
CommandContext,
CommandHandler,
) # CommandHandler might be internal
from .decorators import command, listener, check, check_any, cooldown
from .decorators import (
command,
listener,
check,
check_any,
cooldown,
requires_permissions,
)
from .errors import (
CommandError,
CommandNotFound,
@ -36,6 +43,7 @@ __all__ = [
"check",
"check_any",
"cooldown",
"requires_permissions",
# Errors
"CommandError",
"CommandNotFound",

View File

@ -1,4 +1,5 @@
# disagreement/ext/commands/decorators.py
from __future__ import annotations
import asyncio
import inspect
@ -6,9 +7,9 @@ import time
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
if TYPE_CHECKING:
from .core import Command, CommandContext # For type hinting return or internal use
# from .cog import Cog # For Cog specific decorators
from .core import Command, CommandContext
from disagreement.permissions import Permissions
from disagreement.models import Member, Guild, Channel
def command(
@ -33,32 +34,16 @@ def command(
if not asyncio.iscoroutinefunction(func):
raise TypeError("Command callback must be a coroutine function.")
from .core import (
Command,
) # Late import to avoid circular dependencies at module load time
# The actual registration will happen when a Cog is added or if commands are global.
# For now, this decorator creates a Command instance and attaches it to the function,
# or returns a Command instance that can be collected.
from .core import Command
cmd_name = name or func.__name__
# Store command attributes on the function itself for later collection by Cog or Client
# This is a common pattern.
if hasattr(func, "__command_attrs__"):
# This case might occur if decorators are stacked in an unusual way,
# or if a function is decorated multiple times (which should be disallowed or handled).
# For now, let's assume one @command decorator per function.
raise TypeError("Function is already a command or has command attributes.")
# Create the command object. It will be registered by the Cog or Client.
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
# We can attach the command object to the function, so Cogs can find it.
func.__command_object__ = cmd # type: ignore # type: ignore[attr-defined]
return func # Return the original function, now marked.
# Or return `cmd` if commands are registered globally immediately.
# For Cogs, returning `func` and letting Cog collect is cleaner.
func.__command_object__ = cmd # type: ignore
return func
return decorator
@ -68,11 +53,6 @@ def listener(
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""
A decorator that marks a function as an event listener within a Cog.
The actual registration happens when the Cog is added to the client.
Args:
name (Optional[str]): The name of the event to listen to.
Defaults to the function name (e.g., `on_message`).
"""
def decorator(
@ -81,13 +61,11 @@ def listener(
if not asyncio.iscoroutinefunction(func):
raise TypeError("Listener callback must be a coroutine function.")
# 'name' here is from the outer 'listener' scope (closure)
actual_event_name = name or func.__name__
# Store listener info on the function for Cog to collect
setattr(func, "__listener_name__", actual_event_name)
return func
return decorator # This must be correctly indented under 'listener'
return decorator
def check(
@ -148,3 +126,107 @@ def cooldown(
return True
return check(predicate)
def _compute_permissions(
member: "Member", channel: "Channel", guild: "Guild"
) -> "Permissions":
"""Compute the effective permissions for a member in a channel."""
from disagreement.models import Member, Guild, Channel
from disagreement.permissions import Permissions
if guild.owner_id == member.id:
return Permissions(~0)
roles = {str(r.id): r for r in guild.roles}
everyone_role = roles.get(str(guild.id))
if not everyone_role:
base_permissions = Permissions(0)
else:
base_permissions = Permissions(int(everyone_role.permissions))
for role_id in member.roles:
role = roles.get(str(role_id))
if role:
base_permissions |= Permissions(int(role.permissions))
if base_permissions & Permissions.ADMINISTRATOR:
return Permissions(~0)
overwrites = {
ow.id: ow for ow in getattr(channel, "permission_overwrites", [])
}
allow = Permissions(0)
deny = Permissions(0)
if everyone_overwrite := overwrites.get(str(guild.id)):
allow |= Permissions(int(everyone_overwrite.allow))
deny |= Permissions(int(everyone_overwrite.deny))
for role_id in member.roles:
if role_overwrite := overwrites.get(str(role_id)):
allow |= Permissions(int(role_overwrite.allow))
deny |= Permissions(int(role_overwrite.deny))
if member_overwrite := overwrites.get(str(member.id)):
allow |= Permissions(int(member_overwrite.allow))
deny |= Permissions(int(member_overwrite.deny))
return (base_permissions & ~deny) | allow
def requires_permissions(
*perms: "Permissions",
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Check that the invoking member has the given permissions in the channel."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
from disagreement.permissions import (
has_permissions,
missing_permissions,
)
from disagreement.models import Member
channel = getattr(ctx, "channel", None)
if channel is None and hasattr(ctx.bot, "get_channel"):
channel = ctx.bot.get_channel(ctx.message.channel_id)
if channel is None and hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
if channel is None:
raise CheckFailure("Channel for permission check not found.")
guild = getattr(channel, "guild", None)
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
if hasattr(ctx.bot, "get_guild"):
guild = ctx.bot.get_guild(channel.guild_id)
if not guild and hasattr(ctx.bot, "fetch_guild"):
guild = await ctx.bot.fetch_guild(channel.guild_id)
if not guild:
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
if is_dm:
if perms:
raise CheckFailure("Permission checks are not supported in DMs.")
return True
raise CheckFailure("Guild for permission check not found.")
member = ctx.author
if not isinstance(member, Member):
member = guild.get_member(ctx.author.id)
if not member and hasattr(ctx.bot, "fetch_member"):
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
if not member:
raise CheckFailure("Could not resolve author to a guild member.")
perms_value = _compute_permissions(member, channel, guild)
if not has_permissions(perms_value, *perms):
missing = missing_permissions(perms_value, *perms)
missing_names = ", ".join(p.name for p in missing if p.name)
raise CheckFailure(f"Missing permissions: {missing_names}")
return True
return check(predicate)

View File

@ -49,3 +49,20 @@ async def ping(ctx):
```
Invoking a command while it is on cooldown raises :class:`CommandOnCooldown`.
## Permission Checks
Use `commands.requires_permissions` to ensure the invoking member has the
required permissions in the channel.
```python
from disagreement.ext.commands import command, requires_permissions
from disagreement.permissions import Permissions
@command()
@requires_permissions(Permissions.MANAGE_MESSAGES)
async def purge(ctx):
await ctx.send("Purged!")
```
Missing permissions raise :class:`CheckFailure`.

View File

@ -2,8 +2,13 @@ import asyncio
import pytest
from disagreement.ext.commands.core import Command, CommandContext
from disagreement.ext.commands.decorators import check, cooldown
from disagreement.ext.commands.decorators import (
check,
cooldown,
requires_permissions,
)
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
from disagreement.permissions import Permissions
@pytest.mark.asyncio
@ -49,3 +54,58 @@ async def test_cooldown_per_user(message):
await asyncio.sleep(0.05)
await cmd.invoke(ctx)
assert len(uses) == 2
@pytest.mark.asyncio
async def test_requires_permissions_pass(message):
class Channel:
def __init__(self, perms):
self.perms = perms
def permissions_for(self, member):
return self.perms
message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES)
@requires_permissions(Permissions.SEND_MESSAGES)
async def cb(ctx):
pass
cmd = Command(cb)
ctx = CommandContext(
message=message,
bot=message._client,
prefix="!",
command=cmd,
invoked_with="test",
)
await cmd.invoke(ctx)
@pytest.mark.asyncio
async def test_requires_permissions_fail(message):
class Channel:
def __init__(self, perms):
self.perms = perms
def permissions_for(self, member):
return self.perms
message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES)
@requires_permissions(Permissions.MANAGE_MESSAGES)
async def cb(ctx):
pass
cmd = Command(cb)
ctx = CommandContext(
message=message,
bot=message._client,
prefix="!",
command=cmd,
invoked_with="test",
)
with pytest.raises(CheckFailure):
await cmd.invoke(ctx)