Merge branch 'master' of https://github.com/Slipstreamm/disagreement
This commit is contained in:
commit
380feddeeb
16
README.md
16
README.md
@ -14,6 +14,7 @@ A Python library for interacting with the Discord API, with a focus on bot devel
|
||||
- `Message.jump_url` property for quick links to messages
|
||||
- Built-in caching layer
|
||||
- `Guild.me` property to access the bot's member object
|
||||
- Easy CDN asset handling via the `Asset` model
|
||||
- Experimental voice support
|
||||
- Helpful error handling utilities
|
||||
|
||||
@ -111,6 +112,10 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying
|
||||
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
|
||||
session parameter supported by ``aiohttp``.
|
||||
|
||||
### Logging Out
|
||||
|
||||
Call ``Client.logout`` to disconnect from the Gateway and clear the current bot token while keeping the HTTP session alive. Assign a new token and call ``connect`` or ``run`` to log back in.
|
||||
|
||||
### Default Allowed Mentions
|
||||
|
||||
Specify default mention behaviour for all outgoing messages when constructing the client:
|
||||
@ -126,6 +131,17 @@ client = disagreement.Client(
|
||||
This dictionary is used whenever ``send_message`` or helpers like ``Message.reply``
|
||||
are called without an explicit ``allowed_mentions`` argument.
|
||||
|
||||
### Working With Assets
|
||||
|
||||
Properties like ``User.avatar`` and ``Guild.icon`` return :class:`disagreement.Asset` objects.
|
||||
Use ``read`` to get the bytes or ``save`` to write them to disk.
|
||||
|
||||
```python
|
||||
user = await client.fetch_user(123)
|
||||
data = await user.avatar.read()
|
||||
await user.avatar.save("avatar.png")
|
||||
```
|
||||
|
||||
### Defining Subcommands with `AppCommandGroup`
|
||||
|
||||
```python
|
||||
|
@ -15,6 +15,7 @@ __copyright__ = "Copyright 2025 Slipstream"
|
||||
__version__ = "0.8.1"
|
||||
|
||||
from .client import Client, AutoShardedClient
|
||||
from .asset import Asset
|
||||
from .models import (
|
||||
Message,
|
||||
User,
|
||||
@ -39,6 +40,7 @@ from .models import (
|
||||
Container,
|
||||
Guild,
|
||||
)
|
||||
from .object import Object
|
||||
from .voice_client import VoiceClient
|
||||
from .audio import AudioSource, FFmpegAudioSource
|
||||
from .typing import Typing
|
||||
@ -125,6 +127,7 @@ import logging
|
||||
__all__ = [
|
||||
"Client",
|
||||
"AutoShardedClient",
|
||||
"Asset",
|
||||
"Message",
|
||||
"User",
|
||||
"Reaction",
|
||||
@ -146,6 +149,7 @@ __all__ = [
|
||||
"MediaGallery",
|
||||
"MediaGalleryItem",
|
||||
"Container",
|
||||
"Object",
|
||||
"VoiceClient",
|
||||
"AudioSource",
|
||||
"FFmpegAudioSource",
|
||||
|
51
disagreement/asset.py
Normal file
51
disagreement/asset.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Utility class for Discord CDN assets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import IO, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import aiohttp # pylint: disable=import-error
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
|
||||
|
||||
class Asset:
|
||||
"""Represents a CDN asset such as an avatar or icon."""
|
||||
|
||||
def __init__(self, url: str, client_instance: Optional["Client"] = None) -> None:
|
||||
self.url = url
|
||||
self._client = client_instance
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""Read the asset's bytes."""
|
||||
|
||||
session: Optional[aiohttp.ClientSession] = None
|
||||
if self._client is not None:
|
||||
await self._client._http._ensure_session() # type: ignore[attr-defined]
|
||||
session = self._client._http._session # type: ignore[attr-defined]
|
||||
if session is None:
|
||||
session = aiohttp.ClientSession()
|
||||
close = True
|
||||
else:
|
||||
close = False
|
||||
async with session.get(self.url) as resp:
|
||||
data = await resp.read()
|
||||
if close:
|
||||
await session.close()
|
||||
return data
|
||||
|
||||
async def save(self, fp: Union[str, os.PathLike[str], IO[bytes]]) -> None:
|
||||
"""Save the asset to the given file path or file-like object."""
|
||||
|
||||
data = await self.read()
|
||||
if isinstance(fp, (str, os.PathLike)):
|
||||
path = os.fspath(fp)
|
||||
with open(path, "wb") as file:
|
||||
file.write(data)
|
||||
else:
|
||||
fp.write(data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset url='{self.url}'>"
|
@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]):
|
||||
|
||||
def _should_cache(self, member: Member) -> bool:
|
||||
"""Determines if a member should be cached based on the flags."""
|
||||
if self.flags.all:
|
||||
if self.flags.all_enabled:
|
||||
return True
|
||||
if self.flags.none:
|
||||
if self.flags.no_flags:
|
||||
return False
|
||||
|
||||
if self.flags.online and member.status != "offline":
|
||||
|
@ -74,6 +74,14 @@ class MemberCacheFlags:
|
||||
for name in self.VALID_FLAGS:
|
||||
yield name, getattr(self, name)
|
||||
|
||||
@property
|
||||
def all_enabled(self) -> bool:
|
||||
return self.value == self.ALL_FLAGS
|
||||
|
||||
@property
|
||||
def no_flags(self) -> bool:
|
||||
return self.value == 0
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
|
@ -2,8 +2,11 @@
|
||||
The main Client class for interacting with the Discord API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import asyncio
|
||||
import signal
|
||||
import json
|
||||
import os
|
||||
import importlib
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
@ -16,7 +19,9 @@ from typing import (
|
||||
Dict,
|
||||
cast,
|
||||
)
|
||||
from types import ModuleType
|
||||
from types import ModuleType
|
||||
|
||||
PERSISTENT_VIEWS_FILE = "persistent_views.json"
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@ -77,7 +82,7 @@ def _update_list(lst: List[Any], item: Any) -> None:
|
||||
lst.append(item)
|
||||
|
||||
|
||||
class Client:
|
||||
class Client:
|
||||
"""
|
||||
Represents a client connection that connects to Discord.
|
||||
This class is used to interact with the Discord WebSocket and API.
|
||||
@ -193,7 +198,10 @@ class Client:
|
||||
self._views: Dict[Snowflake, "View"] = {}
|
||||
self._persistent_views: Dict[str, "View"] = {}
|
||||
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
||||
self._webhooks: Dict[Snowflake, "Webhook"] = {}
|
||||
self._webhooks: Dict[Snowflake, "Webhook"] = {}
|
||||
|
||||
# Load persistent views stored on disk
|
||||
self._load_persistent_views()
|
||||
|
||||
# Default whether replies mention the user
|
||||
self.mention_replies: bool = mention_replies
|
||||
@ -210,13 +218,46 @@ class Client:
|
||||
self.loop.add_signal_handler(
|
||||
signal.SIGTERM, lambda: self.loop.create_task(self.close())
|
||||
)
|
||||
except NotImplementedError:
|
||||
# add_signal_handler is not available on all platforms (e.g., Windows default event loop policy)
|
||||
# Users on these platforms would need to handle shutdown differently.
|
||||
print(
|
||||
"Warning: Signal handlers for SIGINT/SIGTERM could not be added. "
|
||||
"Graceful shutdown via signals might not work as expected on this platform."
|
||||
)
|
||||
except NotImplementedError:
|
||||
# add_signal_handler is not available on all platforms (e.g., Windows default event loop policy)
|
||||
# Users on these platforms would need to handle shutdown differently.
|
||||
print(
|
||||
"Warning: Signal handlers for SIGINT/SIGTERM could not be added. "
|
||||
"Graceful shutdown via signals might not work as expected on this platform."
|
||||
)
|
||||
|
||||
def _load_persistent_views(self) -> None:
|
||||
"""Load registered persistent views from disk."""
|
||||
if not os.path.isfile(PERSISTENT_VIEWS_FILE):
|
||||
return
|
||||
try:
|
||||
with open(PERSISTENT_VIEWS_FILE, "r") as fp:
|
||||
mapping = json.load(fp)
|
||||
except Exception as e: # pragma: no cover - best effort load
|
||||
print(f"Failed to load persistent views: {e}")
|
||||
return
|
||||
|
||||
for custom_id, path in mapping.items():
|
||||
try:
|
||||
module_name, class_name = path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
view = cls()
|
||||
self._persistent_views[custom_id] = view
|
||||
except Exception as e: # pragma: no cover - best effort load
|
||||
print(f"Failed to initialize persistent view {path}: {e}")
|
||||
|
||||
def _save_persistent_views(self) -> None:
|
||||
"""Persist registered views to disk."""
|
||||
data = {}
|
||||
for custom_id, view in self._persistent_views.items():
|
||||
cls = view.__class__
|
||||
data[custom_id] = f"{cls.__module__}.{cls.__name__}"
|
||||
try:
|
||||
with open(PERSISTENT_VIEWS_FILE, "w") as fp:
|
||||
json.dump(data, fp)
|
||||
except Exception as e: # pragma: no cover - best effort save
|
||||
print(f"Failed to save persistent views: {e}")
|
||||
|
||||
async def _initialize_gateway(self):
|
||||
"""Initializes the GatewayClient if it doesn't exist."""
|
||||
@ -413,15 +454,23 @@ class Client:
|
||||
await self.close()
|
||||
return False
|
||||
|
||||
async def close_gateway(self, code: int = 1000) -> None:
|
||||
"""Closes only the gateway connection, allowing for potential reconnect."""
|
||||
if self._shard_manager:
|
||||
await self._shard_manager.close()
|
||||
self._shard_manager = None
|
||||
if self._gateway:
|
||||
await self._gateway.close(code=code)
|
||||
self._gateway = None
|
||||
self._ready_event.clear() # No longer ready if gateway is closed
|
||||
async def close_gateway(self, code: int = 1000) -> None:
|
||||
"""Closes only the gateway connection, allowing for potential reconnect."""
|
||||
if self._shard_manager:
|
||||
await self._shard_manager.close()
|
||||
self._shard_manager = None
|
||||
if self._gateway:
|
||||
await self._gateway.close(code=code)
|
||||
self._gateway = None
|
||||
self._ready_event.clear() # No longer ready if gateway is closed
|
||||
|
||||
async def logout(self) -> None:
|
||||
"""Invalidate the bot token and disconnect from the Gateway."""
|
||||
await self.close_gateway()
|
||||
self.token = ""
|
||||
self._http.token = ""
|
||||
self.user = None
|
||||
self.start_time = None
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Indicates if the client has been closed."""
|
||||
@ -1399,10 +1448,30 @@ class Client:
|
||||
|
||||
return self._channels.get(channel_id)
|
||||
|
||||
def get_message(self, message_id: Snowflake) -> Optional["Message"]:
|
||||
"""Returns a message from the internal cache."""
|
||||
|
||||
return self._messages.get(message_id)
|
||||
def get_message(self, message_id: Snowflake) -> Optional["Message"]:
|
||||
"""Returns a message from the internal cache."""
|
||||
|
||||
return self._messages.get(message_id)
|
||||
|
||||
def get_all_channels(self) -> List["Channel"]:
|
||||
"""Return all channels cached in every guild."""
|
||||
|
||||
channels: List["Channel"] = []
|
||||
for guild in self._guilds.values():
|
||||
channels.extend(guild._channels.values())
|
||||
return channels
|
||||
|
||||
def get_all_members(self) -> List["Member"]:
|
||||
"""Return all cached members across all guilds.
|
||||
|
||||
When member caching is disabled via :class:`MemberCacheFlags.none`, this
|
||||
list will always be empty.
|
||||
"""
|
||||
|
||||
members: List["Member"] = []
|
||||
for guild in self._guilds.values():
|
||||
members.extend(guild._members.values())
|
||||
return members
|
||||
|
||||
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
|
||||
"""Fetches a guild by ID from Discord and caches it."""
|
||||
@ -1707,11 +1776,13 @@ class Client:
|
||||
|
||||
for item in view.children:
|
||||
if item.custom_id: # Ensure custom_id is not None
|
||||
if item.custom_id in self._persistent_views:
|
||||
raise ValueError(
|
||||
f"A component with custom_id '{item.custom_id}' is already registered."
|
||||
)
|
||||
self._persistent_views[item.custom_id] = view
|
||||
if item.custom_id in self._persistent_views:
|
||||
raise ValueError(
|
||||
f"A component with custom_id '{item.custom_id}' is already registered."
|
||||
)
|
||||
self._persistent_views[item.custom_id] = view
|
||||
|
||||
self._save_persistent_views()
|
||||
|
||||
# --- Application Command Methods ---
|
||||
async def process_interaction(self, interaction: Interaction) -> None:
|
||||
|
@ -6,7 +6,16 @@ import re
|
||||
import inspect
|
||||
|
||||
from .errors import BadArgument
|
||||
from disagreement.models import Member, Guild, Role, User
|
||||
from disagreement.models import (
|
||||
Member,
|
||||
Guild,
|
||||
Role,
|
||||
User,
|
||||
TextChannel,
|
||||
VoiceChannel,
|
||||
Emoji,
|
||||
PartialEmoji,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import CommandContext
|
||||
@ -158,6 +167,82 @@ class UserConverter(Converter["User"]):
|
||||
raise BadArgument(f"User '{argument}' not found.")
|
||||
|
||||
|
||||
class TextChannelConverter(Converter["TextChannel"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "TextChannel":
|
||||
if not ctx.message.guild_id:
|
||||
raise BadArgument("TextChannel converter requires guild context.")
|
||||
|
||||
match = re.match(r"<#(?P<id>\d+)>$", argument)
|
||||
channel_id = match.group("id") if match else argument
|
||||
|
||||
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||
if guild:
|
||||
channel = guild.get_channel(channel_id)
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel
|
||||
|
||||
channel = (
|
||||
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
|
||||
)
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel
|
||||
|
||||
if hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(channel_id)
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel
|
||||
|
||||
raise BadArgument(f"Text channel '{argument}' not found.")
|
||||
|
||||
|
||||
class VoiceChannelConverter(Converter["VoiceChannel"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "VoiceChannel":
|
||||
if not ctx.message.guild_id:
|
||||
raise BadArgument("VoiceChannel converter requires guild context.")
|
||||
|
||||
match = re.match(r"<#(?P<id>\d+)>$", argument)
|
||||
channel_id = match.group("id") if match else argument
|
||||
|
||||
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||
if guild:
|
||||
channel = guild.get_channel(channel_id)
|
||||
if isinstance(channel, VoiceChannel):
|
||||
return channel
|
||||
|
||||
channel = (
|
||||
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
|
||||
)
|
||||
if isinstance(channel, VoiceChannel):
|
||||
return channel
|
||||
|
||||
if hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(channel_id)
|
||||
if isinstance(channel, VoiceChannel):
|
||||
return channel
|
||||
|
||||
raise BadArgument(f"Voice channel '{argument}' not found.")
|
||||
|
||||
|
||||
class EmojiConverter(Converter["PartialEmoji"]):
|
||||
_CUSTOM_RE = re.compile(r"<(?P<animated>a)?:(?P<name>[^:]+):(?P<id>\d+)>$")
|
||||
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "PartialEmoji":
|
||||
match = self._CUSTOM_RE.match(argument)
|
||||
if match:
|
||||
return PartialEmoji(
|
||||
{
|
||||
"id": match.group("id"),
|
||||
"name": match.group("name"),
|
||||
"animated": bool(match.group("animated")),
|
||||
}
|
||||
)
|
||||
|
||||
if argument:
|
||||
return PartialEmoji({"id": None, "name": argument})
|
||||
|
||||
raise BadArgument(f"Emoji '{argument}' not found.")
|
||||
|
||||
|
||||
# Default converters mapping
|
||||
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||
int: IntConverter(),
|
||||
@ -168,6 +253,10 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||
Guild: GuildConverter(),
|
||||
Role: RoleConverter(),
|
||||
User: UserConverter(),
|
||||
TextChannel: TextChannelConverter(),
|
||||
VoiceChannel: VoiceChannelConverter(),
|
||||
PartialEmoji: EmojiConverter(),
|
||||
Emoji: EmojiConverter(),
|
||||
}
|
||||
|
||||
|
||||
|
@ -79,8 +79,15 @@ class GroupMixin:
|
||||
)
|
||||
self.commands[alias.lower()] = command
|
||||
|
||||
def get_command(self, name: str) -> Optional["Command"]:
|
||||
return self.commands.get(name.lower())
|
||||
def get_command(self, name: str) -> Optional["Command"]:
|
||||
return self.commands.get(name.lower())
|
||||
|
||||
def walk_commands(self):
|
||||
"""Yield all commands in this group recursively."""
|
||||
for cmd in dict.fromkeys(self.commands.values()):
|
||||
yield cmd
|
||||
if isinstance(cmd, Group):
|
||||
yield from cmd.walk_commands()
|
||||
|
||||
|
||||
class Command(GroupMixin):
|
||||
@ -366,6 +373,13 @@ class CommandHandler:
|
||||
def get_command(self, name: str) -> Optional[Command]:
|
||||
return self.commands.get(name.lower())
|
||||
|
||||
def walk_commands(self):
|
||||
"""Yield every registered command, including subcommands."""
|
||||
for cmd in dict.fromkeys(self.commands.values()):
|
||||
yield cmd
|
||||
if isinstance(cmd, Group):
|
||||
yield from cmd.walk_commands()
|
||||
|
||||
def get_cog(self, name: str) -> Optional["Cog"]:
|
||||
"""Return a loaded cog by name if present."""
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional
|
||||
|
||||
from .core import Command, CommandContext, CommandHandler
|
||||
from ...utils import Paginator
|
||||
from .core import Command, CommandContext, CommandHandler, Group
|
||||
|
||||
|
||||
class HelpCommand(Command):
|
||||
@ -15,17 +17,22 @@ class HelpCommand(Command):
|
||||
if not cmd or cmd.name.lower() != command.lower():
|
||||
await ctx.send(f"Command '{command}' not found.")
|
||||
return
|
||||
description = cmd.description or cmd.brief or "No description provided."
|
||||
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
|
||||
else:
|
||||
lines: List[str] = []
|
||||
for registered in dict.fromkeys(handler.commands.values()):
|
||||
brief = registered.brief or registered.description or ""
|
||||
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
|
||||
if lines:
|
||||
await ctx.send("\n".join(lines))
|
||||
if isinstance(cmd, Group):
|
||||
await self.send_group_help(ctx, cmd)
|
||||
elif cmd:
|
||||
description = cmd.description or cmd.brief or "No description provided."
|
||||
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
|
||||
else:
|
||||
await ctx.send("No commands available.")
|
||||
lines: List[str] = []
|
||||
for registered in handler.walk_commands():
|
||||
brief = registered.brief or registered.description or ""
|
||||
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
|
||||
if lines:
|
||||
await ctx.send("\n".join(lines))
|
||||
else:
|
||||
await self.send_command_help(ctx, cmd)
|
||||
else:
|
||||
await self.send_bot_help(ctx)
|
||||
|
||||
super().__init__(
|
||||
callback,
|
||||
@ -33,3 +40,42 @@ class HelpCommand(Command):
|
||||
brief="Show command help.",
|
||||
description="Displays help for commands.",
|
||||
)
|
||||
|
||||
async def send_bot_help(self, ctx: CommandContext) -> None:
|
||||
groups = defaultdict(list)
|
||||
for cmd in dict.fromkeys(self.handler.commands.values()):
|
||||
key = cmd.cog.cog_name if cmd.cog else "No Category"
|
||||
groups[key].append(cmd)
|
||||
|
||||
paginator = Paginator()
|
||||
for cog_name, cmds in groups.items():
|
||||
paginator.add_line(f"**{cog_name}**")
|
||||
for cmd in cmds:
|
||||
brief = cmd.brief or cmd.description or ""
|
||||
paginator.add_line(f"{ctx.prefix}{cmd.name} - {brief}".strip())
|
||||
paginator.add_line("")
|
||||
|
||||
pages = paginator.pages
|
||||
if not pages:
|
||||
await ctx.send("No commands available.")
|
||||
return
|
||||
for page in pages:
|
||||
await ctx.send(page)
|
||||
|
||||
async def send_command_help(self, ctx: CommandContext, command: Command) -> None:
|
||||
description = command.description or command.brief or "No description provided."
|
||||
await ctx.send(f"**{ctx.prefix}{command.name}**\n{description}")
|
||||
|
||||
async def send_group_help(self, ctx: CommandContext, group: Group) -> None:
|
||||
paginator = Paginator()
|
||||
description = group.description or group.brief or "No description provided."
|
||||
paginator.add_line(f"**{ctx.prefix}{group.name}**\n{description}")
|
||||
if group.commands:
|
||||
for sub in dict.fromkeys(group.commands.values()):
|
||||
brief = sub.brief or sub.description or ""
|
||||
paginator.add_line(
|
||||
f"{ctx.prefix}{group.name} {sub.name} - {brief}".strip()
|
||||
)
|
||||
|
||||
for page in paginator.pages:
|
||||
await ctx.send(page)
|
||||
|
@ -656,21 +656,21 @@ class HTTPClient:
|
||||
|
||||
await self.request("PUT", f"/channels/{channel_id}/pins/{message_id}")
|
||||
|
||||
async def unpin_message(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||
) -> None:
|
||||
"""Unpins a message from a channel."""
|
||||
|
||||
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
|
||||
|
||||
async def crosspost_message(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||
) -> Dict[str, Any]:
|
||||
"""Crossposts a message to any following channels."""
|
||||
|
||||
return await self.request(
|
||||
"POST", f"/channels/{channel_id}/messages/{message_id}/crosspost"
|
||||
)
|
||||
async def unpin_message(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||
) -> None:
|
||||
"""Unpins a message from a channel."""
|
||||
|
||||
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
|
||||
|
||||
async def crosspost_message(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||
) -> Dict[str, Any]:
|
||||
"""Crossposts a message to any following channels."""
|
||||
|
||||
return await self.request(
|
||||
"POST", f"/channels/{channel_id}/messages/{message_id}/crosspost"
|
||||
)
|
||||
|
||||
async def delete_channel(
|
||||
self, channel_id: str, reason: Optional[str] = None
|
||||
@ -734,68 +734,68 @@ class HTTPClient:
|
||||
|
||||
return await self.request("GET", f"/channels/{channel_id}/invites")
|
||||
|
||||
async def create_invite(
|
||||
self, channel_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "Invite":
|
||||
"""Creates an invite for a channel."""
|
||||
async def create_invite(
|
||||
self, channel_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "Invite":
|
||||
"""Creates an invite for a channel."""
|
||||
|
||||
data = await self.request(
|
||||
"POST", f"/channels/{channel_id}/invites", payload=payload
|
||||
)
|
||||
from .models import Invite
|
||||
|
||||
return Invite.from_dict(data)
|
||||
|
||||
async def create_channel_invite(
|
||||
self,
|
||||
channel_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
) -> "Invite":
|
||||
"""Creates an invite for a channel with an optional audit log reason."""
|
||||
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
data = await self.request(
|
||||
"POST",
|
||||
f"/channels/{channel_id}/invites",
|
||||
payload=payload,
|
||||
custom_headers=headers,
|
||||
)
|
||||
from .models import Invite
|
||||
|
||||
return Invite.from_dict(data)
|
||||
from .models import Invite
|
||||
|
||||
async def delete_invite(self, code: str) -> None:
|
||||
"""Deletes an invite by code."""
|
||||
|
||||
await self.request("DELETE", f"/invites/{code}")
|
||||
|
||||
async def get_invite(self, code: "Snowflake") -> Dict[str, Any]:
|
||||
"""Fetches a single invite by its code."""
|
||||
|
||||
return await self.request("GET", f"/invites/{code}")
|
||||
return Invite.from_dict(data)
|
||||
|
||||
async def create_webhook(
|
||||
self, channel_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "Webhook":
|
||||
"""Creates a webhook in the specified channel."""
|
||||
async def create_channel_invite(
|
||||
self,
|
||||
channel_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
) -> "Invite":
|
||||
"""Creates an invite for a channel with an optional audit log reason."""
|
||||
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
data = await self.request(
|
||||
"POST",
|
||||
f"/channels/{channel_id}/invites",
|
||||
payload=payload,
|
||||
custom_headers=headers,
|
||||
)
|
||||
from .models import Invite
|
||||
|
||||
return Invite.from_dict(data)
|
||||
|
||||
async def delete_invite(self, code: str) -> None:
|
||||
"""Deletes an invite by code."""
|
||||
|
||||
await self.request("DELETE", f"/invites/{code}")
|
||||
|
||||
async def get_invite(self, code: "Snowflake") -> Dict[str, Any]:
|
||||
"""Fetches a single invite by its code."""
|
||||
|
||||
return await self.request("GET", f"/invites/{code}")
|
||||
|
||||
async def create_webhook(
|
||||
self, channel_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "Webhook":
|
||||
"""Creates a webhook in the specified channel."""
|
||||
|
||||
data = await self.request(
|
||||
"POST", f"/channels/{channel_id}/webhooks", payload=payload
|
||||
)
|
||||
from .models import Webhook
|
||||
|
||||
return Webhook(data)
|
||||
|
||||
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
|
||||
"""Fetches a webhook by ID."""
|
||||
|
||||
return await self.request("GET", f"/webhooks/{webhook_id}")
|
||||
|
||||
async def edit_webhook(
|
||||
self, webhook_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "Webhook":
|
||||
from .models import Webhook
|
||||
|
||||
return Webhook(data)
|
||||
|
||||
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
|
||||
"""Fetches a webhook by ID and returns the raw payload."""
|
||||
|
||||
return await self.request("GET", f"/webhooks/{webhook_id}")
|
||||
|
||||
async def edit_webhook(
|
||||
self, webhook_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "Webhook":
|
||||
"""Edits an existing webhook."""
|
||||
|
||||
data = await self.request("PATCH", f"/webhooks/{webhook_id}", payload=payload)
|
||||
@ -803,25 +803,28 @@ class HTTPClient:
|
||||
|
||||
return Webhook(data)
|
||||
|
||||
async def delete_webhook(self, webhook_id: "Snowflake") -> None:
|
||||
"""Deletes a webhook."""
|
||||
|
||||
await self.request("DELETE", f"/webhooks/{webhook_id}")
|
||||
|
||||
async def get_webhook(
|
||||
self, webhook_id: "Snowflake", token: Optional[str] = None
|
||||
) -> "Webhook":
|
||||
"""Fetches a webhook by ID, optionally using its token."""
|
||||
|
||||
endpoint = f"/webhooks/{webhook_id}"
|
||||
use_auth = True
|
||||
if token is not None:
|
||||
endpoint += f"/{token}"
|
||||
use_auth = False
|
||||
data = await self.request("GET", endpoint, use_auth_header=use_auth)
|
||||
from .models import Webhook
|
||||
|
||||
return Webhook(data)
|
||||
async def delete_webhook(self, webhook_id: "Snowflake") -> None:
|
||||
"""Deletes a webhook."""
|
||||
|
||||
await self.request("DELETE", f"/webhooks/{webhook_id}")
|
||||
|
||||
async def get_webhook_with_token(
|
||||
self, webhook_id: "Snowflake", token: Optional[str] = None
|
||||
) -> "Webhook":
|
||||
"""Fetches a webhook by ID, optionally using its token."""
|
||||
|
||||
endpoint = f"/webhooks/{webhook_id}"
|
||||
use_auth = True
|
||||
if token is not None:
|
||||
endpoint += f"/{token}"
|
||||
use_auth = False
|
||||
if use_auth:
|
||||
data = await self.request("GET", endpoint)
|
||||
else:
|
||||
data = await self.request("GET", endpoint, use_auth_header=False)
|
||||
from .models import Webhook
|
||||
|
||||
return Webhook(data)
|
||||
|
||||
async def execute_webhook(
|
||||
self,
|
||||
|
@ -2,6 +2,8 @@
|
||||
Data models for Discord objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
@ -54,13 +56,26 @@ if TYPE_CHECKING:
|
||||
from .interactions import Snowflake
|
||||
from .typing import Typing
|
||||
from .shard_manager import Shard
|
||||
from .asset import Asset
|
||||
|
||||
# Forward reference Message if it were used in type hints before its definition
|
||||
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
|
||||
from .components import component_factory
|
||||
|
||||
|
||||
class User:
|
||||
class HashableById:
|
||||
"""Mixin providing equality and hashing based on the ``id`` attribute."""
|
||||
|
||||
id: str
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.id == other.id # type: ignore[attr-defined]
|
||||
|
||||
def __hash__(self) -> int: # pragma: no cover - trivial
|
||||
return hash(self.id)
|
||||
|
||||
|
||||
class User(HashableById):
|
||||
"""Represents a Discord User."""
|
||||
|
||||
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
|
||||
@ -71,7 +86,12 @@ class User:
|
||||
self.username: Optional[str] = data.get("username")
|
||||
self.discriminator: Optional[str] = data.get("discriminator")
|
||||
self.bot: bool = data.get("bot", False)
|
||||
self.avatar: Optional[str] = data.get("avatar")
|
||||
avatar_hash = data.get("avatar")
|
||||
self._avatar: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
|
||||
if avatar_hash
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
@ -83,6 +103,25 @@ class User:
|
||||
disc = self.discriminator or "????"
|
||||
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
|
||||
|
||||
@property
|
||||
def avatar(self) -> Optional["Asset"]:
|
||||
"""Return the user's avatar as an :class:`Asset`."""
|
||||
|
||||
if self._avatar:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._avatar, self._client)
|
||||
return None
|
||||
|
||||
@avatar.setter
|
||||
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._avatar = value
|
||||
elif value is None:
|
||||
self._avatar = None
|
||||
else:
|
||||
self._avatar = value.url
|
||||
|
||||
async def send(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
@ -98,7 +137,7 @@ class User:
|
||||
return await target_client.send_dm(self.id, content=content, **kwargs)
|
||||
|
||||
|
||||
class Message:
|
||||
class Message(HashableById):
|
||||
"""Represents a message sent in a channel on Discord.
|
||||
|
||||
Attributes:
|
||||
@ -780,7 +819,12 @@ class Role:
|
||||
self.name: str = data["name"]
|
||||
self.color: int = data["color"]
|
||||
self.hoist: bool = data["hoist"]
|
||||
self.icon: Optional[str] = data.get("icon")
|
||||
icon_hash = data.get("icon")
|
||||
self._icon: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/role-icons/{self.id}/{icon_hash}.png"
|
||||
if icon_hash
|
||||
else None
|
||||
)
|
||||
self.unicode_emoji: Optional[str] = data.get("unicode_emoji")
|
||||
self.position: int = data["position"]
|
||||
self.permissions: str = data["permissions"] # String of bitwise permissions
|
||||
@ -798,6 +842,23 @@ class Role:
|
||||
def __repr__(self) -> str:
|
||||
return f"<Role id='{self.id}' name='{self.name}'>"
|
||||
|
||||
@property
|
||||
def icon(self) -> Optional["Asset"]:
|
||||
if self._icon:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._icon, None)
|
||||
return None
|
||||
|
||||
@icon.setter
|
||||
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._icon = value
|
||||
elif value is None:
|
||||
self._icon = None
|
||||
else:
|
||||
self._icon = value.url
|
||||
|
||||
|
||||
class Member(User): # Member inherits from User
|
||||
"""Represents a Guild Member.
|
||||
@ -826,7 +887,15 @@ class Member(User): # Member inherits from User
|
||||
) # Pass user_data or data if user_data is empty
|
||||
|
||||
self.nick: Optional[str] = data.get("nick")
|
||||
self.avatar: Optional[str] = data.get("avatar")
|
||||
avatar_hash = data.get("avatar")
|
||||
if avatar_hash:
|
||||
guild_id = data.get("guild_id")
|
||||
if guild_id:
|
||||
self._avatar = f"https://cdn.discordapp.com/guilds/{guild_id}/users/{self.id}/avatars/{avatar_hash}.png"
|
||||
else:
|
||||
self._avatar = (
|
||||
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
|
||||
)
|
||||
self.roles: List[str] = data.get("roles", [])
|
||||
self.joined_at: str = data["joined_at"]
|
||||
self.premium_since: Optional[str] = data.get("premium_since")
|
||||
@ -854,6 +923,25 @@ class Member(User): # Member inherits from User
|
||||
def __repr__(self) -> str:
|
||||
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
|
||||
|
||||
@property
|
||||
def avatar(self) -> Optional["Asset"]:
|
||||
"""Return the member's avatar as an :class:`Asset`."""
|
||||
|
||||
if self._avatar:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._avatar, self._client)
|
||||
return None
|
||||
|
||||
@avatar.setter
|
||||
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._avatar = value
|
||||
elif value is None:
|
||||
self._avatar = None
|
||||
else:
|
||||
self._avatar = value.url
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Return the nickname if set, otherwise the username."""
|
||||
@ -921,7 +1009,6 @@ class Member(User): # Member inherits from User
|
||||
return max(role_objects, key=lambda r: r.position)
|
||||
|
||||
@property
|
||||
|
||||
def guild_permissions(self) -> "Permissions":
|
||||
"""Return the member's guild-level permissions."""
|
||||
|
||||
@ -947,7 +1034,8 @@ class Member(User): # Member inherits from User
|
||||
return Permissions(~0)
|
||||
|
||||
return base
|
||||
|
||||
|
||||
@property
|
||||
def voice(self) -> Optional["VoiceState"]:
|
||||
"""Return the member's cached voice state as a :class:`VoiceState`."""
|
||||
|
||||
@ -1152,7 +1240,7 @@ class PermissionOverwrite:
|
||||
return f"<PermissionOverwrite id='{self.id}' type='{self.type.name if hasattr(self.type, 'name') else self._type_val}' allow='{self.allow}' deny='{self.deny}'>"
|
||||
|
||||
|
||||
class Guild:
|
||||
class Guild(HashableById):
|
||||
"""Represents a Discord Guild (Server).
|
||||
|
||||
Attributes:
|
||||
@ -1210,9 +1298,24 @@ class Guild:
|
||||
)
|
||||
self.id: str = data["id"]
|
||||
self.name: str = data["name"]
|
||||
self.icon: Optional[str] = data.get("icon")
|
||||
self.splash: Optional[str] = data.get("splash")
|
||||
self.discovery_splash: Optional[str] = data.get("discovery_splash")
|
||||
icon_hash = data.get("icon")
|
||||
self._icon: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/icons/{self.id}/{icon_hash}.png"
|
||||
if icon_hash
|
||||
else None
|
||||
)
|
||||
splash_hash = data.get("splash")
|
||||
self._splash: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/splashes/{self.id}/{splash_hash}.png"
|
||||
if splash_hash
|
||||
else None
|
||||
)
|
||||
discovery_hash = data.get("discovery_splash")
|
||||
self._discovery_splash: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/discovery-splashes/{self.id}/{discovery_hash}.png"
|
||||
if discovery_hash
|
||||
else None
|
||||
)
|
||||
self.owner: Optional[bool] = data.get("owner")
|
||||
self.owner_id: str = data["owner_id"]
|
||||
self.permissions: Optional[str] = data.get("permissions")
|
||||
@ -1249,7 +1352,12 @@ class Guild:
|
||||
self.max_members: Optional[int] = data.get("max_members")
|
||||
self.vanity_url_code: Optional[str] = data.get("vanity_url_code")
|
||||
self.description: Optional[str] = data.get("description")
|
||||
self.banner: Optional[str] = data.get("banner")
|
||||
banner_hash = data.get("banner")
|
||||
self._banner: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/banners/{self.id}/{banner_hash}.png"
|
||||
if banner_hash
|
||||
else None
|
||||
)
|
||||
self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"])
|
||||
self.premium_subscription_count: Optional[int] = data.get(
|
||||
"premium_subscription_count"
|
||||
@ -1357,6 +1465,74 @@ class Guild:
|
||||
def __repr__(self) -> str:
|
||||
return f"<Guild id='{self.id}' name='{self.name}'>"
|
||||
|
||||
@property
|
||||
def icon(self) -> Optional["Asset"]:
|
||||
if self._icon:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._icon, self._client)
|
||||
return None
|
||||
|
||||
@icon.setter
|
||||
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._icon = value
|
||||
elif value is None:
|
||||
self._icon = None
|
||||
else:
|
||||
self._icon = value.url
|
||||
|
||||
@property
|
||||
def splash(self) -> Optional["Asset"]:
|
||||
if self._splash:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._splash, self._client)
|
||||
return None
|
||||
|
||||
@splash.setter
|
||||
def splash(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._splash = value
|
||||
elif value is None:
|
||||
self._splash = None
|
||||
else:
|
||||
self._splash = value.url
|
||||
|
||||
@property
|
||||
def discovery_splash(self) -> Optional["Asset"]:
|
||||
if self._discovery_splash:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._discovery_splash, self._client)
|
||||
return None
|
||||
|
||||
@discovery_splash.setter
|
||||
def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._discovery_splash = value
|
||||
elif value is None:
|
||||
self._discovery_splash = None
|
||||
else:
|
||||
self._discovery_splash = value.url
|
||||
|
||||
@property
|
||||
def banner(self) -> Optional["Asset"]:
|
||||
if self._banner:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._banner, self._client)
|
||||
return None
|
||||
|
||||
@banner.setter
|
||||
def banner(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._banner = value
|
||||
elif value is None:
|
||||
self._banner = None
|
||||
else:
|
||||
self._banner = value.url
|
||||
|
||||
async def fetch_widget(self) -> Dict[str, Any]:
|
||||
"""|coro| Fetch this guild's widget settings."""
|
||||
|
||||
@ -1485,7 +1661,7 @@ class Guild:
|
||||
return cast("CategoryChannel", self._client.parse_channel(data))
|
||||
|
||||
|
||||
class Channel:
|
||||
class Channel(HashableById):
|
||||
"""Base class for Discord channels."""
|
||||
|
||||
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
||||
@ -2089,7 +2265,12 @@ class Webhook:
|
||||
self.guild_id: Optional[str] = data.get("guild_id")
|
||||
self.channel_id: Optional[str] = data.get("channel_id")
|
||||
self.name: Optional[str] = data.get("name")
|
||||
self.avatar: Optional[str] = data.get("avatar")
|
||||
avatar_hash = data.get("avatar")
|
||||
self._avatar: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/webhooks/{self.id}/{avatar_hash}.png"
|
||||
if avatar_hash
|
||||
else None
|
||||
)
|
||||
self.token: Optional[str] = data.get("token")
|
||||
self.application_id: Optional[str] = data.get("application_id")
|
||||
self.url: Optional[str] = data.get("url")
|
||||
@ -2098,6 +2279,25 @@ class Webhook:
|
||||
def __repr__(self) -> str:
|
||||
return f"<Webhook id='{self.id}' name='{self.name}'>"
|
||||
|
||||
@property
|
||||
def avatar(self) -> Optional["Asset"]:
|
||||
"""Return the webhook's avatar as an :class:`Asset`."""
|
||||
|
||||
if self._avatar:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._avatar, self._client)
|
||||
return None
|
||||
|
||||
@avatar.setter
|
||||
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._avatar = value
|
||||
elif value is None:
|
||||
self._avatar = None
|
||||
else:
|
||||
self._avatar = value.url
|
||||
|
||||
@classmethod
|
||||
def from_url(
|
||||
cls, url: str, session: Optional[aiohttp.ClientSession] = None
|
||||
|
19
disagreement/object.py
Normal file
19
disagreement/object.py
Normal file
@ -0,0 +1,19 @@
|
||||
class Object:
|
||||
"""A minimal wrapper around a Discord snowflake ID."""
|
||||
|
||||
__slots__ = ("id",)
|
||||
|
||||
def __init__(self, object_id: int) -> None:
|
||||
self.id = int(object_id)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Object) and self.id == other.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Object id={self.id}>"
|
@ -77,11 +77,14 @@ class VoiceClient:
|
||||
self.secret_key: Optional[Sequence[int]] = None
|
||||
self._server_ip: Optional[str] = None
|
||||
self._server_port: Optional[int] = None
|
||||
self._current_source: Optional[AudioSource] = None
|
||||
self._play_task: Optional[asyncio.Task] = None
|
||||
self._sink: Optional[AudioSink] = None
|
||||
self._ssrc_map: dict[int, int] = {}
|
||||
self._ssrc_lock = threading.Lock()
|
||||
self._current_source: Optional[AudioSource] = None
|
||||
self._play_task: Optional[asyncio.Task] = None
|
||||
self._pause_event = asyncio.Event()
|
||||
self._pause_event.set()
|
||||
self._is_playing = False
|
||||
self._sink: Optional[AudioSink] = None
|
||||
self._ssrc_map: dict[int, int] = {}
|
||||
self._ssrc_lock = threading.Lock()
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._ws is None:
|
||||
@ -189,31 +192,37 @@ class VoiceClient:
|
||||
raise RuntimeError("UDP socket not initialised")
|
||||
self._udp.send(frame)
|
||||
|
||||
async def _play_loop(self) -> None:
|
||||
assert self._current_source is not None
|
||||
try:
|
||||
while True:
|
||||
data = await self._current_source.read()
|
||||
if not data:
|
||||
break
|
||||
volume = getattr(self._current_source, "volume", 1.0)
|
||||
if volume != 1.0:
|
||||
data = _apply_volume(data, volume)
|
||||
await self.send_audio_frame(data)
|
||||
finally:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
self._play_task = None
|
||||
async def _play_loop(self) -> None:
|
||||
assert self._current_source is not None
|
||||
self._is_playing = True
|
||||
try:
|
||||
while True:
|
||||
await self._pause_event.wait()
|
||||
data = await self._current_source.read()
|
||||
if not data:
|
||||
break
|
||||
volume = getattr(self._current_source, "volume", 1.0)
|
||||
if volume != 1.0:
|
||||
data = _apply_volume(data, volume)
|
||||
await self.send_audio_frame(data)
|
||||
finally:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
self._play_task = None
|
||||
self._is_playing = False
|
||||
self._pause_event.set()
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._play_task:
|
||||
self._play_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._play_task
|
||||
self._play_task = None
|
||||
if self._current_source:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
async def stop(self) -> None:
|
||||
if self._play_task:
|
||||
self._play_task.cancel()
|
||||
self._pause_event.set()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._play_task
|
||||
self._play_task = None
|
||||
self._is_playing = False
|
||||
if self._current_source:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
|
||||
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
|
||||
"""|coro| Play an :class:`AudioSource` on the voice connection."""
|
||||
@ -224,10 +233,31 @@ class VoiceClient:
|
||||
if wait:
|
||||
await self._play_task
|
||||
|
||||
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
||||
"""|coro| Stream an audio file or URL using FFmpeg."""
|
||||
|
||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
||||
"""|coro| Stream an audio file or URL using FFmpeg."""
|
||||
|
||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the current audio source."""
|
||||
|
||||
if self._play_task and not self._play_task.done():
|
||||
self._pause_event.clear()
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume playback of a paused source."""
|
||||
|
||||
if self._play_task and not self._play_task.done():
|
||||
self._pause_event.set()
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
"""Return ``True`` if playback is currently paused."""
|
||||
|
||||
return bool(self._play_task and not self._pause_event.is_set())
|
||||
|
||||
def is_playing(self) -> bool:
|
||||
"""Return ``True`` if audio is actively being played."""
|
||||
return self._is_playing and self._pause_event.is_set()
|
||||
|
||||
def listen(self, sink: AudioSink) -> None:
|
||||
"""Start listening to voice and routing to a sink."""
|
||||
|
@ -28,6 +28,13 @@ The cache can be cleared manually if needed:
|
||||
client.cache.clear()
|
||||
```
|
||||
|
||||
## Partial Objects
|
||||
|
||||
Some events only include minimal data for related resources. When only an ``id``
|
||||
is available, Disagreement represents the resource using :class:`~disagreement.Object`.
|
||||
These objects can be compared and used in sets or dictionaries and can be passed
|
||||
to API methods to fetch the full data when needed.
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Components](using_components.md)
|
||||
|
@ -11,7 +11,11 @@ The command handler registers a `help` command automatically. Use it to list all
|
||||
!help ping # shows help for the "ping" command
|
||||
```
|
||||
|
||||
The help command will show each command's brief description if provided.
|
||||
Commands are grouped by their Cog name and paginated so that long help
|
||||
lists are split into multiple messages using the `Paginator` utility.
|
||||
|
||||
If you need custom formatting you can subclass
|
||||
`HelpCommand` and override `send_command_help` or `send_group_help`.
|
||||
|
||||
## Checks
|
||||
|
||||
|
@ -157,6 +157,22 @@ container = Container(
|
||||
A container can itself contain layout and content components, letting you build complex messages.
|
||||
|
||||
|
||||
## Persistent Views
|
||||
|
||||
Views with ``timeout=None`` are persistent. Their ``custom_id`` components are saved to ``persistent_views.json`` so they survive bot restarts.
|
||||
|
||||
```python
|
||||
class MyView(View):
|
||||
@button(label="Press", custom_id="press")
|
||||
async def handle(self, view, inter):
|
||||
await inter.respond("Pressed!")
|
||||
|
||||
client.add_persistent_view(MyView())
|
||||
```
|
||||
|
||||
When the client starts, it loads this file and registers each view again. Remove
|
||||
the file to clear stored views.
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Slash Commands](slash_commands.md)
|
||||
|
@ -6,6 +6,10 @@ Disagreement includes experimental support for connecting to voice channels. You
|
||||
voice = await client.join_voice(guild_id, channel_id)
|
||||
await voice.play_file("welcome.mp3")
|
||||
await voice.play_file("another.mp3") # switch sources while connected
|
||||
voice.pause()
|
||||
voice.resume()
|
||||
if voice.is_playing():
|
||||
print("audio is playing")
|
||||
await voice.close()
|
||||
```
|
||||
|
||||
|
@ -3,7 +3,16 @@ import pytest
|
||||
from disagreement.ext.commands.converters import run_converters
|
||||
from disagreement.ext.commands.core import CommandContext, Command
|
||||
from disagreement.ext.commands.errors import BadArgument
|
||||
from disagreement.models import Message, Member, Role, Guild, User
|
||||
from disagreement.models import (
|
||||
Message,
|
||||
Member,
|
||||
Role,
|
||||
Guild,
|
||||
User,
|
||||
TextChannel,
|
||||
VoiceChannel,
|
||||
PartialEmoji,
|
||||
)
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
@ -11,11 +20,12 @@ from disagreement.enums import (
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
ChannelType,
|
||||
)
|
||||
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.cache import GuildCache, Cache
|
||||
from disagreement.cache import GuildCache, Cache, ChannelCache
|
||||
|
||||
|
||||
class DummyBot(Client):
|
||||
@ -23,10 +33,14 @@ class DummyBot(Client):
|
||||
super().__init__(token="test")
|
||||
self._guilds = GuildCache()
|
||||
self._users = Cache()
|
||||
self._channels = ChannelCache()
|
||||
|
||||
def get_guild(self, guild_id):
|
||||
return self._guilds.get(guild_id)
|
||||
|
||||
def get_channel(self, channel_id):
|
||||
return self._channels.get(channel_id)
|
||||
|
||||
async def fetch_member(self, guild_id, member_id):
|
||||
guild = self._guilds.get(guild_id)
|
||||
return guild.get_member(member_id) if guild else None
|
||||
@ -41,6 +55,9 @@ class DummyBot(Client):
|
||||
async def fetch_user(self, user_id):
|
||||
return self._users.get(user_id)
|
||||
|
||||
async def fetch_channel(self, channel_id):
|
||||
return self._channels.get(channel_id)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def guild_objects():
|
||||
@ -93,12 +110,38 @@ def guild_objects():
|
||||
guild._members.set(member.id, member)
|
||||
guild.roles.append(role)
|
||||
|
||||
return guild, member, role, user
|
||||
text_channel = TextChannel(
|
||||
{
|
||||
"id": "20",
|
||||
"type": ChannelType.GUILD_TEXT.value,
|
||||
"guild_id": guild.id,
|
||||
"permission_overwrites": [],
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
voice_channel = VoiceChannel(
|
||||
{
|
||||
"id": "21",
|
||||
"type": ChannelType.GUILD_VOICE.value,
|
||||
"guild_id": guild.id,
|
||||
"permission_overwrites": [],
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
|
||||
guild._channels.set(text_channel.id, text_channel)
|
||||
guild.text_channels.append(text_channel)
|
||||
guild._channels.set(voice_channel.id, voice_channel)
|
||||
guild.voice_channels.append(voice_channel)
|
||||
bot._channels.set(text_channel.id, text_channel)
|
||||
bot._channels.set(voice_channel.id, voice_channel)
|
||||
|
||||
return guild, member, role, user, text_channel, voice_channel
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def command_context(guild_objects):
|
||||
guild, member, role, _ = guild_objects
|
||||
guild, member, role, _, _, _ = guild_objects
|
||||
bot = guild._client
|
||||
message_data = {
|
||||
"id": "10",
|
||||
@ -121,7 +164,7 @@ def command_context(guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_converter(command_context, guild_objects):
|
||||
_, member, _, _ = guild_objects
|
||||
_, member, _, _, _, _ = guild_objects
|
||||
mention = f"<@!{member.id}>"
|
||||
result = await run_converters(command_context, Member, mention)
|
||||
assert result is member
|
||||
@ -131,7 +174,7 @@ async def test_member_converter(command_context, guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_converter(command_context, guild_objects):
|
||||
_, _, role, _ = guild_objects
|
||||
_, _, role, _, _, _ = guild_objects
|
||||
mention = f"<@&{role.id}>"
|
||||
result = await run_converters(command_context, Role, mention)
|
||||
assert result is role
|
||||
@ -141,7 +184,7 @@ async def test_role_converter(command_context, guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_converter(command_context, guild_objects):
|
||||
_, _, _, user = guild_objects
|
||||
_, _, _, user, _, _ = guild_objects
|
||||
mention = f"<@{user.id}>"
|
||||
result = await run_converters(command_context, User, mention)
|
||||
assert result is user
|
||||
@ -151,11 +194,43 @@ async def test_user_converter(command_context, guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_converter(command_context, guild_objects):
|
||||
guild, _, _, _ = guild_objects
|
||||
guild, _, _, _, _, _ = guild_objects
|
||||
result = await run_converters(command_context, Guild, guild.id)
|
||||
assert result is guild
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_channel_converter(command_context, guild_objects):
|
||||
_, _, _, _, text_channel, _ = guild_objects
|
||||
mention = f"<#{text_channel.id}>"
|
||||
result = await run_converters(command_context, TextChannel, mention)
|
||||
assert result is text_channel
|
||||
result = await run_converters(command_context, TextChannel, text_channel.id)
|
||||
assert result is text_channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_channel_converter(command_context, guild_objects):
|
||||
_, _, _, _, _, voice_channel = guild_objects
|
||||
mention = f"<#{voice_channel.id}>"
|
||||
result = await run_converters(command_context, VoiceChannel, mention)
|
||||
assert result is voice_channel
|
||||
result = await run_converters(command_context, VoiceChannel, voice_channel.id)
|
||||
assert result is voice_channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emoji_converter(command_context):
|
||||
result = await run_converters(command_context, PartialEmoji, "<:smile:1>")
|
||||
assert isinstance(result, PartialEmoji)
|
||||
assert result.id == "1"
|
||||
assert result.name == "smile"
|
||||
|
||||
result = await run_converters(command_context, PartialEmoji, "😄")
|
||||
assert result.id is None
|
||||
assert result.name == "😄"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_converter_no_guild():
|
||||
guild_data = {
|
||||
|
14
tests/test_asset.py
Normal file
14
tests/test_asset.py
Normal file
@ -0,0 +1,14 @@
|
||||
from disagreement.models import User
|
||||
from disagreement.asset import Asset
|
||||
|
||||
|
||||
def test_user_avatar_returns_asset():
|
||||
user = User({"id": "1", "username": "u", "discriminator": "0001", "avatar": "abc"})
|
||||
avatar = user.avatar
|
||||
assert isinstance(avatar, Asset)
|
||||
assert avatar.url == "https://cdn.discordapp.com/avatars/1/abc.png"
|
||||
|
||||
|
||||
def test_user_avatar_none():
|
||||
user = User({"id": "1", "username": "u", "discriminator": "0001"})
|
||||
assert user.avatar is None
|
@ -1,6 +1,60 @@
|
||||
import time
|
||||
|
||||
from disagreement.cache import Cache
|
||||
from disagreement.client import Client
|
||||
from disagreement.caching import MemberCacheFlags
|
||||
from disagreement.enums import (
|
||||
ChannelType,
|
||||
ExplicitContentFilterLevel,
|
||||
GuildNSFWLevel,
|
||||
MFALevel,
|
||||
MessageNotificationLevel,
|
||||
PremiumTier,
|
||||
VerificationLevel,
|
||||
)
|
||||
|
||||
|
||||
def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict:
|
||||
base = {
|
||||
"id": gid,
|
||||
"name": f"g{gid}",
|
||||
"owner_id": "1",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
"channels": [],
|
||||
"members": [],
|
||||
}
|
||||
for i in range(channel_count):
|
||||
base["channels"].append(
|
||||
{
|
||||
"id": f"{gid}-c{i}",
|
||||
"type": ChannelType.GUILD_TEXT.value,
|
||||
"guild_id": gid,
|
||||
"permission_overwrites": [],
|
||||
}
|
||||
)
|
||||
for i in range(member_count):
|
||||
base["members"].append(
|
||||
{
|
||||
"user": {
|
||||
"id": f"{gid}-m{i}",
|
||||
"username": f"u{i}",
|
||||
"discriminator": "0001",
|
||||
},
|
||||
"joined_at": "t",
|
||||
"roles": [],
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
|
||||
def test_cache_store_and_get():
|
||||
@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item():
|
||||
|
||||
assert cache.get_or_fetch("c", fetch) == 3
|
||||
assert called
|
||||
|
||||
|
||||
def test_client_get_all_channels_and_members():
|
||||
client = Client(token="t")
|
||||
client.parse_guild(_guild_payload("1", 2, 2))
|
||||
client.parse_guild(_guild_payload("2", 1, 1))
|
||||
|
||||
channels = {c.id for c in client.get_all_channels()}
|
||||
members = {m.id for m in client.get_all_members()}
|
||||
|
||||
assert channels == {"1-c0", "1-c1", "2-c0"}
|
||||
assert members == {"1-m0", "1-m1", "2-m0"}
|
||||
|
||||
|
||||
def test_client_get_all_members_disabled_cache():
|
||||
client = Client(token="t", member_cache_flags=MemberCacheFlags.none())
|
||||
client.parse_guild(_guild_payload("1", 1, 2))
|
||||
|
||||
assert client.get_all_members() == []
|
||||
|
86
tests/test_hashable_mixin.py
Normal file
86
tests/test_hashable_mixin.py
Normal file
@ -0,0 +1,86 @@
|
||||
import types
|
||||
from disagreement.models import User, Guild, Channel, Message
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
ChannelType,
|
||||
)
|
||||
|
||||
|
||||
def _guild_data(gid="1"):
|
||||
return {
|
||||
"id": gid,
|
||||
"name": "g",
|
||||
"owner_id": gid,
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
}
|
||||
|
||||
|
||||
def _user(uid="1"):
|
||||
return User({"id": uid, "username": "u", "discriminator": "0001"})
|
||||
|
||||
|
||||
def _message(mid="1"):
|
||||
data = {
|
||||
"id": mid,
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
return Message(data, client_instance=types.SimpleNamespace())
|
||||
|
||||
|
||||
def _channel(cid="1"):
|
||||
data = {"id": cid, "type": ChannelType.GUILD_TEXT.value}
|
||||
return Channel(data, client_instance=types.SimpleNamespace())
|
||||
|
||||
|
||||
def test_user_hash_and_eq():
|
||||
a = _user()
|
||||
b = _user()
|
||||
c = _user("2")
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
assert a != c
|
||||
|
||||
|
||||
def test_guild_hash_and_eq():
|
||||
a = Guild(_guild_data(), client_instance=types.SimpleNamespace())
|
||||
b = Guild(_guild_data(), client_instance=types.SimpleNamespace())
|
||||
c = Guild(_guild_data("2"), client_instance=types.SimpleNamespace())
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
assert a != c
|
||||
|
||||
|
||||
def test_channel_hash_and_eq():
|
||||
a = _channel()
|
||||
b = _channel()
|
||||
c = _channel("2")
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
assert a != c
|
||||
|
||||
|
||||
def test_message_hash_and_eq():
|
||||
a = _message()
|
||||
b = _message()
|
||||
c = _message("2")
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
assert a != c
|
@ -1,7 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.commands.core import CommandHandler, Command
|
||||
from disagreement.ext import commands
|
||||
from disagreement.ext.commands.core import CommandHandler, Command, Group
|
||||
from disagreement.models import Message
|
||||
from disagreement.ext.commands.help import HelpCommand
|
||||
|
||||
|
||||
class DummyBot:
|
||||
@ -13,15 +15,21 @@ class DummyBot:
|
||||
return {"id": "1", "channel_id": channel_id, "content": content}
|
||||
|
||||
|
||||
class MyCog(commands.Cog):
|
||||
def __init__(self, client) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
@commands.command()
|
||||
async def foo(self, ctx: commands.CommandContext) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_lists_commands():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
|
||||
async def foo(ctx):
|
||||
pass
|
||||
|
||||
handler.add_command(Command(foo, name="foo", brief="Foo cmd"))
|
||||
handler.add_cog(MyCog(bot))
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
@ -33,6 +41,7 @@ async def test_help_lists_commands():
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("foo" in m for m in bot.sent)
|
||||
assert any("MyCog" in m for m in bot.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -55,3 +64,65 @@ async def test_help_specific_command():
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("Bar desc" in m for m in bot.sent)
|
||||
|
||||
|
||||
class CustomHelp(HelpCommand):
|
||||
async def send_command_help(self, ctx, command):
|
||||
await ctx.send(f"custom {command.name}")
|
||||
|
||||
async def send_group_help(self, ctx, group):
|
||||
await ctx.send(f"group {group.name}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_help_methods():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
handler.remove_command("help")
|
||||
handler.add_command(CustomHelp(handler))
|
||||
|
||||
async def sub(ctx):
|
||||
pass
|
||||
|
||||
group = Group(sub, name="grp")
|
||||
handler.add_command(group)
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "!help grp",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("group grp" in m for m in bot.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_lists_subcommands():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
|
||||
async def root(ctx):
|
||||
pass
|
||||
|
||||
group = Group(root, name="root")
|
||||
|
||||
@group.command(name="child")
|
||||
async def child(ctx):
|
||||
pass
|
||||
|
||||
handler.add_command(group)
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "!help",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("root" in m for m in bot.sent)
|
||||
assert any("child" in m for m in bot.sent)
|
||||
|
15
tests/test_object.py
Normal file
15
tests/test_object.py
Normal file
@ -0,0 +1,15 @@
|
||||
from disagreement.object import Object
|
||||
|
||||
|
||||
def test_object_int():
|
||||
obj = Object(123)
|
||||
assert int(obj) == 123
|
||||
|
||||
|
||||
def test_object_equality_and_hash():
|
||||
a = Object(1)
|
||||
b = Object(1)
|
||||
c = Object(2)
|
||||
assert a == b
|
||||
assert a != c
|
||||
assert hash(a) == hash(b)
|
@ -59,6 +59,17 @@ class DummySource(AudioSource):
|
||||
return b""
|
||||
|
||||
|
||||
class SlowSource(AudioSource):
|
||||
def __init__(self, chunks):
|
||||
self.chunks = list(chunks)
|
||||
|
||||
async def read(self) -> bytes:
|
||||
await asyncio.sleep(0)
|
||||
if self.chunks:
|
||||
return self.chunks.pop(0)
|
||||
return b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_client_handshake():
|
||||
hello = {"d": {"heartbeat_interval": 50}}
|
||||
@ -205,3 +216,49 @@ async def test_voice_client_volume_scaling(monkeypatch):
|
||||
samples[1] = int(samples[1] * 0.5)
|
||||
expected = samples.tobytes()
|
||||
assert udp.sent == [expected]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_and_status():
|
||||
ws = DummyWebSocket(
|
||||
[
|
||||
{"d": {"heartbeat_interval": 50}},
|
||||
{"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}},
|
||||
{"d": {"secret_key": []}},
|
||||
]
|
||||
)
|
||||
udp = DummyUDP()
|
||||
vc = VoiceClient(
|
||||
client=DummyVoiceClient(),
|
||||
endpoint="ws://localhost",
|
||||
session_id="sess",
|
||||
token="tok",
|
||||
guild_id=1,
|
||||
user_id=2,
|
||||
ws=ws,
|
||||
udp=udp,
|
||||
)
|
||||
await vc.connect()
|
||||
vc._heartbeat_task.cancel()
|
||||
|
||||
src = SlowSource([b"a", b"b", b"c"])
|
||||
await vc.play(src, wait=False)
|
||||
|
||||
while not udp.sent:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert vc.is_playing()
|
||||
vc.pause()
|
||||
assert vc.is_paused()
|
||||
await asyncio.sleep(0)
|
||||
sent = len(udp.sent)
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(udp.sent) == sent
|
||||
assert not vc.is_playing()
|
||||
|
||||
vc.resume()
|
||||
assert not vc.is_paused()
|
||||
|
||||
await vc._play_task
|
||||
assert udp.sent == [b"a", b"b", b"c"]
|
||||
assert not vc.is_playing()
|
||||
|
35
tests/test_walk_commands.py
Normal file
35
tests/test_walk_commands.py
Normal file
@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.commands.core import CommandHandler, Command, Group
|
||||
|
||||
|
||||
class DummyBot:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_walk_commands_recurses_groups():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
|
||||
async def root(ctx):
|
||||
pass
|
||||
|
||||
root_group = Group(root, name="root")
|
||||
|
||||
@root_group.command(name="child")
|
||||
async def child(ctx):
|
||||
pass
|
||||
|
||||
@root_group.group(name="sub")
|
||||
async def sub(ctx):
|
||||
pass
|
||||
|
||||
@sub.command(name="leaf")
|
||||
async def leaf(ctx):
|
||||
pass
|
||||
|
||||
handler.add_command(root_group)
|
||||
|
||||
names = [cmd.name for cmd in handler.walk_commands()]
|
||||
assert set(names) == {"help", "root", "child", "sub", "leaf"}
|
@ -204,7 +204,7 @@ async def test_get_webhook_calls_request():
|
||||
|
||||
await http.get_webhook("1")
|
||||
|
||||
http.request.assert_called_once_with("GET", "/webhooks/1")
|
||||
http.request.assert_called_once_with("GET", "/webhooks/1", use_auth_header=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
Loading…
x
Reference in New Issue
Block a user