chore: Apply code formatting across the codebase
This commit applies consistent code formatting to multiple files. No functional changes are included.
This commit is contained in:
parent
17b7ea35a9
commit
0151526d07
@ -614,75 +614,75 @@ class Client:
|
||||
|
||||
return decorator
|
||||
|
||||
def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
|
||||
"""
|
||||
Adds a standalone application command or group to the bot.
|
||||
Use this for commands not defined within a Cog.
|
||||
|
||||
Args:
|
||||
command (Union[AppCommand, AppCommandGroup]): The application command or group instance.
|
||||
This is typically the object returned by a decorator like @slash_command.
|
||||
"""
|
||||
from .ext.app_commands.commands import (
|
||||
AppCommand,
|
||||
AppCommandGroup,
|
||||
) # Ensure types
|
||||
|
||||
if not isinstance(command, (AppCommand, AppCommandGroup)):
|
||||
raise TypeError(
|
||||
"Command must be an instance of AppCommand or AppCommandGroup."
|
||||
)
|
||||
|
||||
# If it's a decorated function, the command object might be on __app_command_object__
|
||||
if hasattr(command, "__app_command_object__") and isinstance(
|
||||
getattr(command, "__app_command_object__"), (AppCommand, AppCommandGroup)
|
||||
):
|
||||
actual_command_obj = getattr(command, "__app_command_object__")
|
||||
self.app_command_handler.add_command(actual_command_obj)
|
||||
print(
|
||||
f"Registered standalone app command/group '{actual_command_obj.name}'."
|
||||
)
|
||||
elif isinstance(
|
||||
command, (AppCommand, AppCommandGroup)
|
||||
): # It's already the command object
|
||||
self.app_command_handler.add_command(command)
|
||||
print(f"Registered standalone app command/group '{command.name}'.")
|
||||
else:
|
||||
# This case should ideally not be hit if type checks are done by decorators
|
||||
print(
|
||||
f"Warning: Could not register app command {command}. It's not a recognized command object or decorated function."
|
||||
)
|
||||
|
||||
async def on_command_error(
|
||||
self, ctx: "CommandContext", error: "CommandError"
|
||||
) -> None:
|
||||
"""
|
||||
Default command error handler. Called when a command raises an error.
|
||||
Users can override this method in a subclass of Client to implement custom error handling.
|
||||
|
||||
Args:
|
||||
ctx (CommandContext): The context of the command that raised the error.
|
||||
error (CommandError): The error that was raised.
|
||||
"""
|
||||
# Default behavior: print to console.
|
||||
# Users might want to send a message to ctx.channel or log to a file.
|
||||
print(
|
||||
f"Error in command '{ctx.command.name if ctx.command else 'unknown'}': {error}"
|
||||
)
|
||||
|
||||
# Need to import CommandInvokeError for this check if not already globally available
|
||||
# For now, assuming it's imported via TYPE_CHECKING or directly if needed at runtime
|
||||
from .ext.commands.errors import (
|
||||
CommandInvokeError as CIE,
|
||||
) # Local import for isinstance check
|
||||
|
||||
if isinstance(error, CIE):
|
||||
# Now it's safe to access error.original
|
||||
print(
|
||||
f"Original exception: {type(error.original).__name__}: {error.original}"
|
||||
)
|
||||
# import traceback
|
||||
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
|
||||
def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
|
||||
"""
|
||||
Adds a standalone application command or group to the bot.
|
||||
Use this for commands not defined within a Cog.
|
||||
|
||||
Args:
|
||||
command (Union[AppCommand, AppCommandGroup]): The application command or group instance.
|
||||
This is typically the object returned by a decorator like @slash_command.
|
||||
"""
|
||||
from .ext.app_commands.commands import (
|
||||
AppCommand,
|
||||
AppCommandGroup,
|
||||
) # Ensure types
|
||||
|
||||
if not isinstance(command, (AppCommand, AppCommandGroup)):
|
||||
raise TypeError(
|
||||
"Command must be an instance of AppCommand or AppCommandGroup."
|
||||
)
|
||||
|
||||
# If it's a decorated function, the command object might be on __app_command_object__
|
||||
if hasattr(command, "__app_command_object__") and isinstance(
|
||||
getattr(command, "__app_command_object__"), (AppCommand, AppCommandGroup)
|
||||
):
|
||||
actual_command_obj = getattr(command, "__app_command_object__")
|
||||
self.app_command_handler.add_command(actual_command_obj)
|
||||
print(
|
||||
f"Registered standalone app command/group '{actual_command_obj.name}'."
|
||||
)
|
||||
elif isinstance(
|
||||
command, (AppCommand, AppCommandGroup)
|
||||
): # It's already the command object
|
||||
self.app_command_handler.add_command(command)
|
||||
print(f"Registered standalone app command/group '{command.name}'.")
|
||||
else:
|
||||
# This case should ideally not be hit if type checks are done by decorators
|
||||
print(
|
||||
f"Warning: Could not register app command {command}. It's not a recognized command object or decorated function."
|
||||
)
|
||||
|
||||
async def on_command_error(
|
||||
self, ctx: "CommandContext", error: "CommandError"
|
||||
) -> None:
|
||||
"""
|
||||
Default command error handler. Called when a command raises an error.
|
||||
Users can override this method in a subclass of Client to implement custom error handling.
|
||||
|
||||
Args:
|
||||
ctx (CommandContext): The context of the command that raised the error.
|
||||
error (CommandError): The error that was raised.
|
||||
"""
|
||||
# Default behavior: print to console.
|
||||
# Users might want to send a message to ctx.channel or log to a file.
|
||||
print(
|
||||
f"Error in command '{ctx.command.name if ctx.command else 'unknown'}': {error}"
|
||||
)
|
||||
|
||||
# Need to import CommandInvokeError for this check if not already globally available
|
||||
# For now, assuming it's imported via TYPE_CHECKING or directly if needed at runtime
|
||||
from .ext.commands.errors import (
|
||||
CommandInvokeError as CIE,
|
||||
) # Local import for isinstance check
|
||||
|
||||
if isinstance(error, CIE):
|
||||
# Now it's safe to access error.original
|
||||
print(
|
||||
f"Original exception: {type(error.original).__name__}: {error.original}"
|
||||
)
|
||||
# import traceback
|
||||
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
|
||||
|
||||
async def on_command_completion(self, ctx: "CommandContext") -> None:
|
||||
"""
|
||||
@ -1551,86 +1551,86 @@ class Client:
|
||||
self._views[interaction.message.id] = new_view
|
||||
asyncio.create_task(new_view._dispatch(interaction))
|
||||
return
|
||||
|
||||
await self.app_command_handler.process_interaction(interaction)
|
||||
|
||||
async def sync_application_commands(
|
||||
self, guild_id: Optional[Snowflake] = None
|
||||
) -> None:
|
||||
"""Synchronizes application commands with Discord."""
|
||||
|
||||
if not self.application_id:
|
||||
print(
|
||||
"Warning: Cannot sync application commands, application_id is not set. "
|
||||
"Ensure the client is connected and READY."
|
||||
)
|
||||
return
|
||||
if not self.is_ready():
|
||||
print(
|
||||
"Warning: Client is not ready. Waiting for client to be ready before syncing commands."
|
||||
)
|
||||
await self.wait_until_ready()
|
||||
if not self.application_id:
|
||||
print(
|
||||
"Error: application_id still not set after client is ready. Cannot sync commands."
|
||||
)
|
||||
return
|
||||
|
||||
await self.app_command_handler.sync_commands(
|
||||
application_id=self.application_id, guild_id=guild_id
|
||||
)
|
||||
|
||||
async def on_interaction_create(self, interaction: Interaction) -> None:
|
||||
"""|coro| Called when an interaction is created."""
|
||||
|
||||
pass
|
||||
|
||||
async def on_presence_update(self, presence) -> None:
|
||||
"""|coro| Called when a user's presence is updated."""
|
||||
|
||||
pass
|
||||
|
||||
async def on_typing_start(self, typing) -> None:
|
||||
"""|coro| Called when a user starts typing in a channel."""
|
||||
|
||||
pass
|
||||
|
||||
async def on_app_command_error(
|
||||
self, context: AppCommandContext, error: Exception
|
||||
) -> None:
|
||||
"""Default error handler for application commands."""
|
||||
|
||||
print(
|
||||
f"Error in application command '{context.command.name if context.command else 'unknown'}': {error}"
|
||||
)
|
||||
try:
|
||||
if not context._responded:
|
||||
await context.send(
|
||||
"An error occurred while running this command.", ephemeral=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to send error message for app command: {e}")
|
||||
|
||||
async def on_error(
|
||||
self, event_method: str, exc: Exception, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Default event listener error handler."""
|
||||
|
||||
print(f"Unhandled exception in event listener for '{event_method}':")
|
||||
print(f"{type(exc).__name__}: {exc}")
|
||||
|
||||
|
||||
class AutoShardedClient(Client):
|
||||
"""A :class:`Client` that automatically determines the shard count.
|
||||
|
||||
If ``shard_count`` is not provided, the client will query the Discord API
|
||||
via :meth:`HTTPClient.get_gateway_bot` for the recommended shard count and
|
||||
use that when connecting.
|
||||
"""
|
||||
|
||||
async def connect(self, reconnect: bool = True) -> None: # type: ignore[override]
|
||||
if self.shard_count is None:
|
||||
data = await self._http.get_gateway_bot()
|
||||
self.shard_count = data.get("shards", 1)
|
||||
|
||||
await super().connect(reconnect=reconnect)
|
||||
|
||||
await self.app_command_handler.process_interaction(interaction)
|
||||
|
||||
async def sync_application_commands(
|
||||
self, guild_id: Optional[Snowflake] = None
|
||||
) -> None:
|
||||
"""Synchronizes application commands with Discord."""
|
||||
|
||||
if not self.application_id:
|
||||
print(
|
||||
"Warning: Cannot sync application commands, application_id is not set. "
|
||||
"Ensure the client is connected and READY."
|
||||
)
|
||||
return
|
||||
if not self.is_ready():
|
||||
print(
|
||||
"Warning: Client is not ready. Waiting for client to be ready before syncing commands."
|
||||
)
|
||||
await self.wait_until_ready()
|
||||
if not self.application_id:
|
||||
print(
|
||||
"Error: application_id still not set after client is ready. Cannot sync commands."
|
||||
)
|
||||
return
|
||||
|
||||
await self.app_command_handler.sync_commands(
|
||||
application_id=self.application_id, guild_id=guild_id
|
||||
)
|
||||
|
||||
async def on_interaction_create(self, interaction: Interaction) -> None:
|
||||
"""|coro| Called when an interaction is created."""
|
||||
|
||||
pass
|
||||
|
||||
async def on_presence_update(self, presence) -> None:
|
||||
"""|coro| Called when a user's presence is updated."""
|
||||
|
||||
pass
|
||||
|
||||
async def on_typing_start(self, typing) -> None:
|
||||
"""|coro| Called when a user starts typing in a channel."""
|
||||
|
||||
pass
|
||||
|
||||
async def on_app_command_error(
|
||||
self, context: AppCommandContext, error: Exception
|
||||
) -> None:
|
||||
"""Default error handler for application commands."""
|
||||
|
||||
print(
|
||||
f"Error in application command '{context.command.name if context.command else 'unknown'}': {error}"
|
||||
)
|
||||
try:
|
||||
if not context._responded:
|
||||
await context.send(
|
||||
"An error occurred while running this command.", ephemeral=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to send error message for app command: {e}")
|
||||
|
||||
async def on_error(
|
||||
self, event_method: str, exc: Exception, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Default event listener error handler."""
|
||||
|
||||
print(f"Unhandled exception in event listener for '{event_method}':")
|
||||
print(f"{type(exc).__name__}: {exc}")
|
||||
|
||||
|
||||
class AutoShardedClient(Client):
|
||||
"""A :class:`Client` that automatically determines the shard count.
|
||||
|
||||
If ``shard_count`` is not provided, the client will query the Discord API
|
||||
via :meth:`HTTPClient.get_gateway_bot` for the recommended shard count and
|
||||
use that when connecting.
|
||||
"""
|
||||
|
||||
async def connect(self, reconnect: bool = True) -> None: # type: ignore[override]
|
||||
if self.shard_count is None:
|
||||
data = await self._http.get_gateway_bot()
|
||||
self.shard_count = data.get("shards", 1)
|
||||
|
||||
await super().connect(reconnect=reconnect)
|
||||
|
@ -1,65 +1,65 @@
|
||||
# disagreement/ext/commands/__init__.py
|
||||
|
||||
"""
|
||||
disagreement.ext.commands - A command framework extension for the Disagreement library.
|
||||
"""
|
||||
|
||||
from .cog import Cog
|
||||
from .core import (
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandHandler,
|
||||
) # CommandHandler might be internal
|
||||
from .decorators import (
|
||||
command,
|
||||
listener,
|
||||
check,
|
||||
check_any,
|
||||
cooldown,
|
||||
max_concurrency,
|
||||
requires_permissions,
|
||||
# disagreement/ext/commands/__init__.py
|
||||
|
||||
"""
|
||||
disagreement.ext.commands - A command framework extension for the Disagreement library.
|
||||
"""
|
||||
|
||||
from .cog import Cog
|
||||
from .core import (
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandHandler,
|
||||
) # CommandHandler might be internal
|
||||
from .decorators import (
|
||||
command,
|
||||
listener,
|
||||
check,
|
||||
check_any,
|
||||
cooldown,
|
||||
max_concurrency,
|
||||
requires_permissions,
|
||||
has_role,
|
||||
has_any_role,
|
||||
)
|
||||
from .errors import (
|
||||
CommandError,
|
||||
CommandNotFound,
|
||||
BadArgument,
|
||||
MissingRequiredArgument,
|
||||
ArgumentParsingError,
|
||||
CheckFailure,
|
||||
CheckAnyFailure,
|
||||
CommandOnCooldown,
|
||||
CommandInvokeError,
|
||||
MaxConcurrencyReached,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Cog
|
||||
"Cog",
|
||||
# Core
|
||||
"Command",
|
||||
"CommandContext",
|
||||
# "CommandHandler", # Usually not part of public API for direct use by bot devs
|
||||
# Decorators
|
||||
"command",
|
||||
"listener",
|
||||
"check",
|
||||
"check_any",
|
||||
"cooldown",
|
||||
"max_concurrency",
|
||||
"requires_permissions",
|
||||
)
|
||||
from .errors import (
|
||||
CommandError,
|
||||
CommandNotFound,
|
||||
BadArgument,
|
||||
MissingRequiredArgument,
|
||||
ArgumentParsingError,
|
||||
CheckFailure,
|
||||
CheckAnyFailure,
|
||||
CommandOnCooldown,
|
||||
CommandInvokeError,
|
||||
MaxConcurrencyReached,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Cog
|
||||
"Cog",
|
||||
# Core
|
||||
"Command",
|
||||
"CommandContext",
|
||||
# "CommandHandler", # Usually not part of public API for direct use by bot devs
|
||||
# Decorators
|
||||
"command",
|
||||
"listener",
|
||||
"check",
|
||||
"check_any",
|
||||
"cooldown",
|
||||
"max_concurrency",
|
||||
"requires_permissions",
|
||||
"has_role",
|
||||
"has_any_role",
|
||||
# Errors
|
||||
"CommandError",
|
||||
"CommandNotFound",
|
||||
"BadArgument",
|
||||
"MissingRequiredArgument",
|
||||
"ArgumentParsingError",
|
||||
"CheckFailure",
|
||||
"CheckAnyFailure",
|
||||
"CommandOnCooldown",
|
||||
"CommandInvokeError",
|
||||
"MaxConcurrencyReached",
|
||||
]
|
||||
# Errors
|
||||
"CommandError",
|
||||
"CommandNotFound",
|
||||
"BadArgument",
|
||||
"MissingRequiredArgument",
|
||||
"ArgumentParsingError",
|
||||
"CheckFailure",
|
||||
"CheckAnyFailure",
|
||||
"CommandOnCooldown",
|
||||
"CommandInvokeError",
|
||||
"MaxConcurrencyReached",
|
||||
]
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,222 +1,222 @@
|
||||
# disagreement/ext/commands/decorators.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import Command, CommandContext
|
||||
from disagreement.permissions import Permissions
|
||||
from disagreement.models import Member, Guild, Channel
|
||||
|
||||
|
||||
def command(
|
||||
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
|
||||
) -> Callable:
|
||||
"""
|
||||
A decorator that transforms a function into a Command.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): The name of the command. Defaults to the function name.
|
||||
aliases (Optional[List[str]]): Alternative names for the command.
|
||||
**attrs: Additional attributes to pass to the Command constructor
|
||||
(e.g., brief, description, hidden).
|
||||
|
||||
Returns:
|
||||
Callable: A decorator that registers the command.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Command callback must be a coroutine function.")
|
||||
|
||||
from .core import Command
|
||||
|
||||
cmd_name = name or func.__name__
|
||||
|
||||
if hasattr(func, "__command_attrs__"):
|
||||
raise TypeError("Function is already a command or has command attributes.")
|
||||
|
||||
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
||||
func.__command_object__ = cmd # type: ignore
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def listener(
|
||||
name: Optional[str] = None,
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""
|
||||
A decorator that marks a function as an event listener within a Cog.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Listener callback must be a coroutine function.")
|
||||
|
||||
actual_event_name = name or func.__name__
|
||||
setattr(func, "__listener_name__", actual_event_name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check(
|
||||
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator to add a check to a command."""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
checks = getattr(func, "__command_checks__", [])
|
||||
checks.append(predicate)
|
||||
setattr(func, "__command_checks__", checks)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_any(
|
||||
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator that passes if any predicate returns ``True``."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckAnyFailure, CheckFailure
|
||||
|
||||
errors = []
|
||||
for p in predicates:
|
||||
try:
|
||||
result = p(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if result:
|
||||
return True
|
||||
except CheckFailure as e:
|
||||
errors.append(e)
|
||||
raise CheckAnyFailure(errors)
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def max_concurrency(
|
||||
number: int, per: str = "user"
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Limit how many concurrent invocations of a command are allowed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
number:
|
||||
The maximum number of concurrent invocations.
|
||||
per:
|
||||
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
|
||||
"""
|
||||
|
||||
if number < 1:
|
||||
raise ValueError("Concurrency number must be at least 1.")
|
||||
if per not in {"user", "guild", "global"}:
|
||||
raise ValueError("per must be 'user', 'guild', or 'global'.")
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
setattr(func, "__max_concurrency__", (number, per))
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def cooldown(
|
||||
rate: int, per: float
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Simple per-user cooldown decorator."""
|
||||
|
||||
buckets: dict[str, dict[str, float]] = {}
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CommandOnCooldown
|
||||
|
||||
now = time.monotonic()
|
||||
user_buckets = buckets.setdefault(ctx.command.name, {})
|
||||
reset = user_buckets.get(ctx.author.id, 0)
|
||||
if now < reset:
|
||||
raise CommandOnCooldown(reset - now)
|
||||
user_buckets[ctx.author.id] = now + per
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def _compute_permissions(
|
||||
member: "Member", channel: "Channel", guild: "Guild"
|
||||
) -> "Permissions":
|
||||
"""Compute the effective permissions for a member in a channel."""
|
||||
return channel.permissions_for(member)
|
||||
|
||||
|
||||
def requires_permissions(
|
||||
*perms: "Permissions",
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Check that the invoking member has the given permissions in the channel."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
from disagreement.permissions import (
|
||||
has_permissions,
|
||||
missing_permissions,
|
||||
)
|
||||
from disagreement.models import Member
|
||||
|
||||
channel = getattr(ctx, "channel", None)
|
||||
if channel is None and hasattr(ctx.bot, "get_channel"):
|
||||
channel = ctx.bot.get_channel(ctx.message.channel_id)
|
||||
if channel is None and hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
|
||||
|
||||
if channel is None:
|
||||
raise CheckFailure("Channel for permission check not found.")
|
||||
|
||||
guild = getattr(channel, "guild", None)
|
||||
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
|
||||
if hasattr(ctx.bot, "get_guild"):
|
||||
guild = ctx.bot.get_guild(channel.guild_id)
|
||||
if not guild and hasattr(ctx.bot, "fetch_guild"):
|
||||
guild = await ctx.bot.fetch_guild(channel.guild_id)
|
||||
|
||||
if not guild:
|
||||
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
|
||||
if is_dm:
|
||||
if perms:
|
||||
raise CheckFailure("Permission checks are not supported in DMs.")
|
||||
return True
|
||||
raise CheckFailure("Guild for permission check not found.")
|
||||
|
||||
member = ctx.author
|
||||
if not isinstance(member, Member):
|
||||
member = guild.get_member(ctx.author.id)
|
||||
if not member and hasattr(ctx.bot, "fetch_member"):
|
||||
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
|
||||
|
||||
if not member:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
perms_value = _compute_permissions(member, channel, guild)
|
||||
|
||||
if not has_permissions(perms_value, *perms):
|
||||
missing = missing_permissions(perms_value, *perms)
|
||||
missing_names = ", ".join(p.name for p in missing if p.name)
|
||||
raise CheckFailure(f"Missing permissions: {missing_names}")
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
# disagreement/ext/commands/decorators.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import Command, CommandContext
|
||||
from disagreement.permissions import Permissions
|
||||
from disagreement.models import Member, Guild, Channel
|
||||
|
||||
|
||||
def command(
|
||||
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
|
||||
) -> Callable:
|
||||
"""
|
||||
A decorator that transforms a function into a Command.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): The name of the command. Defaults to the function name.
|
||||
aliases (Optional[List[str]]): Alternative names for the command.
|
||||
**attrs: Additional attributes to pass to the Command constructor
|
||||
(e.g., brief, description, hidden).
|
||||
|
||||
Returns:
|
||||
Callable: A decorator that registers the command.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Command callback must be a coroutine function.")
|
||||
|
||||
from .core import Command
|
||||
|
||||
cmd_name = name or func.__name__
|
||||
|
||||
if hasattr(func, "__command_attrs__"):
|
||||
raise TypeError("Function is already a command or has command attributes.")
|
||||
|
||||
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
||||
func.__command_object__ = cmd # type: ignore
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def listener(
|
||||
name: Optional[str] = None,
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""
|
||||
A decorator that marks a function as an event listener within a Cog.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Listener callback must be a coroutine function.")
|
||||
|
||||
actual_event_name = name or func.__name__
|
||||
setattr(func, "__listener_name__", actual_event_name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check(
|
||||
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator to add a check to a command."""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
checks = getattr(func, "__command_checks__", [])
|
||||
checks.append(predicate)
|
||||
setattr(func, "__command_checks__", checks)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_any(
|
||||
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator that passes if any predicate returns ``True``."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckAnyFailure, CheckFailure
|
||||
|
||||
errors = []
|
||||
for p in predicates:
|
||||
try:
|
||||
result = p(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if result:
|
||||
return True
|
||||
except CheckFailure as e:
|
||||
errors.append(e)
|
||||
raise CheckAnyFailure(errors)
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def max_concurrency(
|
||||
number: int, per: str = "user"
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Limit how many concurrent invocations of a command are allowed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
number:
|
||||
The maximum number of concurrent invocations.
|
||||
per:
|
||||
The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
|
||||
"""
|
||||
|
||||
if number < 1:
|
||||
raise ValueError("Concurrency number must be at least 1.")
|
||||
if per not in {"user", "guild", "global"}:
|
||||
raise ValueError("per must be 'user', 'guild', or 'global'.")
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
setattr(func, "__max_concurrency__", (number, per))
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def cooldown(
|
||||
rate: int, per: float
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Simple per-user cooldown decorator."""
|
||||
|
||||
buckets: dict[str, dict[str, float]] = {}
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CommandOnCooldown
|
||||
|
||||
now = time.monotonic()
|
||||
user_buckets = buckets.setdefault(ctx.command.name, {})
|
||||
reset = user_buckets.get(ctx.author.id, 0)
|
||||
if now < reset:
|
||||
raise CommandOnCooldown(reset - now)
|
||||
user_buckets[ctx.author.id] = now + per
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def _compute_permissions(
|
||||
member: "Member", channel: "Channel", guild: "Guild"
|
||||
) -> "Permissions":
|
||||
"""Compute the effective permissions for a member in a channel."""
|
||||
return channel.permissions_for(member)
|
||||
|
||||
|
||||
def requires_permissions(
|
||||
*perms: "Permissions",
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Check that the invoking member has the given permissions in the channel."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
from disagreement.permissions import (
|
||||
has_permissions,
|
||||
missing_permissions,
|
||||
)
|
||||
from disagreement.models import Member
|
||||
|
||||
channel = getattr(ctx, "channel", None)
|
||||
if channel is None and hasattr(ctx.bot, "get_channel"):
|
||||
channel = ctx.bot.get_channel(ctx.message.channel_id)
|
||||
if channel is None and hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
|
||||
|
||||
if channel is None:
|
||||
raise CheckFailure("Channel for permission check not found.")
|
||||
|
||||
guild = getattr(channel, "guild", None)
|
||||
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
|
||||
if hasattr(ctx.bot, "get_guild"):
|
||||
guild = ctx.bot.get_guild(channel.guild_id)
|
||||
if not guild and hasattr(ctx.bot, "fetch_guild"):
|
||||
guild = await ctx.bot.fetch_guild(channel.guild_id)
|
||||
|
||||
if not guild:
|
||||
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
|
||||
if is_dm:
|
||||
if perms:
|
||||
raise CheckFailure("Permission checks are not supported in DMs.")
|
||||
return True
|
||||
raise CheckFailure("Guild for permission check not found.")
|
||||
|
||||
member = ctx.author
|
||||
if not isinstance(member, Member):
|
||||
member = guild.get_member(ctx.author.id)
|
||||
if not member and hasattr(ctx.bot, "fetch_member"):
|
||||
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
|
||||
|
||||
if not member:
|
||||
raise CheckFailure("Could not resolve author to a guild member.")
|
||||
|
||||
perms_value = _compute_permissions(member, channel, guild)
|
||||
|
||||
if not has_permissions(perms_value, *perms):
|
||||
missing = missing_permissions(perms_value, *perms)
|
||||
missing_names = ", ".join(p.name for p in missing if p.name)
|
||||
raise CheckFailure(f"Missing permissions: {missing_names}")
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
def has_role(
|
||||
name_or_id: str | int,
|
||||
|
File diff suppressed because it is too large
Load Diff
2082
disagreement/http.py
2082
disagreement/http.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,167 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from ..models import ActionRow
|
||||
from .item import Item
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import Client
|
||||
from ..interactions import Interaction
|
||||
|
||||
|
||||
class View:
|
||||
"""Represents a container for UI components that can be sent with a message.
|
||||
|
||||
Args:
|
||||
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
|
||||
Defaults to 180.
|
||||
"""
|
||||
|
||||
def __init__(self, *, timeout: Optional[float] = 180.0):
|
||||
self.timeout = timeout
|
||||
self.id = str(uuid.uuid4())
|
||||
self.__children: List[Item] = []
|
||||
self.__stopped = asyncio.Event()
|
||||
self._client: Optional[Client] = None
|
||||
self._message_id: Optional[str] = None
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from ..models import ActionRow
|
||||
from .item import Item
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import Client
|
||||
from ..interactions import Interaction
|
||||
|
||||
|
||||
class View:
|
||||
"""Represents a container for UI components that can be sent with a message.
|
||||
|
||||
Args:
|
||||
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
|
||||
Defaults to 180.
|
||||
"""
|
||||
|
||||
def __init__(self, *, timeout: Optional[float] = 180.0):
|
||||
self.timeout = timeout
|
||||
self.id = str(uuid.uuid4())
|
||||
self.__children: List[Item] = []
|
||||
self.__stopped = asyncio.Event()
|
||||
self._client: Optional[Client] = None
|
||||
self._message_id: Optional[str] = None
|
||||
|
||||
# The below is a bit of a hack to support items defined as class members
|
||||
# e.g. button = Button(...)
|
||||
for item in self.__class__.__dict__.values():
|
||||
if isinstance(item, Item):
|
||||
self.add_item(item)
|
||||
|
||||
@property
|
||||
def children(self) -> List[Item]:
|
||||
return self.__children
|
||||
|
||||
def add_item(self, item: Item):
|
||||
"""Adds an item to the view."""
|
||||
if not isinstance(item, Item):
|
||||
raise TypeError("Only instances of 'Item' can be added to a View.")
|
||||
|
||||
if len(self.__children) >= 25:
|
||||
raise ValueError("A view can only have a maximum of 25 components.")
|
||||
|
||||
for item in self.__class__.__dict__.values():
|
||||
if isinstance(item, Item):
|
||||
self.add_item(item)
|
||||
|
||||
@property
|
||||
def children(self) -> List[Item]:
|
||||
return self.__children
|
||||
|
||||
def add_item(self, item: Item):
|
||||
"""Adds an item to the view."""
|
||||
if not isinstance(item, Item):
|
||||
raise TypeError("Only instances of 'Item' can be added to a View.")
|
||||
|
||||
if len(self.__children) >= 25:
|
||||
raise ValueError("A view can only have a maximum of 25 components.")
|
||||
|
||||
if self.timeout is None and item.custom_id is None:
|
||||
raise ValueError(
|
||||
"All components in a persistent view must have a 'custom_id'."
|
||||
)
|
||||
|
||||
item._view = self
|
||||
self.__children.append(item)
|
||||
|
||||
@property
|
||||
def message_id(self) -> Optional[str]:
|
||||
return self._message_id
|
||||
|
||||
@message_id.setter
|
||||
def message_id(self, value: str):
|
||||
self._message_id = value
|
||||
|
||||
def to_components(self) -> List[ActionRow]:
|
||||
"""Converts the view's children into a list of ActionRow components.
|
||||
|
||||
This retains the original, simple layout behaviour where each item is
|
||||
placed in its own :class:`ActionRow` to ensure backward compatibility.
|
||||
"""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
item._view = self
|
||||
self.__children.append(item)
|
||||
|
||||
@property
|
||||
def message_id(self) -> Optional[str]:
|
||||
return self._message_id
|
||||
|
||||
@message_id.setter
|
||||
def message_id(self, value: str):
|
||||
self._message_id = value
|
||||
|
||||
def to_components(self) -> List[ActionRow]:
|
||||
"""Converts the view's children into a list of ActionRow components.
|
||||
|
||||
This retains the original, simple layout behaviour where each item is
|
||||
placed in its own :class:`ActionRow` to ensure backward compatibility.
|
||||
"""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
rows.append(ActionRow(components=[item]))
|
||||
|
||||
return rows
|
||||
|
||||
def layout_components_advanced(self) -> List[ActionRow]:
|
||||
"""Group compatible components into rows following Discord rules."""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
if item.custom_id is None:
|
||||
item.custom_id = (
|
||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||
)
|
||||
|
||||
target_row = item.row
|
||||
if target_row is not None:
|
||||
if not 0 <= target_row <= 4:
|
||||
raise ValueError("Row index must be between 0 and 4.")
|
||||
|
||||
while len(rows) <= target_row:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
rows.append(ActionRow())
|
||||
|
||||
rows[target_row].add_component(item)
|
||||
continue
|
||||
|
||||
placed = False
|
||||
for row in rows:
|
||||
try:
|
||||
row.add_component(item)
|
||||
placed = True
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not placed:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
new_row = ActionRow([item])
|
||||
rows.append(new_row)
|
||||
|
||||
return rows
|
||||
|
||||
def to_components_payload(self) -> List[Dict[str, Any]]:
|
||||
"""Converts the view's children into a list of component dictionaries
|
||||
that can be sent to the Discord API."""
|
||||
return [row.to_dict() for row in self.to_components()]
|
||||
|
||||
async def _dispatch(self, interaction: Interaction):
|
||||
"""Called by the client to dispatch an interaction to the correct item."""
|
||||
if self.timeout is not None:
|
||||
self.__stopped.set() # Reset the timeout on each interaction
|
||||
self.__stopped.clear()
|
||||
|
||||
if interaction.data:
|
||||
custom_id = interaction.data.custom_id
|
||||
for child in self.children:
|
||||
if child.custom_id == custom_id:
|
||||
if child.callback:
|
||||
await child.callback(self, interaction)
|
||||
break
|
||||
|
||||
async def wait(self) -> bool:
|
||||
"""Waits until the view has stopped interacting."""
|
||||
return await self.__stopped.wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stops the view from listening to interactions."""
|
||||
if not self.__stopped.is_set():
|
||||
self.__stopped.set()
|
||||
|
||||
async def on_timeout(self):
|
||||
"""Called when the view times out."""
|
||||
pass # User can override this
|
||||
|
||||
async def _start(self, client: Client):
|
||||
"""Starts the view's internal listener."""
|
||||
self._client = client
|
||||
if self.timeout is not None:
|
||||
asyncio.create_task(self._timeout_task())
|
||||
|
||||
async def _timeout_task(self):
|
||||
"""The task that waits for the timeout and then stops the view."""
|
||||
try:
|
||||
await asyncio.wait_for(self.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self.stop()
|
||||
await self.on_timeout()
|
||||
if self._client and self._message_id:
|
||||
# Remove the view from the client's listeners
|
||||
self._client._views.pop(self._message_id, None)
|
||||
|
||||
return rows
|
||||
|
||||
def layout_components_advanced(self) -> List[ActionRow]:
|
||||
"""Group compatible components into rows following Discord rules."""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
if item.custom_id is None:
|
||||
item.custom_id = (
|
||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||
)
|
||||
|
||||
target_row = item.row
|
||||
if target_row is not None:
|
||||
if not 0 <= target_row <= 4:
|
||||
raise ValueError("Row index must be between 0 and 4.")
|
||||
|
||||
while len(rows) <= target_row:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
rows.append(ActionRow())
|
||||
|
||||
rows[target_row].add_component(item)
|
||||
continue
|
||||
|
||||
placed = False
|
||||
for row in rows:
|
||||
try:
|
||||
row.add_component(item)
|
||||
placed = True
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not placed:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
new_row = ActionRow([item])
|
||||
rows.append(new_row)
|
||||
|
||||
return rows
|
||||
|
||||
def to_components_payload(self) -> List[Dict[str, Any]]:
|
||||
"""Converts the view's children into a list of component dictionaries
|
||||
that can be sent to the Discord API."""
|
||||
return [row.to_dict() for row in self.to_components()]
|
||||
|
||||
async def _dispatch(self, interaction: Interaction):
|
||||
"""Called by the client to dispatch an interaction to the correct item."""
|
||||
if self.timeout is not None:
|
||||
self.__stopped.set() # Reset the timeout on each interaction
|
||||
self.__stopped.clear()
|
||||
|
||||
if interaction.data:
|
||||
custom_id = interaction.data.custom_id
|
||||
for child in self.children:
|
||||
if child.custom_id == custom_id:
|
||||
if child.callback:
|
||||
await child.callback(self, interaction)
|
||||
break
|
||||
|
||||
async def wait(self) -> bool:
|
||||
"""Waits until the view has stopped interacting."""
|
||||
return await self.__stopped.wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stops the view from listening to interactions."""
|
||||
if not self.__stopped.is_set():
|
||||
self.__stopped.set()
|
||||
|
||||
async def on_timeout(self):
|
||||
"""Called when the view times out."""
|
||||
pass # User can override this
|
||||
|
||||
async def _start(self, client: Client):
|
||||
"""Starts the view's internal listener."""
|
||||
self._client = client
|
||||
if self.timeout is not None:
|
||||
asyncio.create_task(self._timeout_task())
|
||||
|
||||
async def _timeout_task(self):
|
||||
"""The task that waits for the timeout and then stops the view."""
|
||||
try:
|
||||
await asyncio.wait_for(self.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self.stop()
|
||||
await self.on_timeout()
|
||||
if self._client and self._message_id:
|
||||
# Remove the view from the client's listeners
|
||||
self._client._views.pop(self._message_id, None)
|
||||
|
@ -1,130 +1,130 @@
|
||||
# disagreement/voice_client.py
|
||||
"""Voice gateway and UDP audio client."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
# disagreement/voice_client.py
|
||||
"""Voice gateway and UDP audio client."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Optional, Sequence
|
||||
|
||||
import aiohttp
|
||||
|
||||
import aiohttp
|
||||
# The following import is correct, but may be flagged by Pylance if the virtual
|
||||
# environment is not configured correctly.
|
||||
from nacl.secret import SecretBox
|
||||
|
||||
|
||||
from .audio import AudioSink, AudioSource, FFmpegAudioSource
|
||||
from .models import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
|
||||
|
||||
class VoiceClient:
|
||||
"""Handles the Discord voice WebSocket connection and UDP streaming."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
|
||||
class VoiceClient:
|
||||
"""Handles the Discord voice WebSocket connection and UDP streaming."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
endpoint: str,
|
||||
session_id: str,
|
||||
token: str,
|
||||
guild_id: int,
|
||||
user_id: int,
|
||||
*,
|
||||
ws=None,
|
||||
udp: Optional[socket.socket] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
endpoint: str,
|
||||
session_id: str,
|
||||
token: str,
|
||||
guild_id: int,
|
||||
user_id: int,
|
||||
*,
|
||||
ws=None,
|
||||
udp: Optional[socket.socket] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.endpoint = endpoint
|
||||
self.session_id = session_id
|
||||
self.token = token
|
||||
self.guild_id = str(guild_id)
|
||||
self.user_id = str(user_id)
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
|
||||
self._udp = udp
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self.endpoint = endpoint
|
||||
self.session_id = session_id
|
||||
self.token = token
|
||||
self.guild_id = str(guild_id)
|
||||
self.user_id = str(user_id)
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
|
||||
self._udp = udp
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._udp_receive_thread: Optional[threading.Thread] = None
|
||||
self._heartbeat_interval: Optional[float] = None
|
||||
self._heartbeat_interval: Optional[float] = None
|
||||
try:
|
||||
self._loop = loop or asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self.verbose = verbose
|
||||
self.ssrc: Optional[int] = None
|
||||
self.secret_key: Optional[Sequence[int]] = None
|
||||
self._server_ip: Optional[str] = None
|
||||
self._server_port: Optional[int] = None
|
||||
self._current_source: Optional[AudioSource] = None
|
||||
self._play_task: Optional[asyncio.Task] = None
|
||||
self.verbose = verbose
|
||||
self.ssrc: Optional[int] = None
|
||||
self.secret_key: Optional[Sequence[int]] = None
|
||||
self._server_ip: Optional[str] = None
|
||||
self._server_port: Optional[int] = None
|
||||
self._current_source: Optional[AudioSource] = None
|
||||
self._play_task: Optional[asyncio.Task] = None
|
||||
self._sink: Optional[AudioSink] = None
|
||||
self._ssrc_map: dict[int, int] = {}
|
||||
self._ssrc_lock = threading.Lock()
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._ws is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._ws = await self._session.ws_connect(self.endpoint)
|
||||
|
||||
hello = await self._ws.receive_json()
|
||||
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
|
||||
self._heartbeat_task = self._loop.create_task(self._heartbeat())
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 0,
|
||||
"d": {
|
||||
"server_id": self.guild_id,
|
||||
"user_id": self.user_id,
|
||||
"session_id": self.session_id,
|
||||
"token": self.token,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ready = await self._ws.receive_json()
|
||||
data = ready["d"]
|
||||
self.ssrc = data["ssrc"]
|
||||
self._server_ip = data["ip"]
|
||||
self._server_port = data["port"]
|
||||
|
||||
if self._udp is None:
|
||||
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._udp.connect((self._server_ip, self._server_port))
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 1,
|
||||
"d": {
|
||||
"protocol": "udp",
|
||||
"data": {
|
||||
"address": self._udp.getsockname()[0],
|
||||
"port": self._udp.getsockname()[1],
|
||||
"mode": "xsalsa20_poly1305",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
session_desc = await self._ws.receive_json()
|
||||
self.secret_key = session_desc["d"].get("secret_key")
|
||||
|
||||
async def _heartbeat(self) -> None:
|
||||
assert self._ws is not None
|
||||
assert self._heartbeat_interval is not None
|
||||
try:
|
||||
while True:
|
||||
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
|
||||
await asyncio.sleep(self._heartbeat_interval)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._ws is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._ws = await self._session.ws_connect(self.endpoint)
|
||||
|
||||
hello = await self._ws.receive_json()
|
||||
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
|
||||
self._heartbeat_task = self._loop.create_task(self._heartbeat())
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 0,
|
||||
"d": {
|
||||
"server_id": self.guild_id,
|
||||
"user_id": self.user_id,
|
||||
"session_id": self.session_id,
|
||||
"token": self.token,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ready = await self._ws.receive_json()
|
||||
data = ready["d"]
|
||||
self.ssrc = data["ssrc"]
|
||||
self._server_ip = data["ip"]
|
||||
self._server_port = data["port"]
|
||||
|
||||
if self._udp is None:
|
||||
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._udp.connect((self._server_ip, self._server_port))
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 1,
|
||||
"d": {
|
||||
"protocol": "udp",
|
||||
"data": {
|
||||
"address": self._udp.getsockname()[0],
|
||||
"port": self._udp.getsockname()[1],
|
||||
"mode": "xsalsa20_poly1305",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
session_desc = await self._ws.receive_json()
|
||||
self.secret_key = session_desc["d"].get("secret_key")
|
||||
|
||||
async def _heartbeat(self) -> None:
|
||||
assert self._ws is not None
|
||||
assert self._heartbeat_interval is not None
|
||||
try:
|
||||
while True:
|
||||
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
|
||||
await asyncio.sleep(self._heartbeat_interval)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
assert self._ws is not None
|
||||
while True:
|
||||
@ -168,48 +168,48 @@ class VoiceClient:
|
||||
if self.verbose:
|
||||
print(f"Error in UDP receive loop: {e}")
|
||||
|
||||
async def send_audio_frame(self, frame: bytes) -> None:
|
||||
if not self._udp:
|
||||
raise RuntimeError("UDP socket not initialised")
|
||||
self._udp.send(frame)
|
||||
|
||||
async def _play_loop(self) -> None:
|
||||
assert self._current_source is not None
|
||||
try:
|
||||
while True:
|
||||
data = await self._current_source.read()
|
||||
if not data:
|
||||
break
|
||||
await self.send_audio_frame(data)
|
||||
finally:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
self._play_task = None
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._play_task:
|
||||
self._play_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._play_task
|
||||
self._play_task = None
|
||||
if self._current_source:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
|
||||
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
|
||||
"""|coro| Play an :class:`AudioSource` on the voice connection."""
|
||||
|
||||
await self.stop()
|
||||
self._current_source = source
|
||||
self._play_task = self._loop.create_task(self._play_loop())
|
||||
if wait:
|
||||
await self._play_task
|
||||
|
||||
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
||||
"""|coro| Stream an audio file or URL using FFmpeg."""
|
||||
|
||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||
|
||||
async def send_audio_frame(self, frame: bytes) -> None:
|
||||
if not self._udp:
|
||||
raise RuntimeError("UDP socket not initialised")
|
||||
self._udp.send(frame)
|
||||
|
||||
async def _play_loop(self) -> None:
|
||||
assert self._current_source is not None
|
||||
try:
|
||||
while True:
|
||||
data = await self._current_source.read()
|
||||
if not data:
|
||||
break
|
||||
await self.send_audio_frame(data)
|
||||
finally:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
self._play_task = None
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._play_task:
|
||||
self._play_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._play_task
|
||||
self._play_task = None
|
||||
if self._current_source:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
|
||||
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
|
||||
"""|coro| Play an :class:`AudioSource` on the voice connection."""
|
||||
|
||||
await self.stop()
|
||||
self._current_source = source
|
||||
self._play_task = self._loop.create_task(self._play_loop())
|
||||
if wait:
|
||||
await self._play_task
|
||||
|
||||
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
||||
"""|coro| Stream an audio file or URL using FFmpeg."""
|
||||
|
||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||
|
||||
def listen(self, sink: AudioSink) -> None:
|
||||
"""Start listening to voice and routing to a sink."""
|
||||
if not isinstance(sink, AudioSink):
|
||||
@ -222,22 +222,22 @@ class VoiceClient:
|
||||
)
|
||||
self._udp_receive_thread.start()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.stop()
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._heartbeat_task
|
||||
async def close(self) -> None:
|
||||
await self.stop()
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._heartbeat_task
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._receive_task
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
if self._udp:
|
||||
self._udp.close()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
if self._udp:
|
||||
self._udp.close()
|
||||
if self._udp_receive_thread:
|
||||
self._udp_receive_thread.join(timeout=1)
|
||||
if self._sink:
|
||||
|
112
pyproject.toml
112
pyproject.toml
@ -1,57 +1,57 @@
|
||||
[project]
|
||||
name = "disagreement"
|
||||
version = "0.2.0rc1"
|
||||
description = "A Python library for the Discord API."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "BSD 3-Clause"}
|
||||
authors = [
|
||||
{name = "Slipstream", email = "me@slipstreamm.dev"}
|
||||
]
|
||||
keywords = ["discord", "api", "bot", "async", "aiohttp"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Internet",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
[project]
|
||||
name = "disagreement"
|
||||
version = "0.2.0rc1"
|
||||
description = "A Python library for the Discord API."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "BSD 3-Clause"}
|
||||
authors = [
|
||||
{name = "Slipstream", email = "me@slipstreamm.dev"}
|
||||
]
|
||||
keywords = ["discord", "api", "bot", "async", "aiohttp"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Internet",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"PyNaCl>=1.5.0,<2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"hypothesis>=6.132.0",
|
||||
]
|
||||
dev = [
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Slipstreamm/disagreement"
|
||||
Issues = "https://github.com/Slipstreamm/disagreement/issues"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
# Optional: for linting/formatting, e.g., Ruff
|
||||
# [tool.ruff]
|
||||
# line-length = 88
|
||||
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
|
||||
# ignore = []
|
||||
|
||||
# [tool.ruff.format]
|
||||
# quote-style = "double"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"hypothesis>=6.132.0",
|
||||
]
|
||||
dev = [
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Slipstreamm/disagreement"
|
||||
Issues = "https://github.com/Slipstreamm/disagreement/issues"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
# Optional: for linting/formatting, e.g., Ruff
|
||||
# [tool.ruff]
|
||||
# line-length = 88
|
||||
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
|
||||
# ignore = []
|
||||
|
||||
# [tool.ruff.format]
|
||||
# quote-style = "double"
|
||||
|
Loading…
x
Reference in New Issue
Block a user