1899 lines
73 KiB
Python

"""
The main Client class for interacting with the Discord API.
"""
import asyncio
import signal
import json
import os
import importlib
from typing import (
Optional,
Callable,
Any,
TYPE_CHECKING,
Awaitable,
AsyncIterator,
Union,
List,
Dict,
cast,
)
from types import ModuleType
PERSISTENT_VIEWS_FILE = "persistent_views.json"
from datetime import datetime, timedelta
from .http import HTTPClient
from .gateway import GatewayClient
from .shard_manager import ShardManager
from .event_dispatcher import EventDispatcher
from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion
from .errors import DisagreementException, AuthenticationError
from .typing import Typing
from .caching import MemberCacheFlags
from .cache import Cache, GuildCache, ChannelCache, MemberCache
from .ext.commands.core import Command, CommandHandler, Group
from .ext.commands.cog import Cog
from .ext.app_commands.handler import AppCommandHandler
from .ext.app_commands.context import AppCommandContext
from .ext import loader as ext_loader
from .interactions import Interaction, Snowflake
from .error_handler import setup_global_error_handler
from .voice_client import VoiceClient
from .models import Activity
from .utils import utcnow
if TYPE_CHECKING:
from .models import (
Message,
Embed,
ActionRow,
Guild,
Channel,
User,
Member,
Role,
TextChannel,
VoiceChannel,
CategoryChannel,
Thread,
DMChannel,
Webhook,
GuildTemplate,
ScheduledEvent,
AuditLogEntry,
Invite,
)
from .ui.view import View
from .enums import ChannelType as EnumChannelType
from .ext.commands.core import CommandContext
from .ext.commands.errors import CommandError, CommandInvokeError
from .ext.app_commands.commands import AppCommand, AppCommandGroup
def _update_list(lst: List[Any], item: Any) -> None:
"""Replace an item with the same ID in a list or append if missing."""
for i, existing in enumerate(lst):
if getattr(existing, "id", None) == getattr(item, "id", None):
lst[i] = item
return
lst.append(item)
class Client:
"""
Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API.
Args:
token (str): The bot token for authentication.
intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`.
You might need to enable privileged intents in your bot's application page.
loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations.
Defaults to the running loop
via `asyncio.get_running_loop()`,
or a new loop from
`asyncio.new_event_loop()` if
none is running.
command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]):
The prefix(es) for commands. Defaults to '!'.
verbose (bool): If True, print raw HTTP and Gateway traffic for debugging.
mention_replies (bool): Whether replies mention the author by default.
allowed_mentions (Optional[Dict[str, Any]]): Default allowed mentions for messages.
http_options (Optional[Dict[str, Any]]): Extra options passed to
:class:`HTTPClient` for creating the internal
:class:`aiohttp.ClientSession`.
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
in the cache. When ``None``, the cache size is unlimited.
sync_commands_on_ready (bool): If ``True``, automatically call
:meth:`Client.sync_application_commands` after the ``READY`` event
when :attr:`Client.application_id` is available.
"""
def __init__(
self,
token: str,
intents: Optional[int] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
command_prefix: Union[
str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
] = "!",
application_id: Optional[Union[str, int]] = None,
verbose: bool = False,
mention_replies: bool = False,
allowed_mentions: Optional[Dict[str, Any]] = None,
shard_count: Optional[int] = None,
gateway_max_retries: int = 5,
gateway_max_backoff: float = 60.0,
member_cache_flags: Optional[MemberCacheFlags] = None,
message_cache_maxlen: Optional[int] = None,
http_options: Optional[Dict[str, Any]] = None,
owner_ids: Optional[List[Union[str, int]]] = None,
sync_commands_on_ready: bool = True,
):
if not token:
raise ValueError("A bot token must be provided.")
self.token: str = token
self.member_cache_flags: MemberCacheFlags = (
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
)
self.message_cache_maxlen: Optional[int] = message_cache_maxlen
self.intents: int = intents if intents is not None else GatewayIntent.default()
if loop:
self.loop: asyncio.AbstractEventLoop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.application_id: Optional[Snowflake] = (
str(application_id) if application_id else None
)
setup_global_error_handler(self.loop)
self.verbose: bool = verbose
self._http: HTTPClient = HTTPClient(
token=self.token,
verbose=verbose,
**(http_options or {}),
)
self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self)
self._gateway: Optional[GatewayClient] = (
None # Initialized in start() or connect()
)
self.shard_count: Optional[int] = shard_count
self.gateway_max_retries: int = gateway_max_retries
self.gateway_max_backoff: float = gateway_max_backoff
self._shard_manager: Optional[ShardManager] = None
self.owner_ids: List[str] = [str(o) for o in owner_ids] if owner_ids else []
# Initialize CommandHandler
self.command_handler: CommandHandler = CommandHandler(
client=self, prefix=command_prefix
)
self.app_command_handler: AppCommandHandler = AppCommandHandler(client=self)
# Register internal listener for processing commands from messages
self._event_dispatcher.register(
"MESSAGE_CREATE", self._process_message_for_commands
)
self._closed: bool = False
self._ready_event: asyncio.Event = asyncio.Event()
self.user: Optional["User"] = (
None # The bot's own user object, populated on READY
)
self.start_time: Optional[datetime] = None
# Internal Caches
self._guilds: GuildCache = GuildCache()
self._channels: ChannelCache = ChannelCache()
self._users: Cache["User"] = Cache()
self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen)
self._views: Dict[Snowflake, "View"] = {}
self._persistent_views: Dict[str, "View"] = {}
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
self._webhooks: Dict[Snowflake, "Webhook"] = {}
# Load persistent views stored on disk
self._load_persistent_views()
# Default whether replies mention the user
self.mention_replies: bool = mention_replies
self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions
self.sync_commands_on_ready: bool = sync_commands_on_ready
# Basic signal handling for graceful shutdown
# This might be better handled by the user's application code, but can be a nice default.
# For more robust handling, consider libraries or more advanced patterns.
try:
self.loop.add_signal_handler(
signal.SIGINT, lambda: self.loop.create_task(self.close())
)
self.loop.add_signal_handler(
signal.SIGTERM, lambda: self.loop.create_task(self.close())
)
except NotImplementedError:
# add_signal_handler is not available on all platforms (e.g., Windows default event loop policy)
# Users on these platforms would need to handle shutdown differently.
print(
"Warning: Signal handlers for SIGINT/SIGTERM could not be added. "
"Graceful shutdown via signals might not work as expected on this platform."
)
def _load_persistent_views(self) -> None:
"""Load registered persistent views from disk."""
if not os.path.isfile(PERSISTENT_VIEWS_FILE):
return
try:
with open(PERSISTENT_VIEWS_FILE, "r") as fp:
mapping = json.load(fp)
except Exception as e: # pragma: no cover - best effort load
print(f"Failed to load persistent views: {e}")
return
for custom_id, path in mapping.items():
try:
module_name, class_name = path.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
view = cls()
self._persistent_views[custom_id] = view
except Exception as e: # pragma: no cover - best effort load
print(f"Failed to initialize persistent view {path}: {e}")
def _save_persistent_views(self) -> None:
"""Persist registered views to disk."""
data = {}
for custom_id, view in self._persistent_views.items():
cls = view.__class__
data[custom_id] = f"{cls.__module__}.{cls.__name__}"
try:
with open(PERSISTENT_VIEWS_FILE, "w") as fp:
json.dump(data, fp)
except Exception as e: # pragma: no cover - best effort save
print(f"Failed to save persistent views: {e}")
async def _initialize_gateway(self):
"""Initializes the GatewayClient if it doesn't exist."""
if self._gateway is None:
self._gateway = GatewayClient(
http_client=self._http,
event_dispatcher=self._event_dispatcher,
token=self.token,
intents=self.intents,
client_instance=self,
verbose=self.verbose,
max_retries=self.gateway_max_retries,
max_backoff=self.gateway_max_backoff,
)
async def _initialize_shard_manager(self) -> None:
"""Initializes the :class:`ShardManager` if not already created."""
if self._shard_manager is None:
count = self.shard_count or 1
self._shard_manager = ShardManager(self, count)
async def connect(self, reconnect: bool = True) -> None:
"""
Establishes a connection to Discord. This includes logging in and connecting to the Gateway.
This method is a coroutine.
Args:
reconnect (bool): Whether to automatically attempt to reconnect on disconnect.
(Note: Basic reconnect logic is within GatewayClient for now)
Raises:
GatewayException: If the connection to the gateway fails.
AuthenticationError: If the token is invalid.
"""
if self._closed:
raise DisagreementException("Client is closed and cannot connect.")
if self.shard_count and self.shard_count > 1:
await self._initialize_shard_manager()
assert self._shard_manager is not None
await self._shard_manager.start()
print(
f"Client connected using {self.shard_count} shards, waiting for READY signal..."
)
await self.wait_until_ready()
self.start_time = utcnow()
print("Client is READY!")
return
await self._initialize_gateway()
assert self._gateway is not None # Should be initialized by now
retry_delay = 5 # seconds
max_retries = 5 # For initial connection attempts by Client.start, Gateway has its own internal retries for some cases.
for attempt in range(max_retries):
try:
await self._gateway.connect()
# After successful connection, GatewayClient's HELLO handler will trigger IDENTIFY/RESUME
# and its READY handler will set self._ready_event via dispatcher.
print("Client connected to Gateway, waiting for READY signal...")
await self.wait_until_ready() # Wait for the READY event from Gateway
self.start_time = utcnow()
print("Client is READY!")
return # Successfully connected and ready
except AuthenticationError: # Non-recoverable by retry here
print("Authentication failed. Please check your bot token.")
await self.close() # Ensure cleanup
raise
except DisagreementException as e: # Includes GatewayException
print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
print(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
retry_delay = min(
retry_delay * 2, 60
) # Exponential backoff up to 60s
else:
print("Max connection retries reached. Giving up.")
await self.close() # Ensure cleanup
raise
if max_retries == 0: # If max_retries was 0, means no retries attempted
raise DisagreementException("Connection failed with 0 retries allowed.")
async def start(self) -> None:
"""
Connect the client to Discord and run until the client is closed.
This method is a coroutine containing the main run loop logic.
"""
if self._closed:
raise DisagreementException("Client is already closed.")
try:
await self.connect()
# The GatewayClient's _receive_loop will keep running.
# This run method effectively waits until the client is closed or an unhandled error occurs.
# A more robust implementation might have a main loop here that monitors gateway health.
# For now, we rely on the gateway's tasks.
while not self._closed:
if (
self._gateway
and self._gateway._receive_task
and self._gateway._receive_task.done()
):
# If receive task ended unexpectedly, try to handle it or re-raise
try:
exc = self._gateway._receive_task.exception()
if exc:
print(
f"Gateway receive task ended with exception: {exc}. Attempting to reconnect..."
)
# This is a basic reconnect strategy from the client side.
# GatewayClient itself might handle some reconnects.
await self.close_gateway(
code=1000
) # Close current gateway state
await asyncio.sleep(5) # Wait before reconnecting
if (
not self._closed
): # If client wasn't closed by the exception handler
await self.connect()
else:
break # Client was closed, exit run loop
else:
print(
"Gateway receive task ended without exception. Assuming clean shutdown or reconnect handled internally."
)
if (
not self._closed
): # If not explicitly closed, might be an issue
print(
"Warning: Gateway receive task ended but client not closed. This might indicate an issue."
)
# Consider a more robust health check or reconnect strategy here.
await asyncio.sleep(
1
) # Prevent tight loop if something is wrong
else:
break # Client was closed
except asyncio.CancelledError:
print("Gateway receive task was cancelled.")
break # Exit if cancelled
except Exception as e:
print(f"Error checking gateway receive task: {e}")
break # Exit on other errors
await asyncio.sleep(1) # Main loop check interval
except DisagreementException as e:
print(f"Client run loop encountered an error: {e}")
# Error already logged by connect or other methods
except asyncio.CancelledError:
print("Client run loop was cancelled.")
finally:
if not self._closed:
await self.close()
def run(self) -> None:
"""Synchronously start the client using :func:`asyncio.run`."""
asyncio.run(self.start())
async def close(self) -> None:
"""
Closes the connection to Discord. This method is a coroutine.
"""
if self._closed:
return
self._closed = True
print("Closing client...")
if self._shard_manager:
await self._shard_manager.close()
self._shard_manager = None
if self._gateway:
await self._gateway.close()
if self._http: # HTTPClient has its own session to close
await self._http.close()
self._ready_event.set() # Ensure any waiters for ready are unblocked
self.start_time = None
print("Client closed.")
async def __aenter__(self) -> "Client":
"""Enter the context manager by connecting to Discord."""
await self.connect()
return self
async def __aexit__(
self,
exc_type: Optional[type],
exc: Optional[BaseException],
tb: Optional[BaseException],
) -> bool:
"""Exit the context manager and close the client."""
await self.close()
return False
async def close_gateway(self, code: int = 1000) -> None:
"""Closes only the gateway connection, allowing for potential reconnect."""
if self._shard_manager:
await self._shard_manager.close()
self._shard_manager = None
if self._gateway:
await self._gateway.close(code=code)
self._gateway = None
self._ready_event.clear() # No longer ready if gateway is closed
async def logout(self) -> None:
"""Invalidate the bot token and disconnect from the Gateway."""
await self.close_gateway()
self.token = ""
self._http.token = ""
self.user = None
self.start_time = None
def is_closed(self) -> bool:
"""Indicates if the client has been closed."""
return self._closed
def is_ready(self) -> bool:
"""Indicates if the client has successfully connected to the Gateway and is ready."""
return self._ready_event.is_set()
@property
def latency(self) -> Optional[float]:
"""Returns the gateway latency in seconds, or ``None`` if unavailable."""
if self._gateway:
return self._gateway.latency
return None
@property
def latency_ms(self) -> Optional[float]:
"""Returns the gateway latency in milliseconds, or ``None`` if unavailable."""
latency = getattr(self._gateway, "latency_ms", None)
return round(latency, 2) if latency is not None else None
@property
def guilds(self) -> List["Guild"]:
"""Returns all guilds from the internal cache."""
return self._guilds.values()
def uptime(self) -> Optional[timedelta]:
"""Return the duration since the client connected, or ``None`` if not connected."""
if self.start_time is None:
return None
return utcnow() - self.start_time
async def wait_until_ready(self) -> None:
"""|coro|
Waits until the client is fully connected to Discord and the initial state is processed.
This is mainly useful for waiting for the READY event from the Gateway.
"""
await self._ready_event.wait()
async def wait_for(
self,
event_name: str,
check: Optional[Callable[[Any], bool]] = None,
timeout: Optional[float] = None,
) -> Any:
"""|coro|
Waits for a specific event to occur that satisfies the ``check``.
Parameters
----------
event_name: str
The name of the event to wait for.
check: Optional[Callable[[Any], bool]]
A function that determines whether the received event should resolve the wait.
timeout: Optional[float]
How long to wait for the event before raising :class:`asyncio.TimeoutError`.
"""
future: asyncio.Future = self.loop.create_future()
self._event_dispatcher.add_waiter(event_name, future, check)
try:
return await asyncio.wait_for(future, timeout=timeout)
finally:
self._event_dispatcher.remove_waiter(event_name, future)
async def change_presence(
self,
status: str,
activity: Optional[Activity] = None,
since: int = 0,
afk: bool = False,
):
"""
Changes the client's presence on Discord.
Args:
status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible").
activity (Optional[Activity]): Activity instance describing what the bot is doing.
since (int): The timestamp (in milliseconds) of when the client went idle.
afk (bool): Whether the client is AFK.
"""
if self._closed:
raise DisagreementException("Client is closed.")
if self._gateway:
await self._gateway.update_presence(
status=status,
activity=activity,
since=since,
afk=afk,
)
# --- Event Handling ---
def event(
self, coro: Callable[..., Awaitable[None]]
) -> Callable[..., Awaitable[None]]:
"""
A decorator that registers an event to listen to.
The name of the coroutine is used as the event name.
Example:
@client.event
async def on_ready(): # Will listen for the 'READY' event
print("Bot is ready!")
@client.event
async def on_message(message: disagreement.Message): # Will listen for 'MESSAGE_CREATE'
print(f"Message from {message.author}: {message.content}")
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError("Event registered must be a coroutine function.")
event_name = coro.__name__
# Map common function names to Discord event types
# e.g., on_ready -> READY, on_message -> MESSAGE_CREATE
if event_name.startswith("on_"):
discord_event_name = event_name[3:].upper()
mapping = {
"MESSAGE": "MESSAGE_CREATE",
"MESSAGE_EDIT": "MESSAGE_UPDATE",
"MESSAGE_UPDATE": "MESSAGE_UPDATE",
"MESSAGE_DELETE": "MESSAGE_DELETE",
"REACTION_ADD": "MESSAGE_REACTION_ADD",
"REACTION_REMOVE": "MESSAGE_REACTION_REMOVE",
}
discord_event_name = mapping.get(discord_event_name, discord_event_name)
self._event_dispatcher.register(discord_event_name, coro)
else:
# If not starting with "on_", assume it's the direct Discord event name (e.g. "TYPING_START")
# Or raise an error if a specific format is required.
# For now, let's assume direct mapping if no "on_" prefix.
self._event_dispatcher.register(event_name.upper(), coro)
return coro # Return the original coroutine
def on_event(
self, event_name: str
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""
A decorator that registers an event to listen to with a specific event name.
Example:
@client.on_event('MESSAGE_CREATE')
async def my_message_handler(message: disagreement.Message):
print(f"Message: {message.content}")
"""
def decorator(
coro: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(coro):
raise TypeError("Event registered must be a coroutine function.")
self._event_dispatcher.register(event_name.upper(), coro)
return coro
return decorator
def add_listener(
self, event_name: str, coro: Callable[..., Awaitable[None]]
) -> None:
"""Register ``coro`` to listen for ``event_name``."""
self._event_dispatcher.register(event_name, coro)
def remove_listener(
self, event_name: str, coro: Callable[..., Awaitable[None]]
) -> None:
"""Remove ``coro`` from ``event_name`` listeners."""
self._event_dispatcher.unregister(event_name, coro)
async def _process_message_for_commands(self, message: "Message") -> None:
"""Internal listener to process messages for commands."""
# Make sure message object is valid and not from a bot (optional, common check)
if (
not message or not message.author or message.author.bot
): # Add .bot check to User model
return
await self.command_handler.process_commands(message)
async def get_context(self, message: "Message") -> Optional["CommandContext"]:
"""Return a :class:`CommandContext` for ``message`` without executing the command."""
return await self.command_handler.get_context(message)
# --- Command Framework Methods ---
def add_cog(self, cog: Cog) -> None:
"""
Adds a Cog to the bot.
Cogs are classes that group commands, listeners, and state.
This will also discover and register any application commands defined in the cog.
Args:
cog (Cog): An instance of a class derived from `disagreement.ext.commands.Cog`.
"""
# Add to prefix command handler
self.command_handler.add_cog(
cog
) # This should call cog._inject() internally or cog._inject() is called on Cog init
# Discover and add application commands from the cog
# AppCommand and AppCommandGroup are already imported in TYPE_CHECKING block
for app_cmd_obj in cog.get_app_commands_and_groups(): # Uses the new method
# The cog attribute should have been set within Cog._inject() for AppCommands
self.app_command_handler.add_command(app_cmd_obj)
print(
f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'."
)
def remove_cog(self, cog_name: str) -> Optional[Cog]:
"""
Removes a Cog from the bot.
Args:
cog_name (str): The name of the Cog to remove.
Returns:
Optional[Cog]: The Cog that was removed, or None if not found.
"""
removed_cog = self.command_handler.remove_cog(cog_name)
if removed_cog:
# Also remove associated application commands
# This requires AppCommand to store a reference to its cog, or iterate all app_commands.
# Assuming AppCommand has a .cog attribute, which is set in Cog._inject()
# And AppCommandGroup might store commands that have .cog attribute
for app_cmd_or_group in removed_cog.get_app_commands_and_groups():
# The AppCommandHandler.remove_command needs to handle both AppCommand and AppCommandGroup
self.app_command_handler.remove_command(
app_cmd_or_group.name
) # Assuming name is unique enough for removal here
print(
f"Removed app command/group '{app_cmd_or_group.name}' from cog '{cog_name}'."
)
# Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique
# (e.g. if it needs type or if groups and commands can share names).
# For now, assuming name is sufficient for removal from the handler's flat list.
return removed_cog
def get_cog(self, name: str) -> Optional[Cog]:
"""Return a loaded cog by name if present."""
return self.command_handler.get_cog(name)
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
"""
A decorator that adds a global check to the bot.
This check will be called for every command before it's executed.
Example:
@bot.check
async def block_dms(ctx):
return ctx.guild is not None
"""
self.command_handler.add_check(coro)
return coro
def command(
self, **attrs: Any
) -> Callable[[Callable[..., Awaitable[None]]], Command]:
"""A decorator that transforms a function into a Command."""
def decorator(func: Callable[..., Awaitable[None]]) -> Command:
cmd = Command(func, **attrs)
self.command_handler.add_command(cmd)
return cmd
return decorator
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], Group]:
"""A decorator that transforms a function into a Group command."""
def decorator(func: Callable[..., Awaitable[None]]) -> Group:
cmd = Group(func, **attrs)
self.command_handler.add_command(cmd)
return cmd
return decorator
def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
"""
Adds a standalone application command or group to the bot.
Use this for commands not defined within a Cog.
Args:
command (Union[AppCommand, AppCommandGroup]): The application command or group instance.
This is typically the object returned by a decorator like @slash_command.
"""
from .ext.app_commands.commands import (
AppCommand,
AppCommandGroup,
) # Ensure types
if not isinstance(command, (AppCommand, AppCommandGroup)):
raise TypeError(
"Command must be an instance of AppCommand or AppCommandGroup."
)
# If it's a decorated function, the command object might be on __app_command_object__
if hasattr(command, "__app_command_object__") and isinstance(
getattr(command, "__app_command_object__"), (AppCommand, AppCommandGroup)
):
actual_command_obj = getattr(command, "__app_command_object__")
self.app_command_handler.add_command(actual_command_obj)
print(
f"Registered standalone app command/group '{actual_command_obj.name}'."
)
elif isinstance(
command, (AppCommand, AppCommandGroup)
): # It's already the command object
self.app_command_handler.add_command(command)
print(f"Registered standalone app command/group '{command.name}'.")
else:
# This case should ideally not be hit if type checks are done by decorators
print(
f"Warning: Could not register app command {command}. It's not a recognized command object or decorated function."
)
async def on_command_error(
self, ctx: "CommandContext", error: "CommandError"
) -> None:
"""
Default command error handler. Called when a command raises an error.
Users can override this method in a subclass of Client to implement custom error handling.
Args:
ctx (CommandContext): The context of the command that raised the error.
error (CommandError): The error that was raised.
"""
# Default behavior: print to console.
# Users might want to send a message to ctx.channel or log to a file.
print(
f"Error in command '{ctx.command.name if ctx.command else 'unknown'}': {error}"
)
# Need to import CommandInvokeError for this check if not already globally available
# For now, assuming it's imported via TYPE_CHECKING or directly if needed at runtime
from .ext.commands.errors import (
CommandInvokeError as CIE,
) # Local import for isinstance check
if isinstance(error, CIE):
# Now it's safe to access error.original
print(
f"Original exception: {type(error.original).__name__}: {error.original}"
)
# import traceback
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
async def on_command_completion(self, ctx: "CommandContext") -> None:
"""
Default command completion handler. Called when a command has successfully completed.
Users can override this method in a subclass of Client.
Args:
ctx (CommandContext): The context of the command that completed.
"""
pass
# --- Extension Management Methods ---
def load_extension(self, name: str) -> ModuleType:
"""Load an extension by name using :mod:`disagreement.ext.loader`."""
return ext_loader.load_extension(name)
def unload_extension(self, name: str) -> None:
"""Unload a previously loaded extension."""
ext_loader.unload_extension(name)
def reload_extension(self, name: str) -> ModuleType:
"""Reload an extension by name."""
return ext_loader.reload_extension(name)
# --- Model Parsing and Fetching ---
def parse_user(self, data: Dict[str, Any]) -> "User":
"""Parses user data and returns a User object, updating cache."""
from .models import User # Ensure User model is available
user = User(data, client_instance=self)
self._users.set(user.id, user) # Cache the user
return user
def parse_channel(self, data: Dict[str, Any]) -> "Channel":
"""Parses channel data and returns a Channel object, updating caches."""
from .models import (
channel_factory,
TextChannel,
VoiceChannel,
CategoryChannel,
)
channel = channel_factory(data, self)
self._channels.set(channel.id, channel)
if channel.guild_id:
guild = self._guilds.get(channel.guild_id)
if guild:
guild._channels.set(channel.id, channel)
if isinstance(channel, TextChannel):
_update_list(guild.text_channels, channel)
elif isinstance(channel, VoiceChannel):
_update_list(guild.voice_channels, channel)
elif isinstance(channel, CategoryChannel):
_update_list(guild.category_channels, channel)
return channel
def parse_message(self, data: Dict[str, Any]) -> "Message":
"""Parses message data and returns a Message object, updating cache."""
from .models import Message
message = Message(data, client_instance=self)
self._messages.set(message.id, message)
return message
def parse_webhook(self, data: Union[Dict[str, Any], "Webhook"]) -> "Webhook":
"""Parses webhook data and returns a Webhook object, updating cache."""
from .models import Webhook
if isinstance(data, Webhook):
webhook = data
webhook._client = self # type: ignore[attr-defined]
else:
webhook = Webhook(data, client_instance=self)
self._webhooks[webhook.id] = webhook
return webhook
def parse_template(self, data: Dict[str, Any]) -> "GuildTemplate":
"""Parses template data into a GuildTemplate object."""
from .models import GuildTemplate
return GuildTemplate(data, client_instance=self)
def parse_scheduled_event(self, data: Dict[str, Any]) -> "ScheduledEvent":
"""Parses scheduled event data and updates cache."""
from .models import ScheduledEvent
event = ScheduledEvent(data, client_instance=self)
# Cache by ID under guild if guild cache exists
guild = self._guilds.get(event.guild_id)
if guild is not None:
events = getattr(guild, "_scheduled_events", {})
events[event.id] = event
setattr(guild, "_scheduled_events", events)
return event
def parse_audit_log_entry(self, data: Dict[str, Any]) -> "AuditLogEntry":
"""Parses audit log entry data."""
from .models import AuditLogEntry
return AuditLogEntry(data, client_instance=self)
def parse_invite(self, data: Dict[str, Any]) -> "Invite":
"""Parses invite data into an :class:`Invite`."""
from .models import Invite
return Invite.from_dict(data)
async def fetch_user(self, user_id: Snowflake) -> Optional["User"]:
"""Fetches a user by ID from Discord."""
if self._closed:
raise DisagreementException("Client is closed.")
cached_user = self._users.get(user_id)
if cached_user:
return cached_user
try:
user_data = await self._http.get_user(user_id)
return self.parse_user(user_data)
except DisagreementException as e: # Catch HTTP exceptions from http client
print(f"Failed to fetch user {user_id}: {e}")
return None
async def fetch_message(
self, channel_id: Snowflake, message_id: Snowflake
) -> Optional["Message"]:
"""Fetches a message by ID from Discord and caches it."""
if self._closed:
raise DisagreementException("Client is closed.")
cached_message = self._messages.get(message_id)
if cached_message:
return cached_message
try:
message_data = await self._http.get_message(channel_id, message_id)
return self.parse_message(message_data)
except DisagreementException as e:
print(
f"Failed to fetch message {message_id} from channel {channel_id}: {e}"
)
return None
def parse_member(
self, data: Dict[str, Any], guild_id: Snowflake, *, just_joined: bool = False
) -> "Member":
"""Parses member data and returns a Member object, updating relevant caches."""
from .models import Member
member = Member(data, client_instance=self)
member.guild_id = str(guild_id)
if just_joined:
setattr(member, "_just_joined", True)
guild = self._guilds.get(guild_id)
if guild:
guild._members.set(member.id, member)
if just_joined and hasattr(member, "_just_joined"):
delattr(member, "_just_joined")
self._users.set(member.id, member)
return member
async def fetch_member(
self, guild_id: Snowflake, member_id: Snowflake
) -> Optional["Member"]:
"""Fetches a member from a guild by ID."""
if self._closed:
raise DisagreementException("Client is closed.")
guild = self.get_guild(guild_id)
if guild:
cached_member = guild.get_member(member_id) # Use Guild's get_member
if cached_member:
return cached_member # Return cached if available
try:
member_data = await self._http.get_guild_member(guild_id, member_id)
return self.parse_member(member_data, guild_id)
except DisagreementException as e:
print(f"Failed to fetch member {member_id} from guild {guild_id}: {e}")
return None
def parse_role(self, data: Dict[str, Any], guild_id: Snowflake) -> "Role":
"""Parses role data and returns a Role object, updating guild's role cache."""
from .models import Role # Ensure Role model is available
role = Role(data)
guild = self._guilds.get(guild_id)
if guild:
# Update the role in the guild's roles list if it exists, or add it.
# Guild.roles is List[Role]. We need to find and replace or append.
found = False
for i, existing_role in enumerate(guild.roles):
if existing_role.id == role.id:
guild.roles[i] = role
found = True
break
if not found:
guild.roles.append(role)
return role
def parse_guild(self, data: Dict[str, Any]) -> "Guild":
"""Parses guild data and returns a Guild object, updating cache."""
from .models import Guild
shard_id = data.get("shard_id")
guild = Guild(data, client_instance=self, shard_id=shard_id)
self._guilds.set(guild.id, guild)
presences = {p["user"]["id"]: p for p in data.get("presences", [])}
voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])}
for ch_data in data.get("channels", []):
self.parse_channel(ch_data)
for member_data in data.get("members", []):
user_id = member_data.get("user", {}).get("id")
if user_id:
presence = presences.get(user_id)
if presence:
member_data["status"] = presence.get("status", "offline")
voice_state = voice_states.get(user_id)
if voice_state:
member_data["voice_state"] = voice_state
self.parse_member(member_data, guild.id)
return guild
async def fetch_roles(self, guild_id: Snowflake) -> List["Role"]:
"""Fetches all roles for a given guild and caches them.
If the guild is not cached, it will be retrieved first using
:meth:`fetch_guild`.
"""
if self._closed:
raise DisagreementException("Client is closed.")
guild = self.get_guild(guild_id)
if not guild:
guild = await self.fetch_guild(guild_id)
if not guild:
return []
try:
roles_data = await self._http.get_guild_roles(guild_id)
parsed_roles = []
for role_data in roles_data:
# parse_role will add/update it in the guild.roles list
parsed_roles.append(self.parse_role(role_data, guild_id))
guild.roles = parsed_roles # Replace the entire list with the fresh one
return parsed_roles
except DisagreementException as e:
print(f"Failed to fetch roles for guild {guild_id}: {e}")
return []
async def fetch_role(
self, guild_id: Snowflake, role_id: Snowflake
) -> Optional["Role"]:
"""Fetches a specific role from a guild by ID.
If roles for the guild aren't cached or might be stale, it fetches all roles first.
"""
guild = self.get_guild(guild_id)
if guild:
# Try to find in existing guild.roles
for role in guild.roles:
if role.id == role_id:
return role
# If not found in cache or guild doesn't exist yet in cache, fetch all roles for the guild
await self.fetch_roles(guild_id) # This will populate/update guild.roles
# Try again from the now (hopefully) populated cache
guild = self.get_guild(
guild_id
) # Re-get guild in case it was populated by fetch_roles
if guild:
for role in guild.roles:
if role.id == role_id:
return role
return None # Role not found even after fetching
# --- API Methods ---
# --- API Methods ---
async def send_message(
self,
channel_id: str,
content: Optional[str] = None,
*, # Make additional params keyword-only
tts: bool = False,
embed: Optional["Embed"] = None,
embeds: Optional[List["Embed"]] = None,
components: Optional[List["ActionRow"]] = None,
allowed_mentions: Optional[Dict[str, Any]] = None,
message_reference: Optional[Dict[str, Any]] = None,
attachments: Optional[List[Any]] = None,
files: Optional[List[Any]] = None,
flags: Optional[int] = None,
view: Optional["View"] = None,
) -> "Message":
"""|coro|
Sends a message to the specified channel.
Args:
channel_id (str): The ID of the channel to send the message to.
content (Optional[str]): The content of the message.
tts (bool): Whether the message should be sent with text-to-speech. Defaults to False.
embed (Optional[Embed]): A single embed to send. Cannot be used with `embeds`.
embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`.
Discord supports up to 10 embeds per message.
components (Optional[List[ActionRow]]): A list of ActionRow components to include.
allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. Defaults to :attr:`Client.allowed_mentions`.
message_reference (Optional[Dict[str, Any]]): Message reference for replying.
attachments (Optional[List[Any]]): Attachments to include with the message.
files (Optional[List[Any]]): Files to upload with the message.
flags (Optional[int]): Message flags.
view (Optional[View]): A view to send with the message.
Returns:
Message: The message that was sent.
Raises:
HTTPException: Sending the message failed.
ValueError: If both `embed` and `embeds` are provided, or if both `components` and `view` are provided.
"""
if self._closed:
raise DisagreementException("Client is closed.")
if embed and embeds:
raise ValueError("Cannot provide both embed and embeds.")
if components and view:
raise ValueError("Cannot provide both 'components' and 'view'.")
final_embeds_payload: Optional[List[Dict[str, Any]]] = None
if embed:
final_embeds_payload = [embed.to_dict()]
elif embeds:
from .models import (
Embed as EmbedModel,
)
final_embeds_payload = [
e.to_dict() for e in embeds if isinstance(e, EmbedModel)
]
components_payload: Optional[List[Dict[str, Any]]] = None
if view:
await view._start(self)
components_payload = view.to_components_payload()
elif components:
from .models import Component as ComponentModel
components_payload = [
comp.to_dict()
for comp in components
if isinstance(comp, ComponentModel)
]
if allowed_mentions is None:
allowed_mentions = self.allowed_mentions
message_data = await self._http.send_message(
channel_id=channel_id,
content=content,
tts=tts,
embeds=final_embeds_payload,
components=components_payload,
allowed_mentions=allowed_mentions,
message_reference=message_reference,
attachments=attachments,
files=files,
flags=flags,
)
if view:
message_id = message_data["id"]
view.message_id = message_id
self._views[message_id] = view
return self.parse_message(message_data)
async def create_dm(self, user_id: Snowflake) -> "DMChannel":
"""|coro| Create or fetch a DM channel with a user."""
from .models import DMChannel
dm_data = await self._http.create_dm(user_id)
return cast(DMChannel, self.parse_channel(dm_data))
async def send_dm(
self,
user_id: Snowflake,
content: Optional[str] = None,
**kwargs: Any,
) -> "Message":
"""|coro| Convenience method to send a direct message to a user."""
channel = await self.create_dm(user_id)
return await self.send_message(channel.id, content=content, **kwargs)
def typing(self, channel_id: str) -> Typing:
"""Return a context manager to show a typing indicator in a channel."""
return Typing(self, channel_id)
async def join_voice(
self,
guild_id: Snowflake,
channel_id: Snowflake,
*,
self_mute: bool = False,
self_deaf: bool = False,
) -> VoiceClient:
"""|coro| Join a voice channel and return a :class:`VoiceClient`."""
if self._closed:
raise DisagreementException("Client is closed.")
if not self.is_ready():
await self.wait_until_ready()
if self._gateway is None:
raise DisagreementException("Gateway is not connected.")
if not self.user:
raise DisagreementException("Client user unavailable.")
assert self.user is not None
user_id = self.user.id
if guild_id in self._voice_clients:
return self._voice_clients[guild_id]
payload = {
"op": GatewayOpcode.VOICE_STATE_UPDATE,
"d": {
"guild_id": str(guild_id),
"channel_id": str(channel_id),
"self_mute": self_mute,
"self_deaf": self_deaf,
},
}
await self._gateway._send_json(payload) # type: ignore[attr-defined]
server = await self.wait_for(
"VOICE_SERVER_UPDATE",
check=lambda d: d.get("guild_id") == str(guild_id),
timeout=10,
)
state = await self.wait_for(
"VOICE_STATE_UPDATE",
check=lambda d, uid=user_id: d.get("guild_id") == str(guild_id)
and d.get("user_id") == str(uid),
timeout=10,
)
endpoint = f"wss://{server['endpoint']}?v=10"
token = server["token"]
session_id = state["session_id"]
voice = VoiceClient(
self,
endpoint,
session_id,
token,
int(guild_id),
int(self.user.id),
verbose=self.verbose,
)
await voice.connect()
self._voice_clients[guild_id] = voice
return voice
async def add_reaction(self, channel_id: str, message_id: str, emoji: str) -> None:
"""|coro| Add a reaction to a message."""
await self.create_reaction(channel_id, message_id, emoji)
async def remove_reaction(
self, channel_id: str, message_id: str, emoji: str
) -> None:
"""|coro| Remove the bot's reaction from a message."""
await self.delete_reaction(channel_id, message_id, emoji)
async def clear_reactions(self, channel_id: str, message_id: str) -> None:
"""|coro| Remove all reactions from a message."""
if self._closed:
raise DisagreementException("Client is closed.")
await self._http.clear_reactions(channel_id, message_id)
async def create_reaction(
self, channel_id: str, message_id: str, emoji: str
) -> None:
"""|coro| Add a reaction to a message."""
if self._closed:
raise DisagreementException("Client is closed.")
await self._http.create_reaction(channel_id, message_id, emoji)
user_id = getattr(getattr(self, "user", None), "id", None)
payload = {
"user_id": user_id,
"channel_id": channel_id,
"message_id": message_id,
"emoji": {"name": emoji, "id": None},
}
if hasattr(self, "_event_dispatcher"):
await self._event_dispatcher.dispatch("MESSAGE_REACTION_ADD", payload)
async def delete_reaction(
self, channel_id: str, message_id: str, emoji: str
) -> None:
"""|coro| Remove the bot's reaction from a message."""
if self._closed:
raise DisagreementException("Client is closed.")
await self._http.delete_reaction(channel_id, message_id, emoji)
user_id = getattr(getattr(self, "user", None), "id", None)
payload = {
"user_id": user_id,
"channel_id": channel_id,
"message_id": message_id,
"emoji": {"name": emoji, "id": None},
}
if hasattr(self, "_event_dispatcher"):
await self._event_dispatcher.dispatch("MESSAGE_REACTION_REMOVE", payload)
async def get_reactions(
self, channel_id: str, message_id: str, emoji: str
) -> List["User"]:
"""|coro| Return the users who reacted with the given emoji."""
if self._closed:
raise DisagreementException("Client is closed.")
users_data = await self._http.get_reactions(channel_id, message_id, emoji)
return [self.parse_user(u) for u in users_data]
async def edit_message(
self,
channel_id: str,
message_id: str,
*,
content: Optional[str] = None,
embed: Optional["Embed"] = None,
embeds: Optional[List["Embed"]] = None,
components: Optional[List["ActionRow"]] = None,
allowed_mentions: Optional[Dict[str, Any]] = None,
flags: Optional[int] = None,
view: Optional["View"] = None,
) -> "Message":
"""Edits a previously sent message."""
if self._closed:
raise DisagreementException("Client is closed.")
if embed and embeds:
raise ValueError("Cannot provide both embed and embeds.")
if components and view:
raise ValueError("Cannot provide both 'components' and 'view'.")
final_embeds_payload: Optional[List[Dict[str, Any]]] = None
if embed:
final_embeds_payload = [embed.to_dict()]
elif embeds:
final_embeds_payload = [e.to_dict() for e in embeds]
components_payload: Optional[List[Dict[str, Any]]] = None
if view:
await view._start(self)
components_payload = view.to_components_payload()
elif components:
components_payload = [c.to_dict() for c in components]
payload: Dict[str, Any] = {}
if content is not None:
payload["content"] = content
if final_embeds_payload is not None:
payload["embeds"] = final_embeds_payload
if components_payload is not None:
payload["components"] = components_payload
if allowed_mentions is not None:
payload["allowed_mentions"] = allowed_mentions
if flags is not None:
payload["flags"] = flags
message_data = await self._http.edit_message(
channel_id=channel_id,
message_id=message_id,
payload=payload,
)
if view:
view.message_id = message_data["id"]
self._views[message_data["id"]] = view
return self.parse_message(message_data)
def get_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
"""Returns a guild from the internal cache.
Use :meth:`fetch_guild` to retrieve it from Discord if it's not cached.
"""
return self._guilds.get(guild_id)
def get_channel(self, channel_id: Snowflake) -> Optional["Channel"]:
"""Returns a channel from the internal cache."""
return self._channels.get(channel_id)
def get_message(self, message_id: Snowflake) -> Optional["Message"]:
"""Returns a message from the internal cache."""
return self._messages.get(message_id)
def get_all_channels(self) -> List["Channel"]:
"""Return all channels cached in every guild."""
channels: List["Channel"] = []
for guild in self._guilds.values():
channels.extend(guild._channels.values())
return channels
def get_all_members(self) -> List["Member"]:
"""Return all cached members across all guilds.
When member caching is disabled via :class:`MemberCacheFlags.none`, this
list will always be empty.
"""
members: List["Member"] = []
for guild in self._guilds.values():
members.extend(guild._members.values())
return members
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
"""Fetches a guild by ID from Discord and caches it."""
if self._closed:
raise DisagreementException("Client is closed.")
cached_guild = self._guilds.get(guild_id)
if cached_guild:
return cached_guild
try:
guild_data = await self._http.get_guild(guild_id)
return self.parse_guild(guild_data)
except DisagreementException as e:
print(f"Failed to fetch guild {guild_id}: {e}")
return None
async def fetch_guilds(self) -> List["Guild"]:
"""Fetch all guilds the current user is in."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.get_current_user_guilds()
guilds: List["Guild"] = []
for guild_data in data:
guilds.append(self.parse_guild(guild_data))
return guilds
async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]:
"""Fetches a channel from Discord by its ID and updates the cache."""
if self._closed:
raise DisagreementException("Client is closed.")
try:
channel_data = await self._http.get_channel(channel_id)
if not channel_data:
return None
from .models import channel_factory
channel = channel_factory(channel_data, self)
self._channels.set(channel.id, channel)
return channel
except DisagreementException as e: # Includes HTTPException
print(f"Failed to fetch channel {channel_id}: {e}")
return None
async def fetch_audit_logs(
self, guild_id: Snowflake, **filters: Any
) -> AsyncIterator["AuditLogEntry"]:
"""Fetch audit log entries for a guild."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.get_audit_logs(guild_id, **filters)
for entry in data.get("audit_log_entries", []):
yield self.parse_audit_log_entry(entry)
async def fetch_voice_regions(self) -> List[VoiceRegion]:
"""Fetches available voice regions."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.get_voice_regions()
regions = []
for region in data:
region_id = region.get("id")
if region_id:
regions.append(VoiceRegion(region_id))
return regions
async def create_webhook(
self, channel_id: Snowflake, payload: Dict[str, Any]
) -> "Webhook":
"""|coro| Create a webhook in the given channel."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.create_webhook(channel_id, payload)
return self.parse_webhook(data)
async def edit_webhook(
self, webhook_id: Snowflake, payload: Dict[str, Any]
) -> "Webhook":
"""|coro| Edit an existing webhook."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.edit_webhook(webhook_id, payload)
return self.parse_webhook(data)
async def delete_webhook(self, webhook_id: Snowflake) -> None:
"""|coro| Delete a webhook by ID."""
if self._closed:
raise DisagreementException("Client is closed.")
await self._http.delete_webhook(webhook_id)
async def fetch_webhook(self, webhook_id: Snowflake) -> Optional["Webhook"]:
"""|coro| Fetch a webhook by ID."""
if self._closed:
raise DisagreementException("Client is closed.")
cached = self._webhooks.get(webhook_id)
if cached:
return cached
try:
data = await self._http.get_webhook(webhook_id)
return self.parse_webhook(data)
except DisagreementException as e:
print(f"Failed to fetch webhook {webhook_id}: {e}")
return None
async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]:
"""|coro| Fetch all templates for a guild."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.get_guild_templates(guild_id)
return [self.parse_template(t) for t in data]
async def create_template(
self, guild_id: Snowflake, payload: Dict[str, Any]
) -> "GuildTemplate":
"""|coro| Create a template for a guild."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.create_guild_template(guild_id, payload)
return self.parse_template(data)
async def sync_template(
self, guild_id: Snowflake, template_code: str
) -> "GuildTemplate":
"""|coro| Sync a template to the guild's current state."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.sync_guild_template(guild_id, template_code)
return self.parse_template(data)
async def delete_template(self, guild_id: Snowflake, template_code: str) -> None:
"""|coro| Delete a guild template."""
if self._closed:
raise DisagreementException("Client is closed.")
await self._http.delete_guild_template(guild_id, template_code)
async def fetch_widget(self, guild_id: Snowflake) -> Dict[str, Any]:
"""|coro| Fetch a guild's widget settings."""
if self._closed:
raise DisagreementException("Client is closed.")
return await self._http.get_guild_widget(guild_id)
async def edit_widget(
self, guild_id: Snowflake, payload: Dict[str, Any]
) -> Dict[str, Any]:
"""|coro| Edit a guild's widget settings."""
if self._closed:
raise DisagreementException("Client is closed.")
return await self._http.edit_guild_widget(guild_id, payload)
async def fetch_scheduled_events(
self, guild_id: Snowflake
) -> List["ScheduledEvent"]:
"""|coro| Fetch all scheduled events for a guild."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.get_guild_scheduled_events(guild_id)
return [self.parse_scheduled_event(ev) for ev in data]
async def fetch_scheduled_event(
self, guild_id: Snowflake, event_id: Snowflake
) -> Optional["ScheduledEvent"]:
"""|coro| Fetch a single scheduled event."""
if self._closed:
raise DisagreementException("Client is closed.")
try:
data = await self._http.get_guild_scheduled_event(guild_id, event_id)
return self.parse_scheduled_event(data)
except DisagreementException as e:
print(f"Failed to fetch scheduled event {event_id}: {e}")
return None
async def create_scheduled_event(
self, guild_id: Snowflake, payload: Dict[str, Any]
) -> "ScheduledEvent":
"""|coro| Create a scheduled event in a guild."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.create_guild_scheduled_event(guild_id, payload)
return self.parse_scheduled_event(data)
async def edit_scheduled_event(
self, guild_id: Snowflake, event_id: Snowflake, payload: Dict[str, Any]
) -> "ScheduledEvent":
"""|coro| Edit an existing scheduled event."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.edit_guild_scheduled_event(guild_id, event_id, payload)
return self.parse_scheduled_event(data)
async def delete_scheduled_event(
self, guild_id: Snowflake, event_id: Snowflake
) -> None:
"""|coro| Delete a scheduled event."""
if self._closed:
raise DisagreementException("Client is closed.")
await self._http.delete_guild_scheduled_event(guild_id, event_id)
async def create_invite(
self, channel_id: Snowflake, payload: Dict[str, Any]
) -> "Invite":
"""|coro| Create an invite for the given channel."""
if self._closed:
raise DisagreementException("Client is closed.")
return await self._http.create_invite(channel_id, payload)
async def delete_invite(self, code: str) -> None:
"""|coro| Delete an invite by code."""
if self._closed:
raise DisagreementException("Client is closed.")
await self._http.delete_invite(code)
async def fetch_invite(self, code: Snowflake) -> Optional["Invite"]:
"""|coro| Fetch a single invite by code."""
if self._closed:
raise DisagreementException("Client is closed.")
try:
data = await self._http.get_invite(code)
return self.parse_invite(data)
except DisagreementException as e:
print(f"Failed to fetch invite {code}: {e}")
return None
async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]:
"""|coro| Fetch all invites for a channel."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.get_channel_invites(channel_id)
return [self.parse_invite(inv) for inv in data]
def add_persistent_view(self, view: "View") -> None:
"""
Registers a persistent view with the client.
Persistent views have a timeout of `None` and their components must have a `custom_id`.
This allows the view to be re-instantiated across bot restarts.
Args:
view (View): The view instance to register.
Raises:
ValueError: If the view is not persistent (timeout is not None) or if a component's
custom_id is already registered.
"""
if self.is_ready():
print(
"Warning: Adding a persistent view after the client is ready. "
"This view will only be available for interactions on this session."
)
if view.timeout is not None:
raise ValueError("Persistent views must have a timeout of None.")
for item in view.children:
if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views:
raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered."
)
self._persistent_views[item.custom_id] = view
self._save_persistent_views()
# --- Application Command Methods ---
async def process_interaction(self, interaction: Interaction) -> None:
"""Internal method to process an interaction from the gateway."""
if hasattr(self, "on_interaction_create"):
asyncio.create_task(self.on_interaction_create(interaction))
# Route component interactions to the appropriate View
if (
interaction.type == InteractionType.MESSAGE_COMPONENT
and interaction.message
and interaction.data
):
view = self._views.get(interaction.message.id)
if view:
asyncio.create_task(view._dispatch(interaction))
return
else:
# No active view found, check for persistent views
custom_id = interaction.data.custom_id
if custom_id:
registered_view = self._persistent_views.get(custom_id)
if registered_view:
# Create a new instance of the persistent view
new_view = registered_view.__class__()
await new_view._start(self)
new_view.message_id = interaction.message.id
self._views[interaction.message.id] = new_view
asyncio.create_task(new_view._dispatch(interaction))
return
await self.app_command_handler.process_interaction(interaction)
async def sync_application_commands(
self, guild_id: Optional[Snowflake] = None
) -> None:
"""Synchronizes application commands with Discord."""
if not self.application_id:
print(
"Warning: Cannot sync application commands, application_id is not set. "
"Ensure the client is connected and READY."
)
return
await self.app_command_handler.sync_commands(
application_id=self.application_id, guild_id=guild_id
)
async def on_interaction_create(self, interaction: Interaction) -> None:
"""|coro| Called when an interaction is created."""
pass
async def on_presence_update(self, presence) -> None:
"""|coro| Called when a user's presence is updated."""
pass
async def on_typing_start(self, typing) -> None:
"""|coro| Called when a user starts typing in a channel."""
pass
async def on_connect(self) -> None:
"""|coro| Called when the WebSocket connection opens."""
pass
async def on_disconnect(self) -> None:
"""|coro| Called when the WebSocket connection closes."""
pass
async def on_app_command_error(
self, context: AppCommandContext, error: Exception
) -> None:
"""Default error handler for application commands."""
print(
f"Error in application command '{context.command.name if context.command else 'unknown'}': {error}"
)
try:
if not context._responded:
await context.send(
"An error occurred while running this command.", ephemeral=True
)
except Exception as e:
print(f"Failed to send error message for app command: {e}")
async def on_error(
self, event_method: str, exc: Exception, *args: Any, **kwargs: Any
) -> None:
"""Default event listener error handler."""
print(f"Unhandled exception in event listener for '{event_method}':")
print(f"{type(exc).__name__}: {exc}")
class AutoShardedClient(Client):
"""A :class:`Client` that automatically determines the shard count.
If ``shard_count`` is not provided, the client will query the Discord API
via :meth:`HTTPClient.get_gateway_bot` for the recommended shard count and
use that when connecting.
"""
async def connect(self, reconnect: bool = True) -> None: # type: ignore[override]
if self.shard_count is None:
data = await self._http.get_gateway_bot()
self.shard_count = data.get("shards", 1)
await super().connect(reconnect=reconnect)