diff --git a/disagreement/__init__.py b/disagreement/__init__.py index 01b306f..bfd7d3b 100644 --- a/disagreement/__init__.py +++ b/disagreement/__init__.py @@ -101,6 +101,7 @@ from .ext.commands import ( command, cooldown, has_any_role, + is_owner, has_role, listener, max_concurrency, @@ -195,6 +196,7 @@ __all__ = [ "command", "cooldown", "has_any_role", + "is_owner", "has_role", "listener", "max_concurrency", diff --git a/disagreement/client.py b/disagreement/client.py index e9e954b..2adae6b 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -125,8 +125,10 @@ class Client: member_cache_flags: Optional[MemberCacheFlags] = None, message_cache_maxlen: Optional[int] = None, http_options: Optional[Dict[str, Any]] = None, + owner_ids: Optional[List[Union[str, int]]] = None, sync_commands_on_ready: bool = True, ): + if not token: raise ValueError("A bot token must be provided.") @@ -163,6 +165,7 @@ class Client: self.gateway_max_retries: int = gateway_max_retries self.gateway_max_backoff: float = gateway_max_backoff self._shard_manager: Optional[ShardManager] = None + self.owner_ids: List[str] = [str(o) for o in owner_ids] if owner_ids else [] # Initialize CommandHandler self.command_handler: CommandHandler = CommandHandler( diff --git a/disagreement/ext/commands/__init__.py b/disagreement/ext/commands/__init__.py index ef92e0d..7adf886 100644 --- a/disagreement/ext/commands/__init__.py +++ b/disagreement/ext/commands/__init__.py @@ -14,11 +14,12 @@ from .decorators import ( check, check_any, cooldown, - max_concurrency, - requires_permissions, - has_role, - has_any_role, -) + max_concurrency, + requires_permissions, + has_role, + has_any_role, + is_owner, +) from .errors import ( CommandError, CommandNotFound, @@ -47,9 +48,10 @@ __all__ = [ "cooldown", "max_concurrency", "requires_permissions", - "has_role", - "has_any_role", - # Errors + "has_role", + "has_any_role", + "is_owner", + # Errors "CommandError", "CommandNotFound", "BadArgument", diff --git a/disagreement/ext/commands/decorators.py b/disagreement/ext/commands/decorators.py index 15ab249..d9fa226 100644 --- a/disagreement/ext/commands/decorators.py +++ b/disagreement/ext/commands/decorators.py @@ -292,3 +292,19 @@ def has_any_role( ) return check(predicate) + + +def is_owner() -> ( + Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]] +): + """Check that the invoking user is listed as a bot owner.""" + + async def predicate(ctx: "CommandContext") -> bool: + from .errors import CheckFailure + + owner_ids = getattr(ctx.bot, "owner_ids", []) + if str(ctx.author.id) not in {str(o) for o in owner_ids}: + raise CheckFailure("This command can only be used by the bot owner.") + return True + + return check(predicate) diff --git a/tests/test_command_checks.py b/tests/test_command_checks.py index 33d5730..63935a4 100644 --- a/tests/test_command_checks.py +++ b/tests/test_command_checks.py @@ -6,6 +6,7 @@ from disagreement.ext.commands.decorators import ( check, cooldown, requires_permissions, + is_owner, ) from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown from disagreement.permissions import Permissions @@ -133,3 +134,44 @@ async def test_requires_permissions_fail(message): with pytest.raises(CheckFailure): await cmd.invoke(ctx) + + +@pytest.mark.asyncio +async def test_is_owner_pass(message): + message._client.owner_ids = ["2"] + + @is_owner() + async def cb(ctx): + pass + + cmd = Command(cb) + ctx = CommandContext( + message=message, + bot=message._client, + prefix="!", + command=cmd, + invoked_with="test", + ) + + await cmd.invoke(ctx) + + +@pytest.mark.asyncio +async def test_is_owner_fail(message): + message._client.owner_ids = ["1"] + + @is_owner() + async def cb(ctx): + pass + + cmd = Command(cb) + ctx = CommandContext( + message=message, + bot=message._client, + prefix="!", + command=cmd, + invoked_with="test", + ) + + with pytest.raises(CheckFailure): + await cmd.invoke(ctx)