Add is_owner decorator and owner checks (#81)
This commit is contained in:
parent
c7eb8563de
commit
5d72643390
@ -101,6 +101,7 @@ from .ext.commands import (
|
||||
command,
|
||||
cooldown,
|
||||
has_any_role,
|
||||
is_owner,
|
||||
has_role,
|
||||
listener,
|
||||
max_concurrency,
|
||||
@ -195,6 +196,7 @@ __all__ = [
|
||||
"command",
|
||||
"cooldown",
|
||||
"has_any_role",
|
||||
"is_owner",
|
||||
"has_role",
|
||||
"listener",
|
||||
"max_concurrency",
|
||||
|
@ -125,8 +125,10 @@ class Client:
|
||||
member_cache_flags: Optional[MemberCacheFlags] = None,
|
||||
message_cache_maxlen: Optional[int] = None,
|
||||
http_options: Optional[Dict[str, Any]] = None,
|
||||
owner_ids: Optional[List[Union[str, int]]] = None,
|
||||
sync_commands_on_ready: bool = True,
|
||||
):
|
||||
|
||||
if not token:
|
||||
raise ValueError("A bot token must be provided.")
|
||||
|
||||
@ -163,6 +165,7 @@ class Client:
|
||||
self.gateway_max_retries: int = gateway_max_retries
|
||||
self.gateway_max_backoff: float = gateway_max_backoff
|
||||
self._shard_manager: Optional[ShardManager] = None
|
||||
self.owner_ids: List[str] = [str(o) for o in owner_ids] if owner_ids else []
|
||||
|
||||
# Initialize CommandHandler
|
||||
self.command_handler: CommandHandler = CommandHandler(
|
||||
|
@ -14,11 +14,12 @@ from .decorators import (
|
||||
check,
|
||||
check_any,
|
||||
cooldown,
|
||||
max_concurrency,
|
||||
requires_permissions,
|
||||
has_role,
|
||||
has_any_role,
|
||||
)
|
||||
max_concurrency,
|
||||
requires_permissions,
|
||||
has_role,
|
||||
has_any_role,
|
||||
is_owner,
|
||||
)
|
||||
from .errors import (
|
||||
CommandError,
|
||||
CommandNotFound,
|
||||
@ -47,9 +48,10 @@ __all__ = [
|
||||
"cooldown",
|
||||
"max_concurrency",
|
||||
"requires_permissions",
|
||||
"has_role",
|
||||
"has_any_role",
|
||||
# Errors
|
||||
"has_role",
|
||||
"has_any_role",
|
||||
"is_owner",
|
||||
# Errors
|
||||
"CommandError",
|
||||
"CommandNotFound",
|
||||
"BadArgument",
|
||||
|
@ -292,3 +292,19 @@ def has_any_role(
|
||||
)
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def is_owner() -> (
|
||||
Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]
|
||||
):
|
||||
"""Check that the invoking user is listed as a bot owner."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
|
||||
owner_ids = getattr(ctx.bot, "owner_ids", [])
|
||||
if str(ctx.author.id) not in {str(o) for o in owner_ids}:
|
||||
raise CheckFailure("This command can only be used by the bot owner.")
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
@ -6,6 +6,7 @@ from disagreement.ext.commands.decorators import (
|
||||
check,
|
||||
cooldown,
|
||||
requires_permissions,
|
||||
is_owner,
|
||||
)
|
||||
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
|
||||
from disagreement.permissions import Permissions
|
||||
@ -133,3 +134,44 @@ async def test_requires_permissions_fail(message):
|
||||
|
||||
with pytest.raises(CheckFailure):
|
||||
await cmd.invoke(ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_owner_pass(message):
|
||||
message._client.owner_ids = ["2"]
|
||||
|
||||
@is_owner()
|
||||
async def cb(ctx):
|
||||
pass
|
||||
|
||||
cmd = Command(cb)
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=message._client,
|
||||
prefix="!",
|
||||
command=cmd,
|
||||
invoked_with="test",
|
||||
)
|
||||
|
||||
await cmd.invoke(ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_owner_fail(message):
|
||||
message._client.owner_ids = ["1"]
|
||||
|
||||
@is_owner()
|
||||
async def cb(ctx):
|
||||
pass
|
||||
|
||||
cmd = Command(cb)
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=message._client,
|
||||
prefix="!",
|
||||
command=cmd,
|
||||
invoked_with="test",
|
||||
)
|
||||
|
||||
with pytest.raises(CheckFailure):
|
||||
await cmd.invoke(ctx)
|
||||
|
Loading…
x
Reference in New Issue
Block a user