Implement requires_permissions decorator (#22)
* Add permissions check decorator * Refactor command decorator and add permission computation logic
This commit is contained in:
parent
1ff56106c9
commit
e9375a5a36
@ -10,7 +10,14 @@ from .core import (
|
||||
CommandContext,
|
||||
CommandHandler,
|
||||
) # CommandHandler might be internal
|
||||
from .decorators import command, listener, check, check_any, cooldown
|
||||
from .decorators import (
|
||||
command,
|
||||
listener,
|
||||
check,
|
||||
check_any,
|
||||
cooldown,
|
||||
requires_permissions,
|
||||
)
|
||||
from .errors import (
|
||||
CommandError,
|
||||
CommandNotFound,
|
||||
@ -36,6 +43,7 @@ __all__ = [
|
||||
"check",
|
||||
"check_any",
|
||||
"cooldown",
|
||||
"requires_permissions",
|
||||
# Errors
|
||||
"CommandError",
|
||||
"CommandNotFound",
|
||||
|
@ -1,4 +1,5 @@
|
||||
# disagreement/ext/commands/decorators.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
@ -6,9 +7,9 @@ import time
|
||||
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import Command, CommandContext # For type hinting return or internal use
|
||||
|
||||
# from .cog import Cog # For Cog specific decorators
|
||||
from .core import Command, CommandContext
|
||||
from disagreement.permissions import Permissions
|
||||
from disagreement.models import Member, Guild, Channel
|
||||
|
||||
|
||||
def command(
|
||||
@ -33,32 +34,16 @@ def command(
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Command callback must be a coroutine function.")
|
||||
|
||||
from .core import (
|
||||
Command,
|
||||
) # Late import to avoid circular dependencies at module load time
|
||||
|
||||
# The actual registration will happen when a Cog is added or if commands are global.
|
||||
# For now, this decorator creates a Command instance and attaches it to the function,
|
||||
# or returns a Command instance that can be collected.
|
||||
from .core import Command
|
||||
|
||||
cmd_name = name or func.__name__
|
||||
|
||||
# Store command attributes on the function itself for later collection by Cog or Client
|
||||
# This is a common pattern.
|
||||
if hasattr(func, "__command_attrs__"):
|
||||
# This case might occur if decorators are stacked in an unusual way,
|
||||
# or if a function is decorated multiple times (which should be disallowed or handled).
|
||||
# For now, let's assume one @command decorator per function.
|
||||
raise TypeError("Function is already a command or has command attributes.")
|
||||
|
||||
# Create the command object. It will be registered by the Cog or Client.
|
||||
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
||||
|
||||
# We can attach the command object to the function, so Cogs can find it.
|
||||
func.__command_object__ = cmd # type: ignore # type: ignore[attr-defined]
|
||||
return func # Return the original function, now marked.
|
||||
# Or return `cmd` if commands are registered globally immediately.
|
||||
# For Cogs, returning `func` and letting Cog collect is cleaner.
|
||||
func.__command_object__ = cmd # type: ignore
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
@ -68,11 +53,6 @@ def listener(
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""
|
||||
A decorator that marks a function as an event listener within a Cog.
|
||||
The actual registration happens when the Cog is added to the client.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): The name of the event to listen to.
|
||||
Defaults to the function name (e.g., `on_message`).
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
@ -81,13 +61,11 @@ def listener(
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Listener callback must be a coroutine function.")
|
||||
|
||||
# 'name' here is from the outer 'listener' scope (closure)
|
||||
actual_event_name = name or func.__name__
|
||||
# Store listener info on the function for Cog to collect
|
||||
setattr(func, "__listener_name__", actual_event_name)
|
||||
return func
|
||||
|
||||
return decorator # This must be correctly indented under 'listener'
|
||||
return decorator
|
||||
|
||||
|
||||
def check(
|
||||
@ -148,3 +126,107 @@ def cooldown(
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def _compute_permissions(
|
||||
member: "Member", channel: "Channel", guild: "Guild"
|
||||
) -> "Permissions":
|
||||
"""Compute the effective permissions for a member in a channel."""
|
||||
from disagreement.models import Member, Guild, Channel
|
||||
from disagreement.permissions import Permissions
|
||||
|
||||
if guild.owner_id == member.id:
|
||||
return Permissions(~0)
|
||||
|
||||
roles = {str(r.id): r for r in guild.roles}
|
||||
everyone_role = roles.get(str(guild.id))
|
||||
if not everyone_role:
|
||||
base_permissions = Permissions(0)
|
||||
else:
|
||||
base_permissions = Permissions(int(everyone_role.permissions))
|
||||
|
||||
for role_id in member.roles:
|
||||
role = roles.get(str(role_id))
|
||||
if role:
|
||||
base_permissions |= Permissions(int(role.permissions))
|
||||
|
||||
if base_permissions & Permissions.ADMINISTRATOR:
|
||||
return Permissions(~0)
|
||||
|
||||
overwrites = {
|
||||
ow.id: ow for ow in getattr(channel, "permission_overwrites", [])
|
||||
}
|
||||
allow = Permissions(0)
|
||||
deny = Permissions(0)
|
||||
|
||||
if everyone_overwrite := overwrites.get(str(guild.id)):
|
||||
allow |= Permissions(int(everyone_overwrite.allow))
|
||||
deny |= Permissions(int(everyone_overwrite.deny))
|
||||
|
||||
for role_id in member.roles:
|
||||
if role_overwrite := overwrites.get(str(role_id)):
|
||||
allow |= Permissions(int(role_overwrite.allow))
|
||||
deny |= Permissions(int(role_overwrite.deny))
|
||||
|
||||
if member_overwrite := overwrites.get(str(member.id)):
|
||||
allow |= Permissions(int(member_overwrite.allow))
|
||||
deny |= Permissions(int(member_overwrite.deny))
|
||||
|
||||
return (base_permissions & ~deny) | allow
|
||||
|
||||
|
||||
def requires_permissions(
|
||||
*perms: "Permissions",
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Check that the invoking member has the given permissions in the channel."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
from disagreement.permissions import (
|
||||
has_permissions,
|
||||
missing_permissions,
|
||||
)
|
||||
from disagreement.models import Member
|
||||
|
||||
channel = getattr(ctx, "channel", None)
|
||||
if channel is None and hasattr(ctx.bot, "get_channel"):
|
||||
channel = ctx.bot.get_channel(ctx.message.channel_id)
|
||||
if channel is None and hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
|
||||
|
||||
if channel is None:
|
||||
raise CheckFailure("Channel for permission check not found.")
|
||||
|
||||
guild = getattr(channel, "guild", None)
|
||||
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
|
||||
if hasattr(ctx.bot, "get_guild"):
|
||||
guild = ctx.bot.get_guild(channel.guild_id)
|
||||
if not guild and hasattr(ctx.bot, "fetch_guild"):
|
||||
guild = await ctx.bot.fetch_guild(channel.guild_id)
|
||||
|
||||
if not guild:
|
||||
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
|
||||
if is_dm:
|
||||
if perms:
|
||||
raise CheckFailure("Permission checks are not supported in DMs.")
|
||||
return True
|
||||
raise CheckFailure("Guild for permission check not found.")
|
||||
|
||||
member = ctx.author
|
||||
if not isinstance(member, Member):
|
||||
member = guild.get_member(ctx.author.id)
|
||||
if not member and hasattr(ctx.bot, "fetch_member"):
|
||||
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
|
||||
|
||||
if not member:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
perms_value = _compute_permissions(member, channel, guild)
|
||||
|
||||
if not has_permissions(perms_value, *perms):
|
||||
missing = missing_permissions(perms_value, *perms)
|
||||
missing_names = ", ".join(p.name for p in missing if p.name)
|
||||
raise CheckFailure(f"Missing permissions: {missing_names}")
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
@ -49,3 +49,20 @@ async def ping(ctx):
|
||||
```
|
||||
|
||||
Invoking a command while it is on cooldown raises :class:`CommandOnCooldown`.
|
||||
|
||||
## Permission Checks
|
||||
|
||||
Use `commands.requires_permissions` to ensure the invoking member has the
|
||||
required permissions in the channel.
|
||||
|
||||
```python
|
||||
from disagreement.ext.commands import command, requires_permissions
|
||||
from disagreement.permissions import Permissions
|
||||
|
||||
@command()
|
||||
@requires_permissions(Permissions.MANAGE_MESSAGES)
|
||||
async def purge(ctx):
|
||||
await ctx.send("Purged!")
|
||||
```
|
||||
|
||||
Missing permissions raise :class:`CheckFailure`.
|
||||
|
@ -2,8 +2,13 @@ import asyncio
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.commands.core import Command, CommandContext
|
||||
from disagreement.ext.commands.decorators import check, cooldown
|
||||
from disagreement.ext.commands.decorators import (
|
||||
check,
|
||||
cooldown,
|
||||
requires_permissions,
|
||||
)
|
||||
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
|
||||
from disagreement.permissions import Permissions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -49,3 +54,58 @@ async def test_cooldown_per_user(message):
|
||||
await asyncio.sleep(0.05)
|
||||
await cmd.invoke(ctx)
|
||||
assert len(uses) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_permissions_pass(message):
|
||||
class Channel:
|
||||
def __init__(self, perms):
|
||||
self.perms = perms
|
||||
|
||||
def permissions_for(self, member):
|
||||
return self.perms
|
||||
|
||||
message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES)
|
||||
|
||||
@requires_permissions(Permissions.SEND_MESSAGES)
|
||||
async def cb(ctx):
|
||||
pass
|
||||
|
||||
cmd = Command(cb)
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=message._client,
|
||||
prefix="!",
|
||||
command=cmd,
|
||||
invoked_with="test",
|
||||
)
|
||||
|
||||
await cmd.invoke(ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_permissions_fail(message):
|
||||
class Channel:
|
||||
def __init__(self, perms):
|
||||
self.perms = perms
|
||||
|
||||
def permissions_for(self, member):
|
||||
return self.perms
|
||||
|
||||
message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES)
|
||||
|
||||
@requires_permissions(Permissions.MANAGE_MESSAGES)
|
||||
async def cb(ctx):
|
||||
pass
|
||||
|
||||
cmd = Command(cb)
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=message._client,
|
||||
prefix="!",
|
||||
command=cmd,
|
||||
invoked_with="test",
|
||||
)
|
||||
|
||||
with pytest.raises(CheckFailure):
|
||||
await cmd.invoke(ctx)
|
||||
|
Loading…
x
Reference in New Issue
Block a user