Merge branch 'master' of https://github.com/Slipstreamm/disagreement
This commit is contained in:
commit
a4aa4335a5
@ -46,6 +46,103 @@ class Color:
|
|||||||
def blue(cls) -> "Color":
|
def blue(cls) -> "Color":
|
||||||
return cls(0x0000FF)
|
return cls(0x0000FF)
|
||||||
|
|
||||||
|
# Discord brand colors
|
||||||
|
@classmethod
|
||||||
|
def blurple(cls) -> "Color":
|
||||||
|
"""Discord brand blurple (#5865F2)."""
|
||||||
|
return cls(0x5865F2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def light_blurple(cls) -> "Color":
|
||||||
|
"""Light blurple used by Discord (#E0E3FF)."""
|
||||||
|
return cls(0xE0E3FF)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def legacy_blurple(cls) -> "Color":
|
||||||
|
"""Legacy Discord blurple (#7289DA)."""
|
||||||
|
return cls(0x7289DA)
|
||||||
|
|
||||||
|
# Additional assorted colors
|
||||||
|
@classmethod
|
||||||
|
def teal(cls) -> "Color":
|
||||||
|
return cls(0x1ABC9C)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dark_teal(cls) -> "Color":
|
||||||
|
return cls(0x11806A)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def brand_green(cls) -> "Color":
|
||||||
|
return cls(0x57F287)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dark_green(cls) -> "Color":
|
||||||
|
return cls(0x206694)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def orange(cls) -> "Color":
|
||||||
|
return cls(0xE67E22)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dark_orange(cls) -> "Color":
|
||||||
|
return cls(0xA84300)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def brand_red(cls) -> "Color":
|
||||||
|
return cls(0xED4245)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dark_red(cls) -> "Color":
|
||||||
|
return cls(0x992D22)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def magenta(cls) -> "Color":
|
||||||
|
return cls(0xE91E63)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dark_magenta(cls) -> "Color":
|
||||||
|
return cls(0xAD1457)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def purple(cls) -> "Color":
|
||||||
|
return cls(0x9B59B6)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dark_purple(cls) -> "Color":
|
||||||
|
return cls(0x71368A)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def yellow(cls) -> "Color":
|
||||||
|
return cls(0xF1C40F)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dark_gold(cls) -> "Color":
|
||||||
|
return cls(0xC27C0E)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def light_gray(cls) -> "Color":
|
||||||
|
return cls(0x99AAB5)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dark_gray(cls) -> "Color":
|
||||||
|
return cls(0x2C2F33)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def lighter_gray(cls) -> "Color":
|
||||||
|
return cls(0xBFBFBF)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def darker_gray(cls) -> "Color":
|
||||||
|
return cls(0x23272A)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def black(cls) -> "Color":
|
||||||
|
return cls(0x000000)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def white(cls) -> "Color":
|
||||||
|
return cls(0xFFFFFF)
|
||||||
|
|
||||||
def to_rgb(self) -> tuple[int, int, int]:
|
def to_rgb(self) -> tuple[int, int, int]:
|
||||||
return ((self.value >> 16) & 0xFF, (self.value >> 8) & 0xFF, self.value & 0xFF)
|
return ((self.value >> 16) & 0xFF, (self.value >> 8) & 0xFF, self.value & 0xFF)
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
# disagreement/ext/app_commands/handler.py
|
# disagreement/ext/app_commands/handler.py
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Dict,
|
Dict,
|
||||||
@ -67,6 +69,8 @@ if not TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
COMMANDS_CACHE_FILE = ".disagreement_commands.json"
|
||||||
|
|
||||||
|
|
||||||
class AppCommandHandler:
|
class AppCommandHandler:
|
||||||
"""
|
"""
|
||||||
@ -84,6 +88,33 @@ class AppCommandHandler:
|
|||||||
self._app_command_groups: Dict[str, AppCommandGroup] = {}
|
self._app_command_groups: Dict[str, AppCommandGroup] = {}
|
||||||
self._converter_registry: Dict[type, type] = {}
|
self._converter_registry: Dict[type, type] = {}
|
||||||
|
|
||||||
|
def _load_cached_ids(self) -> Dict[str, Dict[str, str]]:
|
||||||
|
try:
|
||||||
|
with open(COMMANDS_CACHE_FILE, "r", encoding="utf-8") as fp:
|
||||||
|
return json.load(fp)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Invalid command cache file. Ignoring.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _save_cached_ids(self, data: Dict[str, Dict[str, str]]) -> None:
|
||||||
|
try:
|
||||||
|
with open(COMMANDS_CACHE_FILE, "w", encoding="utf-8") as fp:
|
||||||
|
json.dump(data, fp, indent=2)
|
||||||
|
except Exception as e: # pragma: no cover - logging only
|
||||||
|
logger.error("Failed to write command cache: %s", e)
|
||||||
|
|
||||||
|
def clear_stored_registrations(self) -> None:
|
||||||
|
"""Remove persisted command registration data."""
|
||||||
|
if os.path.exists(COMMANDS_CACHE_FILE):
|
||||||
|
os.remove(COMMANDS_CACHE_FILE)
|
||||||
|
|
||||||
|
def migrate_stored_registrations(self, new_path: str) -> None:
|
||||||
|
"""Move stored registrations to ``new_path``."""
|
||||||
|
if os.path.exists(COMMANDS_CACHE_FILE):
|
||||||
|
os.replace(COMMANDS_CACHE_FILE, new_path)
|
||||||
|
|
||||||
def add_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
|
def add_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
|
||||||
"""Adds an application command or a command group to the handler."""
|
"""Adds an application command or a command group to the handler."""
|
||||||
if isinstance(command, AppCommandGroup):
|
if isinstance(command, AppCommandGroup):
|
||||||
@ -564,11 +595,13 @@ class AppCommandHandler:
|
|||||||
Synchronizes (registers/updates) all application commands with Discord.
|
Synchronizes (registers/updates) all application commands with Discord.
|
||||||
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
|
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
|
||||||
"""
|
"""
|
||||||
commands_to_sync: List[Dict[str, Any]] = []
|
cache = self._load_cached_ids()
|
||||||
|
scope_key = str(guild_id) if guild_id else "global"
|
||||||
|
stored = cache.get(scope_key, {})
|
||||||
|
|
||||||
|
current_payloads: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
# Collect commands based on scope (global or specific guild)
|
# Collect commands based on scope (global or specific guild)
|
||||||
# This needs to be more sophisticated to handle guild_ids on commands/groups
|
|
||||||
|
|
||||||
source_commands = (
|
source_commands = (
|
||||||
list(self._slash_commands.values())
|
list(self._slash_commands.values())
|
||||||
+ list(self._user_commands.values())
|
+ list(self._user_commands.values())
|
||||||
@ -577,26 +610,22 @@ class AppCommandHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for cmd_or_group in source_commands:
|
for cmd_or_group in source_commands:
|
||||||
# Determine if this command/group should be synced for the current scope
|
|
||||||
is_guild_specific_command = (
|
is_guild_specific_command = (
|
||||||
cmd_or_group.guild_ids is not None and len(cmd_or_group.guild_ids) > 0
|
cmd_or_group.guild_ids is not None and len(cmd_or_group.guild_ids) > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
if guild_id: # Syncing for a specific guild
|
if guild_id:
|
||||||
# Skip if not a guild-specific command OR if it's for a different guild
|
|
||||||
if not is_guild_specific_command or (
|
if not is_guild_specific_command or (
|
||||||
cmd_or_group.guild_ids is not None
|
cmd_or_group.guild_ids is not None
|
||||||
and guild_id not in cmd_or_group.guild_ids
|
and guild_id not in cmd_or_group.guild_ids
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
else: # Syncing global commands
|
else:
|
||||||
if is_guild_specific_command:
|
if is_guild_specific_command:
|
||||||
continue # Skip guild-specific commands when syncing global
|
continue
|
||||||
|
|
||||||
# Use the to_dict() method from AppCommand or AppCommandGroup
|
|
||||||
try:
|
try:
|
||||||
payload = cmd_or_group.to_dict()
|
current_payloads[cmd_or_group.name] = cmd_or_group.to_dict()
|
||||||
commands_to_sync.append(payload)
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Command or group '%s' does not have a to_dict() method. Skipping.",
|
"Command or group '%s' does not have a to_dict() method. Skipping.",
|
||||||
@ -609,32 +638,74 @@ class AppCommandHandler:
|
|||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not commands_to_sync:
|
if not current_payloads:
|
||||||
logger.info(
|
logger.info(
|
||||||
"No commands to sync for %s scope.",
|
"No commands to sync for %s scope.",
|
||||||
f"guild {guild_id}" if guild_id else "global",
|
f"guild {guild_id}" if guild_id else "global",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
names_current = set(current_payloads)
|
||||||
|
names_stored = set(stored)
|
||||||
|
|
||||||
|
to_delete = names_stored - names_current
|
||||||
|
to_create = names_current - names_stored
|
||||||
|
to_update = names_current & names_stored
|
||||||
|
|
||||||
|
if not to_delete and not to_create and not to_update:
|
||||||
|
logger.info(
|
||||||
|
"Application commands already up to date for %s scope.", scope_key
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if guild_id:
|
for name in to_delete:
|
||||||
logger.info(
|
cmd_id = stored[name]
|
||||||
"Syncing %s commands for guild %s...",
|
if guild_id:
|
||||||
len(commands_to_sync),
|
await self.client._http.delete_guild_application_command(
|
||||||
guild_id,
|
application_id, guild_id, cmd_id
|
||||||
)
|
)
|
||||||
await self.client._http.bulk_overwrite_guild_application_commands(
|
else:
|
||||||
application_id, guild_id, commands_to_sync
|
await self.client._http.delete_global_application_command(
|
||||||
)
|
application_id, cmd_id
|
||||||
else:
|
)
|
||||||
logger.info(
|
|
||||||
"Syncing %s global commands...",
|
new_ids: Dict[str, str] = {}
|
||||||
len(commands_to_sync),
|
for name in to_create:
|
||||||
)
|
payload = current_payloads[name]
|
||||||
await self.client._http.bulk_overwrite_global_application_commands(
|
if guild_id:
|
||||||
application_id, commands_to_sync
|
result = await self.client._http.create_guild_application_command(
|
||||||
)
|
application_id, guild_id, payload
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await self.client._http.create_global_application_command(
|
||||||
|
application_id, payload
|
||||||
|
)
|
||||||
|
if result.id:
|
||||||
|
new_ids[name] = str(result.id)
|
||||||
|
|
||||||
|
for name in to_update:
|
||||||
|
payload = current_payloads[name]
|
||||||
|
cmd_id = stored[name]
|
||||||
|
if guild_id:
|
||||||
|
await self.client._http.edit_guild_application_command(
|
||||||
|
application_id, guild_id, cmd_id, payload
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.client._http.edit_global_application_command(
|
||||||
|
application_id, cmd_id, payload
|
||||||
|
)
|
||||||
|
new_ids[name] = cmd_id
|
||||||
|
|
||||||
|
final_ids: Dict[str, str] = {}
|
||||||
|
for name in names_current:
|
||||||
|
if name in new_ids:
|
||||||
|
final_ids[name] = new_ids[name]
|
||||||
|
else:
|
||||||
|
final_ids[name] = stored[name]
|
||||||
|
|
||||||
|
cache[scope_key] = final_ids
|
||||||
|
self._save_cached_ids(cache)
|
||||||
logger.info("Command sync successful.")
|
logger.info("Command sync successful.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error syncing application commands: %s", e)
|
logger.error("Error syncing application commands: %s", e)
|
||||||
# Consider re-raising or specific error handling
|
|
||||||
|
@ -16,6 +16,7 @@ from .decorators import (
|
|||||||
check,
|
check,
|
||||||
check_any,
|
check_any,
|
||||||
cooldown,
|
cooldown,
|
||||||
|
max_concurrency,
|
||||||
requires_permissions,
|
requires_permissions,
|
||||||
)
|
)
|
||||||
from .errors import (
|
from .errors import (
|
||||||
@ -28,6 +29,7 @@ from .errors import (
|
|||||||
CheckAnyFailure,
|
CheckAnyFailure,
|
||||||
CommandOnCooldown,
|
CommandOnCooldown,
|
||||||
CommandInvokeError,
|
CommandInvokeError,
|
||||||
|
MaxConcurrencyReached,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -43,6 +45,7 @@ __all__ = [
|
|||||||
"check",
|
"check",
|
||||||
"check_any",
|
"check_any",
|
||||||
"cooldown",
|
"cooldown",
|
||||||
|
"max_concurrency",
|
||||||
"requires_permissions",
|
"requires_permissions",
|
||||||
# Errors
|
# Errors
|
||||||
"CommandError",
|
"CommandError",
|
||||||
@ -54,4 +57,5 @@ __all__ = [
|
|||||||
"CheckAnyFailure",
|
"CheckAnyFailure",
|
||||||
"CommandOnCooldown",
|
"CommandOnCooldown",
|
||||||
"CommandInvokeError",
|
"CommandInvokeError",
|
||||||
|
"MaxConcurrencyReached",
|
||||||
]
|
]
|
||||||
|
@ -70,6 +70,10 @@ class Command:
|
|||||||
if hasattr(callback, "__command_checks__"):
|
if hasattr(callback, "__command_checks__"):
|
||||||
self.checks.extend(getattr(callback, "__command_checks__"))
|
self.checks.extend(getattr(callback, "__command_checks__"))
|
||||||
|
|
||||||
|
self.max_concurrency: Optional[Tuple[int, str]] = None
|
||||||
|
if hasattr(callback, "__max_concurrency__"):
|
||||||
|
self.max_concurrency = getattr(callback, "__max_concurrency__")
|
||||||
|
|
||||||
def add_check(
|
def add_check(
|
||||||
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
|
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -215,6 +219,7 @@ class CommandHandler:
|
|||||||
] = prefix
|
] = prefix
|
||||||
self.commands: Dict[str, Command] = {}
|
self.commands: Dict[str, Command] = {}
|
||||||
self.cogs: Dict[str, "Cog"] = {}
|
self.cogs: Dict[str, "Cog"] = {}
|
||||||
|
self._concurrency: Dict[str, Dict[str, int]] = {}
|
||||||
|
|
||||||
from .help import HelpCommand
|
from .help import HelpCommand
|
||||||
|
|
||||||
@ -300,6 +305,47 @@ class CommandHandler:
|
|||||||
logger.info("Cog '%s' removed.", cog_name)
|
logger.info("Cog '%s' removed.", cog_name)
|
||||||
return cog_to_remove
|
return cog_to_remove
|
||||||
|
|
||||||
|
def _acquire_concurrency(self, ctx: CommandContext) -> None:
|
||||||
|
mc = getattr(ctx.command, "max_concurrency", None)
|
||||||
|
if not mc:
|
||||||
|
return
|
||||||
|
limit, scope = mc
|
||||||
|
if scope == "user":
|
||||||
|
key = ctx.author.id
|
||||||
|
elif scope == "guild":
|
||||||
|
key = ctx.message.guild_id or ctx.author.id
|
||||||
|
else:
|
||||||
|
key = "global"
|
||||||
|
buckets = self._concurrency.setdefault(ctx.command.name, {})
|
||||||
|
current = buckets.get(key, 0)
|
||||||
|
if current >= limit:
|
||||||
|
from .errors import MaxConcurrencyReached
|
||||||
|
|
||||||
|
raise MaxConcurrencyReached(limit)
|
||||||
|
buckets[key] = current + 1
|
||||||
|
|
||||||
|
def _release_concurrency(self, ctx: CommandContext) -> None:
|
||||||
|
mc = getattr(ctx.command, "max_concurrency", None)
|
||||||
|
if not mc:
|
||||||
|
return
|
||||||
|
_, scope = mc
|
||||||
|
if scope == "user":
|
||||||
|
key = ctx.author.id
|
||||||
|
elif scope == "guild":
|
||||||
|
key = ctx.message.guild_id or ctx.author.id
|
||||||
|
else:
|
||||||
|
key = "global"
|
||||||
|
buckets = self._concurrency.get(ctx.command.name)
|
||||||
|
if not buckets:
|
||||||
|
return
|
||||||
|
current = buckets.get(key, 0)
|
||||||
|
if current <= 1:
|
||||||
|
buckets.pop(key, None)
|
||||||
|
else:
|
||||||
|
buckets[key] = current - 1
|
||||||
|
if not buckets:
|
||||||
|
self._concurrency.pop(ctx.command.name, None)
|
||||||
|
|
||||||
async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
|
async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
|
||||||
if callable(self.prefix):
|
if callable(self.prefix):
|
||||||
if inspect.iscoroutinefunction(self.prefix):
|
if inspect.iscoroutinefunction(self.prefix):
|
||||||
@ -501,7 +547,11 @@ class CommandHandler:
|
|||||||
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
|
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
|
||||||
ctx.args = parsed_args
|
ctx.args = parsed_args
|
||||||
ctx.kwargs = parsed_kwargs
|
ctx.kwargs = parsed_kwargs
|
||||||
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
|
self._acquire_concurrency(ctx)
|
||||||
|
try:
|
||||||
|
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
|
||||||
|
finally:
|
||||||
|
self._release_concurrency(ctx)
|
||||||
except CommandError as e:
|
except CommandError as e:
|
||||||
logger.error("Command error for '%s': %s", command.name, e)
|
logger.error("Command error for '%s': %s", command.name, e)
|
||||||
if hasattr(self.client, "on_command_error"):
|
if hasattr(self.client, "on_command_error"):
|
||||||
|
@ -107,6 +107,33 @@ def check_any(
|
|||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
|
def max_concurrency(
|
||||||
|
number: int, per: str = "user"
|
||||||
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
|
"""Limit how many concurrent invocations of a command are allowed.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
number:
|
||||||
|
The maximum number of concurrent invocations.
|
||||||
|
per:
|
||||||
|
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if number < 1:
|
||||||
|
raise ValueError("Concurrency number must be at least 1.")
|
||||||
|
if per not in {"user", "guild", "global"}:
|
||||||
|
raise ValueError("per must be 'user', 'guild', or 'global'.")
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
func: Callable[..., Awaitable[None]],
|
||||||
|
) -> Callable[..., Awaitable[None]]:
|
||||||
|
setattr(func, "__max_concurrency__", (number, per))
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def cooldown(
|
def cooldown(
|
||||||
rate: int, per: float
|
rate: int, per: float
|
||||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
|
@ -72,5 +72,13 @@ class CommandInvokeError(CommandError):
|
|||||||
super().__init__(f"Error during command invocation: {original}")
|
super().__init__(f"Error during command invocation: {original}")
|
||||||
|
|
||||||
|
|
||||||
|
class MaxConcurrencyReached(CommandError):
|
||||||
|
"""Raised when a command exceeds its concurrency limit."""
|
||||||
|
|
||||||
|
def __init__(self, limit: int):
|
||||||
|
self.limit = limit
|
||||||
|
super().__init__(f"Max concurrency of {limit} reached")
|
||||||
|
|
||||||
|
|
||||||
# Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc.
|
# Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc.
|
||||||
# These might inherit from BadArgument.
|
# These might inherit from BadArgument.
|
||||||
|
@ -20,3 +20,14 @@ Use `AppCommandGroup` to group related commands. See the [components guide](usin
|
|||||||
- [Caching](caching.md)
|
- [Caching](caching.md)
|
||||||
- [Voice Features](voice_features.md)
|
- [Voice Features](voice_features.md)
|
||||||
|
|
||||||
|
## Command Persistence
|
||||||
|
|
||||||
|
`AppCommandHandler.sync_commands` can persist registered command IDs in
|
||||||
|
`.disagreement_commands.json`. When enabled, subsequent syncs compare the
|
||||||
|
stored IDs to the commands defined in code and only create, edit or delete
|
||||||
|
commands when changes are detected.
|
||||||
|
|
||||||
|
Call `AppCommandHandler.clear_stored_registrations()` if you need to wipe the
|
||||||
|
stored IDs or migrate them elsewhere with
|
||||||
|
`AppCommandHandler.migrate_stored_registrations()`.
|
||||||
|
|
||||||
|
@ -11,3 +11,6 @@ def test_static_colors():
|
|||||||
assert Color.red().value == 0xFF0000
|
assert Color.red().value == 0xFF0000
|
||||||
assert Color.green().value == 0x00FF00
|
assert Color.green().value == 0x00FF00
|
||||||
assert Color.blue().value == 0x0000FF
|
assert Color.blue().value == 0x0000FF
|
||||||
|
assert Color.blurple().value == 0x5865F2
|
||||||
|
assert Color.light_blurple().value == 0xE0E3FF
|
||||||
|
assert Color.legacy_blurple().value == 0x7289DA
|
||||||
|
103
tests/test_max_concurrency.py
Normal file
103
tests/test_max_concurrency.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.ext.commands.core import CommandHandler
|
||||||
|
from disagreement.ext.commands.decorators import command, max_concurrency
|
||||||
|
from disagreement.ext.commands.errors import MaxConcurrencyReached
|
||||||
|
from disagreement.models import Message
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBot:
|
||||||
|
def __init__(self):
|
||||||
|
self.errors = []
|
||||||
|
|
||||||
|
async def on_command_error(self, ctx, error):
|
||||||
|
self.errors.append(error)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_concurrency_per_user():
|
||||||
|
bot = DummyBot()
|
||||||
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
|
||||||
|
@command()
|
||||||
|
@max_concurrency(1, per="user")
|
||||||
|
async def foo(ctx):
|
||||||
|
started.set()
|
||||||
|
await release.wait()
|
||||||
|
|
||||||
|
handler.add_command(foo.__command_object__)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"guild_id": "g",
|
||||||
|
"author": {"id": "a", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "!foo",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg1 = Message(data, client_instance=bot)
|
||||||
|
msg2 = Message({**data, "id": "2"}, client_instance=bot)
|
||||||
|
|
||||||
|
task = asyncio.create_task(handler.process_commands(msg1))
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
await handler.process_commands(msg2)
|
||||||
|
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
await task
|
||||||
|
|
||||||
|
await handler.process_commands(msg2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_concurrency_per_guild():
|
||||||
|
bot = DummyBot()
|
||||||
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
|
||||||
|
@command()
|
||||||
|
@max_concurrency(1, per="guild")
|
||||||
|
async def foo(ctx):
|
||||||
|
started.set()
|
||||||
|
await release.wait()
|
||||||
|
|
||||||
|
handler.add_command(foo.__command_object__)
|
||||||
|
|
||||||
|
base = {
|
||||||
|
"channel_id": "c",
|
||||||
|
"guild_id": "g",
|
||||||
|
"content": "!foo",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg1 = Message(
|
||||||
|
{
|
||||||
|
**base,
|
||||||
|
"id": "1",
|
||||||
|
"author": {"id": "a", "username": "u", "discriminator": "0001"},
|
||||||
|
},
|
||||||
|
client_instance=bot,
|
||||||
|
)
|
||||||
|
msg2 = Message(
|
||||||
|
{
|
||||||
|
**base,
|
||||||
|
"id": "2",
|
||||||
|
"author": {"id": "b", "username": "v", "discriminator": "0001"},
|
||||||
|
},
|
||||||
|
client_instance=bot,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = asyncio.create_task(handler.process_commands(msg1))
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
await handler.process_commands(msg2)
|
||||||
|
assert any(isinstance(e, MaxConcurrencyReached) for e in bot.errors)
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
await task
|
||||||
|
|
||||||
|
await handler.process_commands(msg2)
|
Loading…
x
Reference in New Issue
Block a user