Add is_owner decorator and owner checks (#81)

This commit is contained in:
Slipstream 2025-06-15 15:23:52 -06:00 committed by GitHub
parent c7eb8563de
commit 5d72643390
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 73 additions and 8 deletions

View File

@ -101,6 +101,7 @@ from .ext.commands import (
command,
cooldown,
has_any_role,
is_owner,
has_role,
listener,
max_concurrency,
@ -195,6 +196,7 @@ __all__ = [
"command",
"cooldown",
"has_any_role",
"is_owner",
"has_role",
"listener",
"max_concurrency",

View File

@ -125,8 +125,10 @@ class Client:
member_cache_flags: Optional[MemberCacheFlags] = None,
message_cache_maxlen: Optional[int] = None,
http_options: Optional[Dict[str, Any]] = None,
owner_ids: Optional[List[Union[str, int]]] = None,
sync_commands_on_ready: bool = True,
):
if not token:
raise ValueError("A bot token must be provided.")
@ -163,6 +165,7 @@ class Client:
self.gateway_max_retries: int = gateway_max_retries
self.gateway_max_backoff: float = gateway_max_backoff
self._shard_manager: Optional[ShardManager] = None
self.owner_ids: List[str] = [str(o) for o in owner_ids] if owner_ids else []
# Initialize CommandHandler
self.command_handler: CommandHandler = CommandHandler(

View File

@ -14,11 +14,12 @@ from .decorators import (
check,
check_any,
cooldown,
max_concurrency,
requires_permissions,
has_role,
has_any_role,
)
max_concurrency,
requires_permissions,
has_role,
has_any_role,
is_owner,
)
from .errors import (
CommandError,
CommandNotFound,
@ -47,9 +48,10 @@ __all__ = [
"cooldown",
"max_concurrency",
"requires_permissions",
"has_role",
"has_any_role",
# Errors
"has_role",
"has_any_role",
"is_owner",
# Errors
"CommandError",
"CommandNotFound",
"BadArgument",

View File

@ -292,3 +292,19 @@ def has_any_role(
)
return check(predicate)
def is_owner() -> (
Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]
):
"""Check that the invoking user is listed as a bot owner."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
owner_ids = getattr(ctx.bot, "owner_ids", [])
if str(ctx.author.id) not in {str(o) for o in owner_ids}:
raise CheckFailure("This command can only be used by the bot owner.")
return True
return check(predicate)

View File

@ -6,6 +6,7 @@ from disagreement.ext.commands.decorators import (
check,
cooldown,
requires_permissions,
is_owner,
)
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
from disagreement.permissions import Permissions
@ -133,3 +134,44 @@ async def test_requires_permissions_fail(message):
with pytest.raises(CheckFailure):
await cmd.invoke(ctx)
@pytest.mark.asyncio
async def test_is_owner_pass(message):
message._client.owner_ids = ["2"]
@is_owner()
async def cb(ctx):
pass
cmd = Command(cb)
ctx = CommandContext(
message=message,
bot=message._client,
prefix="!",
command=cmd,
invoked_with="test",
)
await cmd.invoke(ctx)
@pytest.mark.asyncio
async def test_is_owner_fail(message):
message._client.owner_ids = ["1"]
@is_owner()
async def cb(ctx):
pass
cmd = Command(cb)
ctx = CommandContext(
message=message,
bot=message._client,
prefix="!",
command=cmd,
invoked_with="test",
)
with pytest.raises(CheckFailure):
await cmd.invoke(ctx)