diff --git a/disagreement/ext/commands/decorators.py b/disagreement/ext/commands/decorators.py index 8fd14f7..53f540c 100644 --- a/disagreement/ext/commands/decorators.py +++ b/disagreement/ext/commands/decorators.py @@ -132,47 +132,7 @@ def _compute_permissions( member: "Member", channel: "Channel", guild: "Guild" ) -> "Permissions": """Compute the effective permissions for a member in a channel.""" - from disagreement.models import Member, Guild, Channel - from disagreement.permissions import Permissions - - if guild.owner_id == member.id: - return Permissions(~0) - - roles = {str(r.id): r for r in guild.roles} - everyone_role = roles.get(str(guild.id)) - if not everyone_role: - base_permissions = Permissions(0) - else: - base_permissions = Permissions(int(everyone_role.permissions)) - - for role_id in member.roles: - role = roles.get(str(role_id)) - if role: - base_permissions |= Permissions(int(role.permissions)) - - if base_permissions & Permissions.ADMINISTRATOR: - return Permissions(~0) - - overwrites = { - ow.id: ow for ow in getattr(channel, "permission_overwrites", []) - } - allow = Permissions(0) - deny = Permissions(0) - - if everyone_overwrite := overwrites.get(str(guild.id)): - allow |= Permissions(int(everyone_overwrite.allow)) - deny |= Permissions(int(everyone_overwrite.deny)) - - for role_id in member.roles: - if role_overwrite := overwrites.get(str(role_id)): - allow |= Permissions(int(role_overwrite.allow)) - deny |= Permissions(int(role_overwrite.deny)) - - if member_overwrite := overwrites.get(str(member.id)): - allow |= Permissions(int(member_overwrite.allow)) - deny |= Permissions(int(member_overwrite.deny)) - - return (base_permissions & ~deny) | allow + return channel.permissions_for(member) def requires_permissions( diff --git a/tests/test_command_checks.py b/tests/test_command_checks.py index 7c37407..6a6c8b2 100644 --- a/tests/test_command_checks.py +++ b/tests/test_command_checks.py @@ -56,8 +56,17 @@ async def test_cooldown_per_user(message): assert len(uses) == 2 +@pytest.mark.asyncio @pytest.mark.asyncio async def test_requires_permissions_pass(message): + class Guild: + id = "g" + owner_id = "owner" + roles = [] + + def get_member(self, mid): + return message.author + class Channel: def __init__(self, perms): self.perms = perms @@ -66,7 +75,9 @@ async def test_requires_permissions_pass(message): def permissions_for(self, member): return self.perms + message.author.roles = [] message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES) + message._client.get_guild = lambda gid: Guild() @requires_permissions(Permissions.SEND_MESSAGES) async def cb(ctx): @@ -83,9 +94,17 @@ async def test_requires_permissions_pass(message): await cmd.invoke(ctx) - +@pytest.mark.asyncio @pytest.mark.asyncio async def test_requires_permissions_fail(message): + class Guild: + id = "g" + owner_id = "owner" + roles = [] + + def get_member(self, mid): + return message.author + class Channel: def __init__(self, perms): self.perms = perms @@ -94,7 +113,9 @@ async def test_requires_permissions_fail(message): def permissions_for(self, member): return self.perms + message.author.roles = [] message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES) + message._client.get_guild = lambda gid: Guild() @requires_permissions(Permissions.MANAGE_MESSAGES) async def cb(ctx):