Refactor: Move permission calculation to Channel model
The _compute_permissions helper function has been simplified by delegating the complex permission calculation logic to a new channel.permissions_for(member) method. This refactoring encapsulates the permission logic within the Channel model, where it belongs. It removes duplicate and complex code from the command decorator, improving maintainability and separation of concerns. The tests for requires_permissions have been updated to mock the new permissions_for method, resulting in cleaner tests that are no longer concerned with the internal implementation details of how permissions are calculated.
This commit is contained in:
parent
61222e1df7
commit
81ea79a94d
@ -132,47 +132,7 @@ def _compute_permissions(
|
||||
member: "Member", channel: "Channel", guild: "Guild"
|
||||
) -> "Permissions":
|
||||
"""Compute the effective permissions for a member in a channel."""
|
||||
from disagreement.models import Member, Guild, Channel
|
||||
from disagreement.permissions import Permissions
|
||||
|
||||
if guild.owner_id == member.id:
|
||||
return Permissions(~0)
|
||||
|
||||
roles = {str(r.id): r for r in guild.roles}
|
||||
everyone_role = roles.get(str(guild.id))
|
||||
if not everyone_role:
|
||||
base_permissions = Permissions(0)
|
||||
else:
|
||||
base_permissions = Permissions(int(everyone_role.permissions))
|
||||
|
||||
for role_id in member.roles:
|
||||
role = roles.get(str(role_id))
|
||||
if role:
|
||||
base_permissions |= Permissions(int(role.permissions))
|
||||
|
||||
if base_permissions & Permissions.ADMINISTRATOR:
|
||||
return Permissions(~0)
|
||||
|
||||
overwrites = {
|
||||
ow.id: ow for ow in getattr(channel, "permission_overwrites", [])
|
||||
}
|
||||
allow = Permissions(0)
|
||||
deny = Permissions(0)
|
||||
|
||||
if everyone_overwrite := overwrites.get(str(guild.id)):
|
||||
allow |= Permissions(int(everyone_overwrite.allow))
|
||||
deny |= Permissions(int(everyone_overwrite.deny))
|
||||
|
||||
for role_id in member.roles:
|
||||
if role_overwrite := overwrites.get(str(role_id)):
|
||||
allow |= Permissions(int(role_overwrite.allow))
|
||||
deny |= Permissions(int(role_overwrite.deny))
|
||||
|
||||
if member_overwrite := overwrites.get(str(member.id)):
|
||||
allow |= Permissions(int(member_overwrite.allow))
|
||||
deny |= Permissions(int(member_overwrite.deny))
|
||||
|
||||
return (base_permissions & ~deny) | allow
|
||||
return channel.permissions_for(member)
|
||||
|
||||
|
||||
def requires_permissions(
|
||||
|
@ -56,8 +56,17 @@ async def test_cooldown_per_user(message):
|
||||
assert len(uses) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_permissions_pass(message):
|
||||
class Guild:
|
||||
id = "g"
|
||||
owner_id = "owner"
|
||||
roles = []
|
||||
|
||||
def get_member(self, mid):
|
||||
return message.author
|
||||
|
||||
class Channel:
|
||||
def __init__(self, perms):
|
||||
self.perms = perms
|
||||
@ -66,7 +75,9 @@ async def test_requires_permissions_pass(message):
|
||||
def permissions_for(self, member):
|
||||
return self.perms
|
||||
|
||||
message.author.roles = []
|
||||
message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES)
|
||||
message._client.get_guild = lambda gid: Guild()
|
||||
|
||||
@requires_permissions(Permissions.SEND_MESSAGES)
|
||||
async def cb(ctx):
|
||||
@ -83,9 +94,17 @@ async def test_requires_permissions_pass(message):
|
||||
|
||||
await cmd.invoke(ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_permissions_fail(message):
|
||||
class Guild:
|
||||
id = "g"
|
||||
owner_id = "owner"
|
||||
roles = []
|
||||
|
||||
def get_member(self, mid):
|
||||
return message.author
|
||||
|
||||
class Channel:
|
||||
def __init__(self, perms):
|
||||
self.perms = perms
|
||||
@ -94,7 +113,9 @@ async def test_requires_permissions_fail(message):
|
||||
def permissions_for(self, member):
|
||||
return self.perms
|
||||
|
||||
message.author.roles = []
|
||||
message._client.get_channel = lambda cid: Channel(Permissions.SEND_MESSAGES)
|
||||
message._client.get_guild = lambda gid: Guild()
|
||||
|
||||
@requires_permissions(Permissions.MANAGE_MESSAGES)
|
||||
async def cb(ctx):
|
||||
|
Loading…
x
Reference in New Issue
Block a user