Add is_owner decorator and owner checks (#81)
This commit is contained in:
parent
c7eb8563de
commit
5d72643390
@ -101,6 +101,7 @@ from .ext.commands import (
|
|||||||
command,
|
command,
|
||||||
cooldown,
|
cooldown,
|
||||||
has_any_role,
|
has_any_role,
|
||||||
|
is_owner,
|
||||||
has_role,
|
has_role,
|
||||||
listener,
|
listener,
|
||||||
max_concurrency,
|
max_concurrency,
|
||||||
@ -195,6 +196,7 @@ __all__ = [
|
|||||||
"command",
|
"command",
|
||||||
"cooldown",
|
"cooldown",
|
||||||
"has_any_role",
|
"has_any_role",
|
||||||
|
"is_owner",
|
||||||
"has_role",
|
"has_role",
|
||||||
"listener",
|
"listener",
|
||||||
"max_concurrency",
|
"max_concurrency",
|
||||||
|
@ -125,8 +125,10 @@ class Client:
|
|||||||
member_cache_flags: Optional[MemberCacheFlags] = None,
|
member_cache_flags: Optional[MemberCacheFlags] = None,
|
||||||
message_cache_maxlen: Optional[int] = None,
|
message_cache_maxlen: Optional[int] = None,
|
||||||
http_options: Optional[Dict[str, Any]] = None,
|
http_options: Optional[Dict[str, Any]] = None,
|
||||||
|
owner_ids: Optional[List[Union[str, int]]] = None,
|
||||||
sync_commands_on_ready: bool = True,
|
sync_commands_on_ready: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
raise ValueError("A bot token must be provided.")
|
raise ValueError("A bot token must be provided.")
|
||||||
|
|
||||||
@ -163,6 +165,7 @@ class Client:
|
|||||||
self.gateway_max_retries: int = gateway_max_retries
|
self.gateway_max_retries: int = gateway_max_retries
|
||||||
self.gateway_max_backoff: float = gateway_max_backoff
|
self.gateway_max_backoff: float = gateway_max_backoff
|
||||||
self._shard_manager: Optional[ShardManager] = None
|
self._shard_manager: Optional[ShardManager] = None
|
||||||
|
self.owner_ids: List[str] = [str(o) for o in owner_ids] if owner_ids else []
|
||||||
|
|
||||||
# Initialize CommandHandler
|
# Initialize CommandHandler
|
||||||
self.command_handler: CommandHandler = CommandHandler(
|
self.command_handler: CommandHandler = CommandHandler(
|
||||||
|
@ -18,6 +18,7 @@ from .decorators import (
|
|||||||
requires_permissions,
|
requires_permissions,
|
||||||
has_role,
|
has_role,
|
||||||
has_any_role,
|
has_any_role,
|
||||||
|
is_owner,
|
||||||
)
|
)
|
||||||
from .errors import (
|
from .errors import (
|
||||||
CommandError,
|
CommandError,
|
||||||
@ -49,6 +50,7 @@ __all__ = [
|
|||||||
"requires_permissions",
|
"requires_permissions",
|
||||||
"has_role",
|
"has_role",
|
||||||
"has_any_role",
|
"has_any_role",
|
||||||
|
"is_owner",
|
||||||
# Errors
|
# Errors
|
||||||
"CommandError",
|
"CommandError",
|
||||||
"CommandNotFound",
|
"CommandNotFound",
|
||||||
|
@ -292,3 +292,19 @@ def has_any_role(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
|
def is_owner() -> (
|
||||||
|
Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]
|
||||||
|
):
|
||||||
|
"""Check that the invoking user is listed as a bot owner."""
|
||||||
|
|
||||||
|
async def predicate(ctx: "CommandContext") -> bool:
|
||||||
|
from .errors import CheckFailure
|
||||||
|
|
||||||
|
owner_ids = getattr(ctx.bot, "owner_ids", [])
|
||||||
|
if str(ctx.author.id) not in {str(o) for o in owner_ids}:
|
||||||
|
raise CheckFailure("This command can only be used by the bot owner.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return check(predicate)
|
||||||
|
@ -6,6 +6,7 @@ from disagreement.ext.commands.decorators import (
|
|||||||
check,
|
check,
|
||||||
cooldown,
|
cooldown,
|
||||||
requires_permissions,
|
requires_permissions,
|
||||||
|
is_owner,
|
||||||
)
|
)
|
||||||
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
|
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
|
||||||
from disagreement.permissions import Permissions
|
from disagreement.permissions import Permissions
|
||||||
@ -133,3 +134,44 @@ async def test_requires_permissions_fail(message):
|
|||||||
|
|
||||||
with pytest.raises(CheckFailure):
|
with pytest.raises(CheckFailure):
|
||||||
await cmd.invoke(ctx)
|
await cmd.invoke(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_owner_pass(message):
|
||||||
|
message._client.owner_ids = ["2"]
|
||||||
|
|
||||||
|
@is_owner()
|
||||||
|
async def cb(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cmd = Command(cb)
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message,
|
||||||
|
bot=message._client,
|
||||||
|
prefix="!",
|
||||||
|
command=cmd,
|
||||||
|
invoked_with="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
await cmd.invoke(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_owner_fail(message):
|
||||||
|
message._client.owner_ids = ["1"]
|
||||||
|
|
||||||
|
@is_owner()
|
||||||
|
async def cb(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cmd = Command(cb)
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message,
|
||||||
|
bot=message._client,
|
||||||
|
prefix="!",
|
||||||
|
command=cmd,
|
||||||
|
invoked_with="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(CheckFailure):
|
||||||
|
await cmd.invoke(ctx)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user