Compare commits
46 Commits
Author | SHA1 | Date | |
---|---|---|---|
380feddeeb | |||
3beaed8a1b | |||
e5ad932321 | |||
8e88aaec2f | |||
d710487fc2 | |||
506adeca20 | |||
e2061adc55 | |||
132521fa39 | |||
cec747a575 | |||
17751d3b09 | |||
4b3b6aeb45 | |||
aa55aa1d4c | |||
80f64c1f73 | |||
f5f8f6908c | |||
3437050f0e | |||
87d67eb63b | |||
9c10ab0f70 | |||
2008dd33d1 | |||
de40aa2c29 | |||
2056a3ddcf | |||
ccf55adba2 | |||
a335ed972c | |||
2586d3cd0d | |||
7f9647a442 | |||
a222dec661 | |||
3f7c286322 | |||
cc17d11509 | |||
9fabf1fbac | |||
223c86cb78 | |||
98afb89629 | |||
095e7e7192 | |||
c1c5cfb41a | |||
8be234c1f0 | |||
1464937f6f | |||
5d66eb79cc | |||
5d72643390 | |||
c7eb8563de | |||
a68bbe7826 | |||
6eff962682 | |||
f24c1befac | |||
c811e2b578 | |||
9f2fc0857b | |||
775dce0c80 | |||
a93ad432b7 | |||
3a264f4530 | |||
|
a41a301927 |
25
README.md
25
README.md
@ -13,6 +13,8 @@ A Python library for interacting with the Discord API, with a focus on bot devel
|
||||
- Message component helpers
|
||||
- `Message.jump_url` property for quick links to messages
|
||||
- Built-in caching layer
|
||||
- `Guild.me` property to access the bot's member object
|
||||
- Easy CDN asset handling via the `Asset` model
|
||||
- Experimental voice support
|
||||
- Helpful error handling utilities
|
||||
|
||||
@ -61,13 +63,9 @@ if not token:
|
||||
|
||||
intents = disagreement.GatewayIntent.default() | disagreement.GatewayIntent.MESSAGE_CONTENT
|
||||
client = disagreement.Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
|
||||
async def main() -> None:
|
||||
|
||||
client.add_cog(Basics(client))
|
||||
await client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
client.run()
|
||||
```
|
||||
|
||||
### Global Error Handling
|
||||
@ -114,6 +112,10 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying
|
||||
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
|
||||
session parameter supported by ``aiohttp``.
|
||||
|
||||
### Logging Out
|
||||
|
||||
Call ``Client.logout`` to disconnect from the Gateway and clear the current bot token while keeping the HTTP session alive. Assign a new token and call ``connect`` or ``run`` to log back in.
|
||||
|
||||
### Default Allowed Mentions
|
||||
|
||||
Specify default mention behaviour for all outgoing messages when constructing the client:
|
||||
@ -129,6 +131,17 @@ client = disagreement.Client(
|
||||
This dictionary is used whenever ``send_message`` or helpers like ``Message.reply``
|
||||
are called without an explicit ``allowed_mentions`` argument.
|
||||
|
||||
### Working With Assets
|
||||
|
||||
Properties like ``User.avatar`` and ``Guild.icon`` return :class:`disagreement.Asset` objects.
|
||||
Use ``read`` to get the bytes or ``save`` to write them to disk.
|
||||
|
||||
```python
|
||||
user = await client.fetch_user(123)
|
||||
data = await user.avatar.read()
|
||||
await user.avatar.save("avatar.png")
|
||||
```
|
||||
|
||||
### Defining Subcommands with `AppCommandGroup`
|
||||
|
||||
```python
|
||||
|
@ -12,9 +12,10 @@ __title__ = "disagreement"
|
||||
__author__ = "Slipstream"
|
||||
__license__ = "BSD 3-Clause License"
|
||||
__copyright__ = "Copyright 2025 Slipstream"
|
||||
__version__ = "0.8.0"
|
||||
__version__ = "0.8.1"
|
||||
|
||||
from .client import Client, AutoShardedClient
|
||||
from .asset import Asset
|
||||
from .models import (
|
||||
Message,
|
||||
User,
|
||||
@ -39,6 +40,7 @@ from .models import (
|
||||
Container,
|
||||
Guild,
|
||||
)
|
||||
from .object import Object
|
||||
from .voice_client import VoiceClient
|
||||
from .audio import AudioSource, FFmpegAudioSource
|
||||
from .typing import Typing
|
||||
@ -51,7 +53,15 @@ from .errors import (
|
||||
NotFound,
|
||||
)
|
||||
from .color import Color
|
||||
from .utils import utcnow, message_pager
|
||||
from .utils import (
|
||||
utcnow,
|
||||
message_pager,
|
||||
get,
|
||||
find,
|
||||
escape_markdown,
|
||||
escape_mentions,
|
||||
snowflake_time,
|
||||
)
|
||||
from .enums import (
|
||||
GatewayIntent,
|
||||
GatewayOpcode,
|
||||
@ -101,6 +111,7 @@ from .ext.commands import (
|
||||
command,
|
||||
cooldown,
|
||||
has_any_role,
|
||||
is_owner,
|
||||
has_role,
|
||||
listener,
|
||||
max_concurrency,
|
||||
@ -116,6 +127,7 @@ import logging
|
||||
__all__ = [
|
||||
"Client",
|
||||
"AutoShardedClient",
|
||||
"Asset",
|
||||
"Message",
|
||||
"User",
|
||||
"Reaction",
|
||||
@ -137,6 +149,7 @@ __all__ = [
|
||||
"MediaGallery",
|
||||
"MediaGalleryItem",
|
||||
"Container",
|
||||
"Object",
|
||||
"VoiceClient",
|
||||
"AudioSource",
|
||||
"FFmpegAudioSource",
|
||||
@ -149,7 +162,12 @@ __all__ = [
|
||||
"NotFound",
|
||||
"Color",
|
||||
"utcnow",
|
||||
"escape_markdown",
|
||||
"escape_mentions",
|
||||
"message_pager",
|
||||
"get",
|
||||
"find",
|
||||
"snowflake_time",
|
||||
"GatewayIntent",
|
||||
"GatewayOpcode",
|
||||
"ButtonStyle",
|
||||
@ -195,6 +213,7 @@ __all__ = [
|
||||
"command",
|
||||
"cooldown",
|
||||
"has_any_role",
|
||||
"is_owner",
|
||||
"has_role",
|
||||
"listener",
|
||||
"max_concurrency",
|
||||
|
51
disagreement/asset.py
Normal file
51
disagreement/asset.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Utility class for Discord CDN assets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import IO, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import aiohttp # pylint: disable=import-error
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
|
||||
|
||||
class Asset:
|
||||
"""Represents a CDN asset such as an avatar or icon."""
|
||||
|
||||
def __init__(self, url: str, client_instance: Optional["Client"] = None) -> None:
|
||||
self.url = url
|
||||
self._client = client_instance
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""Read the asset's bytes."""
|
||||
|
||||
session: Optional[aiohttp.ClientSession] = None
|
||||
if self._client is not None:
|
||||
await self._client._http._ensure_session() # type: ignore[attr-defined]
|
||||
session = self._client._http._session # type: ignore[attr-defined]
|
||||
if session is None:
|
||||
session = aiohttp.ClientSession()
|
||||
close = True
|
||||
else:
|
||||
close = False
|
||||
async with session.get(self.url) as resp:
|
||||
data = await resp.read()
|
||||
if close:
|
||||
await session.close()
|
||||
return data
|
||||
|
||||
async def save(self, fp: Union[str, os.PathLike[str], IO[bytes]]) -> None:
|
||||
"""Save the asset to the given file path or file-like object."""
|
||||
|
||||
data = await self.read()
|
||||
if isinstance(fp, (str, os.PathLike)):
|
||||
path = os.fspath(fp)
|
||||
with open(path, "wb") as file:
|
||||
file.write(data)
|
||||
else:
|
||||
fp.write(data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset url='{self.url}'>"
|
@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]):
|
||||
|
||||
def _should_cache(self, member: Member) -> bool:
|
||||
"""Determines if a member should be cached based on the flags."""
|
||||
if self.flags.all:
|
||||
if self.flags.all_enabled:
|
||||
return True
|
||||
if self.flags.none:
|
||||
if self.flags.no_flags:
|
||||
return False
|
||||
|
||||
if self.flags.online and member.status != "offline":
|
||||
|
@ -74,6 +74,14 @@ class MemberCacheFlags:
|
||||
for name in self.VALID_FLAGS:
|
||||
yield name, getattr(self, name)
|
||||
|
||||
@property
|
||||
def all_enabled(self) -> bool:
|
||||
return self.value == self.ALL_FLAGS
|
||||
|
||||
@property
|
||||
def no_flags(self) -> bool:
|
||||
return self.value == 0
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
|
@ -4,6 +4,9 @@ The main Client class for interacting with the Discord API.
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import json
|
||||
import os
|
||||
import importlib
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
@ -18,6 +21,10 @@ from typing import (
|
||||
)
|
||||
from types import ModuleType
|
||||
|
||||
PERSISTENT_VIEWS_FILE = "persistent_views.json"
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .http import HTTPClient
|
||||
from .gateway import GatewayClient
|
||||
from .shard_manager import ShardManager
|
||||
@ -36,6 +43,7 @@ from .interactions import Interaction, Snowflake
|
||||
from .error_handler import setup_global_error_handler
|
||||
from .voice_client import VoiceClient
|
||||
from .models import Activity
|
||||
from .utils import utcnow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import (
|
||||
@ -65,6 +73,15 @@ if TYPE_CHECKING:
|
||||
from .ext.app_commands.commands import AppCommand, AppCommandGroup
|
||||
|
||||
|
||||
def _update_list(lst: List[Any], item: Any) -> None:
|
||||
"""Replace an item with the same ID in a list or append if missing."""
|
||||
for i, existing in enumerate(lst):
|
||||
if getattr(existing, "id", None) == getattr(item, "id", None):
|
||||
lst[i] = item
|
||||
return
|
||||
lst.append(item)
|
||||
|
||||
|
||||
class Client:
|
||||
"""
|
||||
Represents a client connection that connects to Discord.
|
||||
@ -90,6 +107,9 @@ class Client:
|
||||
:class:`aiohttp.ClientSession`.
|
||||
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
|
||||
in the cache. When ``None``, the cache size is unlimited.
|
||||
sync_commands_on_ready (bool): If ``True``, automatically call
|
||||
:meth:`Client.sync_application_commands` after the ``READY`` event
|
||||
when :attr:`Client.application_id` is available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -110,7 +130,10 @@ class Client:
|
||||
member_cache_flags: Optional[MemberCacheFlags] = None,
|
||||
message_cache_maxlen: Optional[int] = None,
|
||||
http_options: Optional[Dict[str, Any]] = None,
|
||||
owner_ids: Optional[List[Union[str, int]]] = None,
|
||||
sync_commands_on_ready: bool = True,
|
||||
):
|
||||
|
||||
if not token:
|
||||
raise ValueError("A bot token must be provided.")
|
||||
|
||||
@ -141,12 +164,13 @@ class Client:
|
||||
)
|
||||
self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self)
|
||||
self._gateway: Optional[GatewayClient] = (
|
||||
None # Initialized in run() or connect()
|
||||
None # Initialized in start() or connect()
|
||||
)
|
||||
self.shard_count: Optional[int] = shard_count
|
||||
self.gateway_max_retries: int = gateway_max_retries
|
||||
self.gateway_max_backoff: float = gateway_max_backoff
|
||||
self._shard_manager: Optional[ShardManager] = None
|
||||
self.owner_ids: List[str] = [str(o) for o in owner_ids] if owner_ids else []
|
||||
|
||||
# Initialize CommandHandler
|
||||
self.command_handler: CommandHandler = CommandHandler(
|
||||
@ -164,6 +188,8 @@ class Client:
|
||||
None # The bot's own user object, populated on READY
|
||||
)
|
||||
|
||||
self.start_time: Optional[datetime] = None
|
||||
|
||||
# Internal Caches
|
||||
self._guilds: GuildCache = GuildCache()
|
||||
self._channels: ChannelCache = ChannelCache()
|
||||
@ -174,9 +200,13 @@ class Client:
|
||||
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
||||
self._webhooks: Dict[Snowflake, "Webhook"] = {}
|
||||
|
||||
# Load persistent views stored on disk
|
||||
self._load_persistent_views()
|
||||
|
||||
# Default whether replies mention the user
|
||||
self.mention_replies: bool = mention_replies
|
||||
self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions
|
||||
self.sync_commands_on_ready: bool = sync_commands_on_ready
|
||||
|
||||
# Basic signal handling for graceful shutdown
|
||||
# This might be better handled by the user's application code, but can be a nice default.
|
||||
@ -196,6 +226,39 @@ class Client:
|
||||
"Graceful shutdown via signals might not work as expected on this platform."
|
||||
)
|
||||
|
||||
def _load_persistent_views(self) -> None:
|
||||
"""Load registered persistent views from disk."""
|
||||
if not os.path.isfile(PERSISTENT_VIEWS_FILE):
|
||||
return
|
||||
try:
|
||||
with open(PERSISTENT_VIEWS_FILE, "r") as fp:
|
||||
mapping = json.load(fp)
|
||||
except Exception as e: # pragma: no cover - best effort load
|
||||
print(f"Failed to load persistent views: {e}")
|
||||
return
|
||||
|
||||
for custom_id, path in mapping.items():
|
||||
try:
|
||||
module_name, class_name = path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
view = cls()
|
||||
self._persistent_views[custom_id] = view
|
||||
except Exception as e: # pragma: no cover - best effort load
|
||||
print(f"Failed to initialize persistent view {path}: {e}")
|
||||
|
||||
def _save_persistent_views(self) -> None:
|
||||
"""Persist registered views to disk."""
|
||||
data = {}
|
||||
for custom_id, view in self._persistent_views.items():
|
||||
cls = view.__class__
|
||||
data[custom_id] = f"{cls.__module__}.{cls.__name__}"
|
||||
try:
|
||||
with open(PERSISTENT_VIEWS_FILE, "w") as fp:
|
||||
json.dump(data, fp)
|
||||
except Exception as e: # pragma: no cover - best effort save
|
||||
print(f"Failed to save persistent views: {e}")
|
||||
|
||||
async def _initialize_gateway(self):
|
||||
"""Initializes the GatewayClient if it doesn't exist."""
|
||||
if self._gateway is None:
|
||||
@ -239,6 +302,7 @@ class Client:
|
||||
f"Client connected using {self.shard_count} shards, waiting for READY signal..."
|
||||
)
|
||||
await self.wait_until_ready()
|
||||
self.start_time = utcnow()
|
||||
print("Client is READY!")
|
||||
return
|
||||
|
||||
@ -246,7 +310,7 @@ class Client:
|
||||
assert self._gateway is not None # Should be initialized by now
|
||||
|
||||
retry_delay = 5 # seconds
|
||||
max_retries = 5 # For initial connection attempts by Client.run, Gateway has its own internal retries for some cases.
|
||||
max_retries = 5 # For initial connection attempts by Client.start, Gateway has its own internal retries for some cases.
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
@ -255,6 +319,7 @@ class Client:
|
||||
# and its READY handler will set self._ready_event via dispatcher.
|
||||
print("Client connected to Gateway, waiting for READY signal...")
|
||||
await self.wait_until_ready() # Wait for the READY event from Gateway
|
||||
self.start_time = utcnow()
|
||||
print("Client is READY!")
|
||||
return # Successfully connected and ready
|
||||
except AuthenticationError: # Non-recoverable by retry here
|
||||
@ -276,11 +341,10 @@ class Client:
|
||||
if max_retries == 0: # If max_retries was 0, means no retries attempted
|
||||
raise DisagreementException("Connection failed with 0 retries allowed.")
|
||||
|
||||
async def run(self) -> None:
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
A blocking call that connects the client to Discord and runs until the client is closed.
|
||||
This method is a coroutine.
|
||||
It handles login, Gateway connection, and keeping the connection alive.
|
||||
Connect the client to Discord and run until the client is closed.
|
||||
This method is a coroutine containing the main run loop logic.
|
||||
"""
|
||||
if self._closed:
|
||||
raise DisagreementException("Client is already closed.")
|
||||
@ -348,6 +412,10 @@ class Client:
|
||||
if not self._closed:
|
||||
await self.close()
|
||||
|
||||
def run(self) -> None:
|
||||
"""Synchronously start the client using :func:`asyncio.run`."""
|
||||
asyncio.run(self.start())
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Closes the connection to Discord. This method is a coroutine.
|
||||
@ -368,6 +436,7 @@ class Client:
|
||||
await self._http.close()
|
||||
|
||||
self._ready_event.set() # Ensure any waiters for ready are unblocked
|
||||
self.start_time = None
|
||||
print("Client closed.")
|
||||
|
||||
async def __aenter__(self) -> "Client":
|
||||
@ -395,6 +464,14 @@ class Client:
|
||||
self._gateway = None
|
||||
self._ready_event.clear() # No longer ready if gateway is closed
|
||||
|
||||
async def logout(self) -> None:
|
||||
"""Invalidate the bot token and disconnect from the Gateway."""
|
||||
await self.close_gateway()
|
||||
self.token = ""
|
||||
self._http.token = ""
|
||||
self.user = None
|
||||
self.start_time = None
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Indicates if the client has been closed."""
|
||||
return self._closed
|
||||
@ -416,6 +493,17 @@ class Client:
|
||||
latency = getattr(self._gateway, "latency_ms", None)
|
||||
return round(latency, 2) if latency is not None else None
|
||||
|
||||
@property
|
||||
def guilds(self) -> List["Guild"]:
|
||||
"""Returns all guilds from the internal cache."""
|
||||
return self._guilds.values()
|
||||
|
||||
def uptime(self) -> Optional[timedelta]:
|
||||
"""Return the duration since the client connected, or ``None`` if not connected."""
|
||||
if self.start_time is None:
|
||||
return None
|
||||
return utcnow() - self.start_time
|
||||
|
||||
async def wait_until_ready(self) -> None:
|
||||
"""|coro|
|
||||
Waits until the client is fully connected to Discord and the initial state is processed.
|
||||
@ -563,6 +651,11 @@ class Client:
|
||||
return
|
||||
await self.command_handler.process_commands(message)
|
||||
|
||||
async def get_context(self, message: "Message") -> Optional["CommandContext"]:
|
||||
"""Return a :class:`CommandContext` for ``message`` without executing the command."""
|
||||
|
||||
return await self.command_handler.get_context(message)
|
||||
|
||||
# --- Command Framework Methods ---
|
||||
|
||||
def add_cog(self, cog: Cog) -> None:
|
||||
@ -617,6 +710,11 @@ class Client:
|
||||
# For now, assuming name is sufficient for removal from the handler's flat list.
|
||||
return removed_cog
|
||||
|
||||
def get_cog(self, name: str) -> Optional[Cog]:
|
||||
"""Return a loaded cog by name if present."""
|
||||
|
||||
return self.command_handler.get_cog(name)
|
||||
|
||||
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
|
||||
"""
|
||||
A decorator that adds a global check to the bot.
|
||||
@ -762,7 +860,12 @@ class Client:
|
||||
def parse_channel(self, data: Dict[str, Any]) -> "Channel":
|
||||
"""Parses channel data and returns a Channel object, updating caches."""
|
||||
|
||||
from .models import channel_factory
|
||||
from .models import (
|
||||
channel_factory,
|
||||
TextChannel,
|
||||
VoiceChannel,
|
||||
CategoryChannel,
|
||||
)
|
||||
|
||||
channel = channel_factory(data, self)
|
||||
self._channels.set(channel.id, channel)
|
||||
@ -770,6 +873,12 @@ class Client:
|
||||
guild = self._guilds.get(channel.guild_id)
|
||||
if guild:
|
||||
guild._channels.set(channel.id, channel)
|
||||
if isinstance(channel, TextChannel):
|
||||
_update_list(guild.text_channels, channel)
|
||||
elif isinstance(channel, VoiceChannel):
|
||||
_update_list(guild.voice_channels, channel)
|
||||
elif isinstance(channel, CategoryChannel):
|
||||
_update_list(guild.category_channels, channel)
|
||||
return channel
|
||||
|
||||
def parse_message(self, data: Dict[str, Any]) -> "Message":
|
||||
@ -930,7 +1039,8 @@ class Client:
|
||||
"""Parses guild data and returns a Guild object, updating cache."""
|
||||
from .models import Guild
|
||||
|
||||
guild = Guild(data, client_instance=self)
|
||||
shard_id = data.get("shard_id")
|
||||
guild = Guild(data, client_instance=self, shard_id=shard_id)
|
||||
self._guilds.set(guild.id, guild)
|
||||
|
||||
presences = {p["user"]["id"]: p for p in data.get("presences", [])}
|
||||
@ -1343,6 +1453,26 @@ class Client:
|
||||
|
||||
return self._messages.get(message_id)
|
||||
|
||||
def get_all_channels(self) -> List["Channel"]:
|
||||
"""Return all channels cached in every guild."""
|
||||
|
||||
channels: List["Channel"] = []
|
||||
for guild in self._guilds.values():
|
||||
channels.extend(guild._channels.values())
|
||||
return channels
|
||||
|
||||
def get_all_members(self) -> List["Member"]:
|
||||
"""Return all cached members across all guilds.
|
||||
|
||||
When member caching is disabled via :class:`MemberCacheFlags.none`, this
|
||||
list will always be empty.
|
||||
"""
|
||||
|
||||
members: List["Member"] = []
|
||||
for guild in self._guilds.values():
|
||||
members.extend(guild._members.values())
|
||||
return members
|
||||
|
||||
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
|
||||
"""Fetches a guild by ID from Discord and caches it."""
|
||||
|
||||
@ -1449,6 +1579,23 @@ class Client:
|
||||
|
||||
await self._http.delete_webhook(webhook_id)
|
||||
|
||||
async def fetch_webhook(self, webhook_id: Snowflake) -> Optional["Webhook"]:
|
||||
"""|coro| Fetch a webhook by ID."""
|
||||
|
||||
if self._closed:
|
||||
raise DisagreementException("Client is closed.")
|
||||
|
||||
cached = self._webhooks.get(webhook_id)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
try:
|
||||
data = await self._http.get_webhook(webhook_id)
|
||||
return self.parse_webhook(data)
|
||||
except DisagreementException as e:
|
||||
print(f"Failed to fetch webhook {webhook_id}: {e}")
|
||||
return None
|
||||
|
||||
async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]:
|
||||
"""|coro| Fetch all templates for a guild."""
|
||||
|
||||
@ -1582,6 +1729,19 @@ class Client:
|
||||
|
||||
await self._http.delete_invite(code)
|
||||
|
||||
async def fetch_invite(self, code: Snowflake) -> Optional["Invite"]:
|
||||
"""|coro| Fetch a single invite by code."""
|
||||
|
||||
if self._closed:
|
||||
raise DisagreementException("Client is closed.")
|
||||
|
||||
try:
|
||||
data = await self._http.get_invite(code)
|
||||
return self.parse_invite(data)
|
||||
except DisagreementException as e:
|
||||
print(f"Failed to fetch invite {code}: {e}")
|
||||
return None
|
||||
|
||||
async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]:
|
||||
"""|coro| Fetch all invites for a channel."""
|
||||
|
||||
@ -1622,6 +1782,8 @@ class Client:
|
||||
)
|
||||
self._persistent_views[item.custom_id] = view
|
||||
|
||||
self._save_persistent_views()
|
||||
|
||||
# --- Application Command Methods ---
|
||||
async def process_interaction(self, interaction: Interaction) -> None:
|
||||
"""Internal method to process an interaction from the gateway."""
|
||||
@ -1665,16 +1827,6 @@ class Client:
|
||||
"Ensure the client is connected and READY."
|
||||
)
|
||||
return
|
||||
if not self.is_ready():
|
||||
print(
|
||||
"Warning: Client is not ready. Waiting for client to be ready before syncing commands."
|
||||
)
|
||||
await self.wait_until_ready()
|
||||
if not self.application_id:
|
||||
print(
|
||||
"Error: application_id still not set after client is ready. Cannot sync commands."
|
||||
)
|
||||
return
|
||||
|
||||
await self.app_command_handler.sync_commands(
|
||||
application_id=self.application_id, guild_id=guild_id
|
||||
@ -1695,6 +1847,16 @@ class Client:
|
||||
|
||||
pass
|
||||
|
||||
async def on_connect(self) -> None:
|
||||
"""|coro| Called when the WebSocket connection opens."""
|
||||
|
||||
pass
|
||||
|
||||
async def on_disconnect(self) -> None:
|
||||
"""|coro| Called when the WebSocket connection closes."""
|
||||
|
||||
pass
|
||||
|
||||
async def on_app_command_error(
|
||||
self, context: AppCommandContext, error: Exception
|
||||
) -> None:
|
||||
|
@ -268,6 +268,8 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum
|
||||
VERIFIED = "VERIFIED"
|
||||
VIP_REGIONS = "VIP_REGIONS"
|
||||
WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED"
|
||||
SOUNDBOARD = "SOUNDBOARD"
|
||||
VIDEO_QUALITY_720_60FPS = "VIDEO_QUALITY_720_60FPS"
|
||||
# Add more as they become known or needed
|
||||
|
||||
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
|
||||
|
@ -61,6 +61,11 @@ class EventDispatcher:
|
||||
"GUILD_ROLE_UPDATE": self._parse_guild_role_update,
|
||||
"TYPING_START": self._parse_typing_start,
|
||||
"VOICE_STATE_UPDATE": self._parse_voice_state_update,
|
||||
"THREAD_CREATE": self._parse_thread_create,
|
||||
"THREAD_UPDATE": self._parse_thread_update,
|
||||
"THREAD_DELETE": self._parse_thread_delete,
|
||||
"INVITE_CREATE": self._parse_invite_create,
|
||||
"INVITE_DELETE": self._parse_invite_delete,
|
||||
}
|
||||
|
||||
def _parse_message_create(self, data: Dict[str, Any]) -> Message:
|
||||
@ -165,6 +170,43 @@ class EventDispatcher:
|
||||
|
||||
return GuildRoleUpdate(data, client_instance=self._client)
|
||||
|
||||
def _parse_thread_create(self, data: Dict[str, Any]):
|
||||
"""Parses THREAD_CREATE into a Thread object and updates caches."""
|
||||
|
||||
return self._client.parse_channel(data)
|
||||
|
||||
def _parse_thread_update(self, data: Dict[str, Any]):
|
||||
"""Parses THREAD_UPDATE into a Thread object."""
|
||||
|
||||
return self._client.parse_channel(data)
|
||||
|
||||
def _parse_thread_delete(self, data: Dict[str, Any]):
|
||||
"""Parses THREAD_DELETE, removing the thread from caches."""
|
||||
|
||||
thread = self._client.parse_channel(data)
|
||||
thread_id = data.get("id")
|
||||
if thread_id:
|
||||
self._client._channels.invalidate(thread_id)
|
||||
guild_id = data.get("guild_id")
|
||||
if guild_id:
|
||||
guild = self._client._guilds.get(guild_id)
|
||||
if guild:
|
||||
guild._channels.invalidate(thread_id)
|
||||
guild._threads.pop(thread_id, None)
|
||||
return thread
|
||||
|
||||
def _parse_invite_create(self, data: Dict[str, Any]):
|
||||
"""Parses INVITE_CREATE into an Invite object."""
|
||||
|
||||
return self._client.parse_invite(data)
|
||||
|
||||
def _parse_invite_delete(self, data: Dict[str, Any]):
|
||||
"""Parses INVITE_DELETE into an InviteDelete model."""
|
||||
|
||||
from .models import InviteDelete
|
||||
|
||||
return InviteDelete(data)
|
||||
|
||||
# Potentially add _parse_user for events that directly provide a full user object
|
||||
# def _parse_user_update(self, data: Dict[str, Any]) -> User:
|
||||
# return User(data=data)
|
||||
|
@ -18,6 +18,7 @@ from .decorators import (
|
||||
requires_permissions,
|
||||
has_role,
|
||||
has_any_role,
|
||||
is_owner,
|
||||
)
|
||||
from .errors import (
|
||||
CommandError,
|
||||
@ -49,6 +50,7 @@ __all__ = [
|
||||
"requires_permissions",
|
||||
"has_role",
|
||||
"has_any_role",
|
||||
"is_owner",
|
||||
# Errors
|
||||
"CommandError",
|
||||
"CommandNotFound",
|
||||
|
@ -6,7 +6,16 @@ import re
|
||||
import inspect
|
||||
|
||||
from .errors import BadArgument
|
||||
from disagreement.models import Member, Guild, Role
|
||||
from disagreement.models import (
|
||||
Member,
|
||||
Guild,
|
||||
Role,
|
||||
User,
|
||||
TextChannel,
|
||||
VoiceChannel,
|
||||
Emoji,
|
||||
PartialEmoji,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import CommandContext
|
||||
@ -143,6 +152,97 @@ class GuildConverter(Converter["Guild"]):
|
||||
raise BadArgument(f"Guild '{argument}' not found.")
|
||||
|
||||
|
||||
class UserConverter(Converter["User"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "User":
|
||||
match = re.match(r"<@!?(\d+)>$", argument)
|
||||
user_id = match.group(1) if match else argument
|
||||
|
||||
user = ctx.bot._users.get(user_id)
|
||||
if user:
|
||||
return user
|
||||
|
||||
user = await ctx.bot.fetch_user(user_id)
|
||||
if user:
|
||||
return user
|
||||
raise BadArgument(f"User '{argument}' not found.")
|
||||
|
||||
|
||||
class TextChannelConverter(Converter["TextChannel"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "TextChannel":
|
||||
if not ctx.message.guild_id:
|
||||
raise BadArgument("TextChannel converter requires guild context.")
|
||||
|
||||
match = re.match(r"<#(?P<id>\d+)>$", argument)
|
||||
channel_id = match.group("id") if match else argument
|
||||
|
||||
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||
if guild:
|
||||
channel = guild.get_channel(channel_id)
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel
|
||||
|
||||
channel = (
|
||||
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
|
||||
)
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel
|
||||
|
||||
if hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(channel_id)
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel
|
||||
|
||||
raise BadArgument(f"Text channel '{argument}' not found.")
|
||||
|
||||
|
||||
class VoiceChannelConverter(Converter["VoiceChannel"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "VoiceChannel":
|
||||
if not ctx.message.guild_id:
|
||||
raise BadArgument("VoiceChannel converter requires guild context.")
|
||||
|
||||
match = re.match(r"<#(?P<id>\d+)>$", argument)
|
||||
channel_id = match.group("id") if match else argument
|
||||
|
||||
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||
if guild:
|
||||
channel = guild.get_channel(channel_id)
|
||||
if isinstance(channel, VoiceChannel):
|
||||
return channel
|
||||
|
||||
channel = (
|
||||
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
|
||||
)
|
||||
if isinstance(channel, VoiceChannel):
|
||||
return channel
|
||||
|
||||
if hasattr(ctx.bot, "fetch_channel"):
|
||||
channel = await ctx.bot.fetch_channel(channel_id)
|
||||
if isinstance(channel, VoiceChannel):
|
||||
return channel
|
||||
|
||||
raise BadArgument(f"Voice channel '{argument}' not found.")
|
||||
|
||||
|
||||
class EmojiConverter(Converter["PartialEmoji"]):
|
||||
_CUSTOM_RE = re.compile(r"<(?P<animated>a)?:(?P<name>[^:]+):(?P<id>\d+)>$")
|
||||
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "PartialEmoji":
|
||||
match = self._CUSTOM_RE.match(argument)
|
||||
if match:
|
||||
return PartialEmoji(
|
||||
{
|
||||
"id": match.group("id"),
|
||||
"name": match.group("name"),
|
||||
"animated": bool(match.group("animated")),
|
||||
}
|
||||
)
|
||||
|
||||
if argument:
|
||||
return PartialEmoji({"id": None, "name": argument})
|
||||
|
||||
raise BadArgument(f"Emoji '{argument}' not found.")
|
||||
|
||||
|
||||
# Default converters mapping
|
||||
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||
int: IntConverter(),
|
||||
@ -152,7 +252,11 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||
Member: MemberConverter(),
|
||||
Guild: GuildConverter(),
|
||||
Role: RoleConverter(),
|
||||
# User: UserConverter(), # Add when User model and converter are ready
|
||||
User: UserConverter(),
|
||||
TextChannel: TextChannelConverter(),
|
||||
VoiceChannel: VoiceChannelConverter(),
|
||||
PartialEmoji: EmojiConverter(),
|
||||
Emoji: EmojiConverter(),
|
||||
}
|
||||
|
||||
|
||||
|
@ -82,6 +82,13 @@ class GroupMixin:
|
||||
def get_command(self, name: str) -> Optional["Command"]:
|
||||
return self.commands.get(name.lower())
|
||||
|
||||
def walk_commands(self):
|
||||
"""Yield all commands in this group recursively."""
|
||||
for cmd in dict.fromkeys(self.commands.values()):
|
||||
yield cmd
|
||||
if isinstance(cmd, Group):
|
||||
yield from cmd.walk_commands()
|
||||
|
||||
|
||||
class Command(GroupMixin):
|
||||
"""
|
||||
@ -366,6 +373,18 @@ class CommandHandler:
|
||||
def get_command(self, name: str) -> Optional[Command]:
|
||||
return self.commands.get(name.lower())
|
||||
|
||||
def walk_commands(self):
|
||||
"""Yield every registered command, including subcommands."""
|
||||
for cmd in dict.fromkeys(self.commands.values()):
|
||||
yield cmd
|
||||
if isinstance(cmd, Group):
|
||||
yield from cmd.walk_commands()
|
||||
|
||||
def get_cog(self, name: str) -> Optional["Cog"]:
|
||||
"""Return a loaded cog by name if present."""
|
||||
|
||||
return self.cogs.get(name)
|
||||
|
||||
def add_cog(self, cog_to_add: "Cog") -> None:
|
||||
from .cog import Cog
|
||||
|
||||
@ -638,6 +657,78 @@ class CommandHandler:
|
||||
|
||||
return args_list, kwargs_dict
|
||||
|
||||
async def get_context(self, message: "Message") -> Optional[CommandContext]:
|
||||
"""Parse a message and return a :class:`CommandContext` without executing the command.
|
||||
|
||||
Returns ``None`` if the message does not invoke a command."""
|
||||
|
||||
if not message.content:
|
||||
return None
|
||||
|
||||
prefix_to_use = await self.get_prefix(message)
|
||||
if not prefix_to_use:
|
||||
return None
|
||||
|
||||
actual_prefix: Optional[str] = None
|
||||
if isinstance(prefix_to_use, list):
|
||||
for p in prefix_to_use:
|
||||
if message.content.startswith(p):
|
||||
actual_prefix = p
|
||||
break
|
||||
if not actual_prefix:
|
||||
return None
|
||||
elif isinstance(prefix_to_use, str):
|
||||
if message.content.startswith(prefix_to_use):
|
||||
actual_prefix = prefix_to_use
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
if actual_prefix is None:
|
||||
return None
|
||||
|
||||
view = StringView(message.content[len(actual_prefix) :])
|
||||
|
||||
command_name = view.get_word()
|
||||
if not command_name:
|
||||
return None
|
||||
|
||||
command = self.get_command(command_name)
|
||||
if not command:
|
||||
return None
|
||||
|
||||
invoked_with = command_name
|
||||
|
||||
if isinstance(command, Group):
|
||||
view.skip_whitespace()
|
||||
potential_subcommand = view.get_word()
|
||||
if potential_subcommand:
|
||||
subcommand = command.get_command(potential_subcommand)
|
||||
if subcommand:
|
||||
command = subcommand
|
||||
invoked_with += f" {potential_subcommand}"
|
||||
elif command.invoke_without_command:
|
||||
view.index -= len(potential_subcommand) + view.previous
|
||||
else:
|
||||
raise CommandNotFound(
|
||||
f"Subcommand '{potential_subcommand}' not found."
|
||||
)
|
||||
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=self.client,
|
||||
prefix=actual_prefix,
|
||||
command=command,
|
||||
invoked_with=invoked_with,
|
||||
cog=command.cog,
|
||||
)
|
||||
|
||||
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
|
||||
ctx.args = parsed_args
|
||||
ctx.kwargs = parsed_kwargs
|
||||
return ctx
|
||||
|
||||
async def process_commands(self, message: "Message") -> None:
|
||||
if not message.content:
|
||||
return
|
||||
|
@ -292,3 +292,19 @@ def has_any_role(
|
||||
)
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def is_owner() -> (
|
||||
Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]
|
||||
):
|
||||
"""Check that the invoking user is listed as a bot owner."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckFailure
|
||||
|
||||
owner_ids = getattr(ctx.bot, "owner_ids", [])
|
||||
if str(ctx.author.id) not in {str(o) for o in owner_ids}:
|
||||
raise CheckFailure("This command can only be used by the bot owner.")
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
@ -1,6 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional
|
||||
|
||||
from .core import Command, CommandContext, CommandHandler
|
||||
from ...utils import Paginator
|
||||
from .core import Command, CommandContext, CommandHandler, Group
|
||||
|
||||
|
||||
class HelpCommand(Command):
|
||||
@ -15,17 +17,22 @@ class HelpCommand(Command):
|
||||
if not cmd or cmd.name.lower() != command.lower():
|
||||
await ctx.send(f"Command '{command}' not found.")
|
||||
return
|
||||
if isinstance(cmd, Group):
|
||||
await self.send_group_help(ctx, cmd)
|
||||
elif cmd:
|
||||
description = cmd.description or cmd.brief or "No description provided."
|
||||
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
|
||||
else:
|
||||
lines: List[str] = []
|
||||
for registered in dict.fromkeys(handler.commands.values()):
|
||||
for registered in handler.walk_commands():
|
||||
brief = registered.brief or registered.description or ""
|
||||
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
|
||||
if lines:
|
||||
await ctx.send("\n".join(lines))
|
||||
else:
|
||||
await ctx.send("No commands available.")
|
||||
await self.send_command_help(ctx, cmd)
|
||||
else:
|
||||
await self.send_bot_help(ctx)
|
||||
|
||||
super().__init__(
|
||||
callback,
|
||||
@ -33,3 +40,42 @@ class HelpCommand(Command):
|
||||
brief="Show command help.",
|
||||
description="Displays help for commands.",
|
||||
)
|
||||
|
||||
async def send_bot_help(self, ctx: CommandContext) -> None:
|
||||
groups = defaultdict(list)
|
||||
for cmd in dict.fromkeys(self.handler.commands.values()):
|
||||
key = cmd.cog.cog_name if cmd.cog else "No Category"
|
||||
groups[key].append(cmd)
|
||||
|
||||
paginator = Paginator()
|
||||
for cog_name, cmds in groups.items():
|
||||
paginator.add_line(f"**{cog_name}**")
|
||||
for cmd in cmds:
|
||||
brief = cmd.brief or cmd.description or ""
|
||||
paginator.add_line(f"{ctx.prefix}{cmd.name} - {brief}".strip())
|
||||
paginator.add_line("")
|
||||
|
||||
pages = paginator.pages
|
||||
if not pages:
|
||||
await ctx.send("No commands available.")
|
||||
return
|
||||
for page in pages:
|
||||
await ctx.send(page)
|
||||
|
||||
async def send_command_help(self, ctx: CommandContext, command: Command) -> None:
|
||||
description = command.description or command.brief or "No description provided."
|
||||
await ctx.send(f"**{ctx.prefix}{command.name}**\n{description}")
|
||||
|
||||
async def send_group_help(self, ctx: CommandContext, group: Group) -> None:
|
||||
paginator = Paginator()
|
||||
description = group.description or group.brief or "No description provided."
|
||||
paginator.add_line(f"**{ctx.prefix}{group.name}**\n{description}")
|
||||
if group.commands:
|
||||
for sub in dict.fromkeys(group.commands.values()):
|
||||
brief = sub.brief or sub.description or ""
|
||||
paginator.add_line(
|
||||
f"{ctx.prefix}{group.name} {sub.name} - {brief}".strip()
|
||||
)
|
||||
|
||||
for page in paginator.pages:
|
||||
await ctx.send(page)
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from importlib import import_module
|
||||
import inspect
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Dict
|
||||
from typing import Any, Coroutine, Dict, cast
|
||||
|
||||
__all__ = ["load_extension", "unload_extension", "reload_extension"]
|
||||
|
||||
@ -25,7 +27,20 @@ def load_extension(name: str) -> ModuleType:
|
||||
if not hasattr(module, "setup"):
|
||||
raise ImportError(f"Extension '{name}' does not define a setup function")
|
||||
|
||||
module.setup()
|
||||
result = module.setup()
|
||||
if inspect.isawaitable(result):
|
||||
coro = cast(Coroutine[Any, Any, Any], result)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
asyncio.run(coro)
|
||||
else:
|
||||
if loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
future.result()
|
||||
else:
|
||||
loop.run_until_complete(coro)
|
||||
|
||||
_loaded_extensions[name] = module
|
||||
return module
|
||||
|
||||
@ -38,7 +53,19 @@ def unload_extension(name: str) -> None:
|
||||
raise ValueError(f"Extension '{name}' is not loaded")
|
||||
|
||||
if hasattr(module, "teardown"):
|
||||
module.teardown()
|
||||
result = module.teardown()
|
||||
if inspect.isawaitable(result):
|
||||
coro = cast(Coroutine[Any, Any, Any], result)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
asyncio.run(coro)
|
||||
else:
|
||||
if loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
future.result()
|
||||
else:
|
||||
loop.run_until_complete(coro)
|
||||
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
@ -23,6 +23,7 @@ class Task:
|
||||
) -> None:
|
||||
self._coro = coro
|
||||
self._task: Optional[asyncio.Task[None]] = None
|
||||
self._current_loop = 0
|
||||
if time_of_day is not None and (
|
||||
seconds or minutes or hours or delta is not None
|
||||
):
|
||||
@ -68,6 +69,7 @@ class Task:
|
||||
await _maybe_call(self._on_error, exc)
|
||||
else:
|
||||
raise
|
||||
self._current_loop += 1
|
||||
|
||||
first = False
|
||||
except asyncio.CancelledError:
|
||||
@ -78,6 +80,7 @@ class Task:
|
||||
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||
if self._task is None or self._task.done():
|
||||
self._current_loop = 0
|
||||
self._task = asyncio.create_task(self._run(*args, **kwargs))
|
||||
return self._task
|
||||
|
||||
@ -90,6 +93,34 @@ class Task:
|
||||
def running(self) -> bool:
|
||||
return self._task is not None and not self._task.done()
|
||||
|
||||
@property
|
||||
def current_loop(self) -> int:
|
||||
return self._current_loop
|
||||
|
||||
def change_interval(
|
||||
self,
|
||||
*,
|
||||
seconds: float = 0.0,
|
||||
minutes: float = 0.0,
|
||||
hours: float = 0.0,
|
||||
delta: Optional[datetime.timedelta] = None,
|
||||
time_of_day: Optional[datetime.time] = None,
|
||||
) -> None:
|
||||
if time_of_day is not None and (
|
||||
seconds or minutes or hours or delta is not None
|
||||
):
|
||||
raise ValueError("time_of_day cannot be used with an interval")
|
||||
|
||||
if delta is not None:
|
||||
if not isinstance(delta, datetime.timedelta):
|
||||
raise TypeError("delta must be a datetime.timedelta")
|
||||
interval_seconds = delta.total_seconds()
|
||||
else:
|
||||
interval_seconds = seconds + minutes * 60.0 + hours * 3600.0
|
||||
|
||||
self._seconds = float(interval_seconds)
|
||||
self._time_of_day = time_of_day
|
||||
|
||||
|
||||
async def _maybe_call(
|
||||
func: Callable[[Exception], Awaitable[None] | None], exc: Exception
|
||||
@ -181,10 +212,37 @@ class _Loop:
|
||||
if self._task is not None:
|
||||
self._task.stop()
|
||||
|
||||
def change_interval(
|
||||
self,
|
||||
*,
|
||||
seconds: float = 0.0,
|
||||
minutes: float = 0.0,
|
||||
hours: float = 0.0,
|
||||
delta: Optional[datetime.timedelta] = None,
|
||||
time_of_day: Optional[datetime.time] = None,
|
||||
) -> None:
|
||||
self.seconds = seconds
|
||||
self.minutes = minutes
|
||||
self.hours = hours
|
||||
self.delta = delta
|
||||
self.time_of_day = time_of_day
|
||||
if self._task is not None:
|
||||
self._task.change_interval(
|
||||
seconds=seconds,
|
||||
minutes=minutes,
|
||||
hours=hours,
|
||||
delta=delta,
|
||||
time_of_day=time_of_day,
|
||||
)
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self._task.running if self._task else False
|
||||
|
||||
@property
|
||||
def current_loop(self) -> int:
|
||||
return self._task.current_loop if self._task else 0
|
||||
|
||||
|
||||
class _BoundLoop:
|
||||
def __init__(self, parent: _Loop, owner: Any) -> None:
|
||||
@ -202,6 +260,27 @@ class _BoundLoop:
|
||||
def running(self) -> bool:
|
||||
return self._parent.running
|
||||
|
||||
def change_interval(
|
||||
self,
|
||||
*,
|
||||
seconds: float = 0.0,
|
||||
minutes: float = 0.0,
|
||||
hours: float = 0.0,
|
||||
delta: Optional[datetime.timedelta] = None,
|
||||
time_of_day: Optional[datetime.time] = None,
|
||||
) -> None:
|
||||
self._parent.change_interval(
|
||||
seconds=seconds,
|
||||
minutes=minutes,
|
||||
hours=hours,
|
||||
delta=delta,
|
||||
time_of_day=time_of_day,
|
||||
)
|
||||
|
||||
@property
|
||||
def current_loop(self) -> int:
|
||||
return self._parent.current_loop
|
||||
|
||||
|
||||
def loop(
|
||||
*,
|
||||
|
@ -334,7 +334,19 @@ class GatewayClient:
|
||||
self._resume_gateway_url,
|
||||
)
|
||||
|
||||
# The client is now ready for operations. Set the event before dispatching to user code.
|
||||
self._client_instance._ready_event.set()
|
||||
logger.info("Client is now marked as ready.")
|
||||
|
||||
if isinstance(raw_event_d_payload, dict) and self._shard_id is not None:
|
||||
raw_event_d_payload["shard_id"] = self._shard_id
|
||||
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
|
||||
|
||||
if (
|
||||
getattr(self._client_instance, "sync_commands_on_ready", True)
|
||||
and self._client_instance.application_id
|
||||
):
|
||||
asyncio.create_task(self._client_instance.sync_application_commands())
|
||||
elif event_name == "GUILD_MEMBERS_CHUNK":
|
||||
if isinstance(raw_event_d_payload, dict):
|
||||
nonce = raw_event_d_payload.get("nonce")
|
||||
@ -384,6 +396,8 @@ class GatewayClient:
|
||||
event_data_to_dispatch = (
|
||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||
)
|
||||
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
|
||||
event_data_to_dispatch["shard_id"] = self._shard_id
|
||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||
await self._dispatcher.dispatch(
|
||||
"SHARD_RESUME", {"shard_id": self._shard_id}
|
||||
@ -394,6 +408,8 @@ class GatewayClient:
|
||||
event_data_to_dispatch = (
|
||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||
)
|
||||
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
|
||||
event_data_to_dispatch["shard_id"] = self._shard_id
|
||||
|
||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||
else:
|
||||
@ -553,6 +569,7 @@ class GatewayClient:
|
||||
await self._dispatcher.dispatch(
|
||||
"SHARD_CONNECT", {"shard_id": self._shard_id}
|
||||
)
|
||||
await self._dispatcher.dispatch("CONNECT", {"shard_id": self._shard_id})
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
raise GatewayException(
|
||||
@ -608,6 +625,7 @@ class GatewayClient:
|
||||
await self._dispatcher.dispatch(
|
||||
"SHARD_DISCONNECT", {"shard_id": self._shard_id}
|
||||
)
|
||||
await self._dispatcher.dispatch("DISCONNECT", {"shard_id": self._shard_id})
|
||||
|
||||
@property
|
||||
def latency(self) -> Optional[float]:
|
||||
|
@ -663,6 +663,15 @@ class HTTPClient:
|
||||
|
||||
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
|
||||
|
||||
async def crosspost_message(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||
) -> Dict[str, Any]:
|
||||
"""Crossposts a message to any following channels."""
|
||||
|
||||
return await self.request(
|
||||
"POST", f"/channels/{channel_id}/messages/{message_id}/crosspost"
|
||||
)
|
||||
|
||||
async def delete_channel(
|
||||
self, channel_id: str, reason: Optional[str] = None
|
||||
) -> None:
|
||||
@ -702,6 +711,22 @@ class HTTPClient:
|
||||
"""Fetches a channel by ID."""
|
||||
return await self.request("GET", f"/channels/{channel_id}")
|
||||
|
||||
async def create_guild_channel(
|
||||
self,
|
||||
guild_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
reason: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Creates a new channel in the specified guild."""
|
||||
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
return await self.request(
|
||||
"POST",
|
||||
f"/guilds/{guild_id}/channels",
|
||||
payload=payload,
|
||||
custom_headers=headers,
|
||||
)
|
||||
|
||||
async def get_channel_invites(
|
||||
self, channel_id: "Snowflake"
|
||||
) -> List[Dict[str, Any]]:
|
||||
@ -721,11 +746,36 @@ class HTTPClient:
|
||||
|
||||
return Invite.from_dict(data)
|
||||
|
||||
async def create_channel_invite(
|
||||
self,
|
||||
channel_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
) -> "Invite":
|
||||
"""Creates an invite for a channel with an optional audit log reason."""
|
||||
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
data = await self.request(
|
||||
"POST",
|
||||
f"/channels/{channel_id}/invites",
|
||||
payload=payload,
|
||||
custom_headers=headers,
|
||||
)
|
||||
from .models import Invite
|
||||
|
||||
return Invite.from_dict(data)
|
||||
|
||||
async def delete_invite(self, code: str) -> None:
|
||||
"""Deletes an invite by code."""
|
||||
|
||||
await self.request("DELETE", f"/invites/{code}")
|
||||
|
||||
async def get_invite(self, code: "Snowflake") -> Dict[str, Any]:
|
||||
"""Fetches a single invite by its code."""
|
||||
|
||||
return await self.request("GET", f"/invites/{code}")
|
||||
|
||||
async def create_webhook(
|
||||
self, channel_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "Webhook":
|
||||
@ -738,6 +788,11 @@ class HTTPClient:
|
||||
|
||||
return Webhook(data)
|
||||
|
||||
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
|
||||
"""Fetches a webhook by ID and returns the raw payload."""
|
||||
|
||||
return await self.request("GET", f"/webhooks/{webhook_id}")
|
||||
|
||||
async def edit_webhook(
|
||||
self, webhook_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "Webhook":
|
||||
@ -753,6 +808,24 @@ class HTTPClient:
|
||||
|
||||
await self.request("DELETE", f"/webhooks/{webhook_id}")
|
||||
|
||||
async def get_webhook_with_token(
|
||||
self, webhook_id: "Snowflake", token: Optional[str] = None
|
||||
) -> "Webhook":
|
||||
"""Fetches a webhook by ID, optionally using its token."""
|
||||
|
||||
endpoint = f"/webhooks/{webhook_id}"
|
||||
use_auth = True
|
||||
if token is not None:
|
||||
endpoint += f"/{token}"
|
||||
use_auth = False
|
||||
if use_auth:
|
||||
data = await self.request("GET", endpoint)
|
||||
else:
|
||||
data = await self.request("GET", endpoint, use_auth_header=False)
|
||||
from .models import Webhook
|
||||
|
||||
return Webhook(data)
|
||||
|
||||
async def execute_webhook(
|
||||
self,
|
||||
webhook_id: "Snowflake",
|
||||
@ -902,6 +975,29 @@ class HTTPClient:
|
||||
custom_headers=headers,
|
||||
)
|
||||
|
||||
async def get_guild_prune_count(self, guild_id: "Snowflake", *, days: int) -> int:
|
||||
"""Returns the number of members that would be pruned."""
|
||||
|
||||
data = await self.request(
|
||||
"GET",
|
||||
f"/guilds/{guild_id}/prune",
|
||||
params={"days": days},
|
||||
)
|
||||
return int(data.get("pruned", 0))
|
||||
|
||||
async def begin_guild_prune(
|
||||
self, guild_id: "Snowflake", *, days: int, compute_count: bool = True
|
||||
) -> int:
|
||||
"""Begins a prune operation for the guild and returns the count."""
|
||||
|
||||
payload = {"days": days, "compute_prune_count": compute_count}
|
||||
data = await self.request(
|
||||
"POST",
|
||||
f"/guilds/{guild_id}/prune",
|
||||
payload=payload,
|
||||
)
|
||||
return int(data.get("pruned", 0))
|
||||
|
||||
async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
|
||||
"""Returns a list of role objects for the guild."""
|
||||
return await self.request("GET", f"/guilds/{guild_id}/roles")
|
||||
|
@ -2,11 +2,26 @@
|
||||
Data models for Discord objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
cast,
|
||||
IO,
|
||||
)
|
||||
|
||||
from .cache import ChannelCache, MemberCache
|
||||
from .caching import MemberCacheFlags
|
||||
@ -40,22 +55,43 @@ if TYPE_CHECKING:
|
||||
from .ui.view import View
|
||||
from .interactions import Snowflake
|
||||
from .typing import Typing
|
||||
from .shard_manager import Shard
|
||||
from .asset import Asset
|
||||
|
||||
# Forward reference Message if it were used in type hints before its definition
|
||||
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
|
||||
from .components import component_factory
|
||||
|
||||
|
||||
class User:
|
||||
class HashableById:
|
||||
"""Mixin providing equality and hashing based on the ``id`` attribute."""
|
||||
|
||||
id: str
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.id == other.id # type: ignore[attr-defined]
|
||||
|
||||
def __hash__(self) -> int: # pragma: no cover - trivial
|
||||
return hash(self.id)
|
||||
|
||||
|
||||
class User(HashableById):
|
||||
"""Represents a Discord User."""
|
||||
|
||||
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
|
||||
self._client = client_instance
|
||||
if "id" not in data and "user" in data:
|
||||
data = data["user"]
|
||||
self.id: str = data["id"]
|
||||
self.username: Optional[str] = data.get("username")
|
||||
self.discriminator: Optional[str] = data.get("discriminator")
|
||||
self.bot: bool = data.get("bot", False)
|
||||
self.avatar: Optional[str] = data.get("avatar")
|
||||
avatar_hash = data.get("avatar")
|
||||
self._avatar: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
|
||||
if avatar_hash
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
@ -67,6 +103,25 @@ class User:
|
||||
disc = self.discriminator or "????"
|
||||
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
|
||||
|
||||
@property
|
||||
def avatar(self) -> Optional["Asset"]:
|
||||
"""Return the user's avatar as an :class:`Asset`."""
|
||||
|
||||
if self._avatar:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._avatar, self._client)
|
||||
return None
|
||||
|
||||
@avatar.setter
|
||||
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._avatar = value
|
||||
elif value is None:
|
||||
self._avatar = None
|
||||
else:
|
||||
self._avatar = value.url
|
||||
|
||||
async def send(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
@ -82,7 +137,7 @@ class User:
|
||||
return await target_client.send_dm(self.id, content=content, **kwargs)
|
||||
|
||||
|
||||
class Message:
|
||||
class Message(HashableById):
|
||||
"""Represents a message sent in a channel on Discord.
|
||||
|
||||
Attributes:
|
||||
@ -108,6 +163,7 @@ class Message:
|
||||
self.author: User = User(data["author"], client_instance)
|
||||
self.content: str = data["content"]
|
||||
self.timestamp: str = data["timestamp"]
|
||||
self.edited_timestamp: Optional[str] = data.get("edited_timestamp")
|
||||
if data.get("components"):
|
||||
self.components: Optional[List[ActionRow]] = [
|
||||
ActionRow.from_dict(c, client_instance)
|
||||
@ -139,6 +195,20 @@ class Message:
|
||||
cleaned = pattern.sub("", self.content)
|
||||
return " ".join(cleaned.split())
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime.datetime:
|
||||
"""Return message timestamp as a :class:`~datetime.datetime`."""
|
||||
|
||||
return datetime.datetime.fromisoformat(self.timestamp)
|
||||
|
||||
@property
|
||||
def edited_at(self) -> Optional[datetime.datetime]:
|
||||
"""Return edited timestamp as :class:`~datetime.datetime` if present."""
|
||||
|
||||
if self.edited_timestamp is None:
|
||||
return None
|
||||
return datetime.datetime.fromisoformat(self.edited_timestamp)
|
||||
|
||||
async def pin(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
@ -165,6 +235,15 @@ class Message:
|
||||
await self._client._http.unpin_message(self.channel_id, self.id)
|
||||
self.pinned = False
|
||||
|
||||
async def crosspost(self) -> "Message":
|
||||
"""|coro|
|
||||
|
||||
Crossposts this message to all follower channels and return the resulting message.
|
||||
"""
|
||||
|
||||
data = await self._client._http.crosspost_message(self.channel_id, self.id)
|
||||
return self._client.parse_message(data)
|
||||
|
||||
async def reply(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
@ -633,11 +712,45 @@ class Attachment:
|
||||
|
||||
|
||||
class File:
|
||||
"""Represents a file to be uploaded."""
|
||||
"""Represents a file to be uploaded.
|
||||
|
||||
def __init__(self, filename: str, data: bytes):
|
||||
self.filename = filename
|
||||
self.data = data
|
||||
Parameters
|
||||
----------
|
||||
fp:
|
||||
A file path, file-like object, or bytes-like object containing the
|
||||
data to upload.
|
||||
filename:
|
||||
Optional name of the file. If not provided and ``fp`` is a path or has
|
||||
a ``name`` attribute, the name will be inferred.
|
||||
spoiler:
|
||||
When ``True`` the filename will be prefixed with ``"SPOILER_"``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fp: Union[str, bytes, os.PathLike[Any], IO[bytes]],
|
||||
*,
|
||||
filename: Optional[str] = None,
|
||||
spoiler: bool = False,
|
||||
) -> None:
|
||||
if isinstance(fp, (str, os.PathLike)):
|
||||
self.data = open(fp, "rb")
|
||||
inferred = os.path.basename(fp)
|
||||
elif isinstance(fp, bytes):
|
||||
self.data = io.BytesIO(fp)
|
||||
inferred = None
|
||||
else:
|
||||
self.data = fp
|
||||
inferred = getattr(fp, "name", None)
|
||||
|
||||
name = filename or inferred
|
||||
if name is None:
|
||||
raise ValueError("filename could not be inferred")
|
||||
|
||||
if spoiler and not name.startswith("SPOILER_"):
|
||||
name = f"SPOILER_{name}"
|
||||
self.filename = name
|
||||
self.spoiler = spoiler
|
||||
|
||||
|
||||
class AllowedMentions:
|
||||
@ -706,7 +819,12 @@ class Role:
|
||||
self.name: str = data["name"]
|
||||
self.color: int = data["color"]
|
||||
self.hoist: bool = data["hoist"]
|
||||
self.icon: Optional[str] = data.get("icon")
|
||||
icon_hash = data.get("icon")
|
||||
self._icon: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/role-icons/{self.id}/{icon_hash}.png"
|
||||
if icon_hash
|
||||
else None
|
||||
)
|
||||
self.unicode_emoji: Optional[str] = data.get("unicode_emoji")
|
||||
self.position: int = data["position"]
|
||||
self.permissions: str = data["permissions"] # String of bitwise permissions
|
||||
@ -724,6 +842,23 @@ class Role:
|
||||
def __repr__(self) -> str:
|
||||
return f"<Role id='{self.id}' name='{self.name}'>"
|
||||
|
||||
@property
|
||||
def icon(self) -> Optional["Asset"]:
|
||||
if self._icon:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._icon, None)
|
||||
return None
|
||||
|
||||
@icon.setter
|
||||
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._icon = value
|
||||
elif value is None:
|
||||
self._icon = None
|
||||
else:
|
||||
self._icon = value.url
|
||||
|
||||
|
||||
class Member(User): # Member inherits from User
|
||||
"""Represents a Guild Member.
|
||||
@ -752,12 +887,18 @@ class Member(User): # Member inherits from User
|
||||
) # Pass user_data or data if user_data is empty
|
||||
|
||||
self.nick: Optional[str] = data.get("nick")
|
||||
self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash
|
||||
self.roles: List[str] = data.get("roles", []) # List of role IDs
|
||||
self.joined_at: str = data["joined_at"] # ISO8601 timestamp
|
||||
self.premium_since: Optional[str] = data.get(
|
||||
"premium_since"
|
||||
) # ISO8601 timestamp
|
||||
avatar_hash = data.get("avatar")
|
||||
if avatar_hash:
|
||||
guild_id = data.get("guild_id")
|
||||
if guild_id:
|
||||
self._avatar = f"https://cdn.discordapp.com/guilds/{guild_id}/users/{self.id}/avatars/{avatar_hash}.png"
|
||||
else:
|
||||
self._avatar = (
|
||||
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
|
||||
)
|
||||
self.roles: List[str] = data.get("roles", [])
|
||||
self.joined_at: str = data["joined_at"]
|
||||
self.premium_since: Optional[str] = data.get("premium_since")
|
||||
self.deaf: bool = data.get("deaf", False)
|
||||
self.mute: bool = data.get("mute", False)
|
||||
self.pending: bool = data.get("pending", False)
|
||||
@ -767,6 +908,7 @@ class Member(User): # Member inherits from User
|
||||
self.communication_disabled_until: Optional[str] = data.get(
|
||||
"communication_disabled_until"
|
||||
) # ISO8601 timestamp
|
||||
self.voice_state = data.get("voice_state")
|
||||
|
||||
# If 'user' object was present, ensure User attributes are from there
|
||||
if user_data:
|
||||
@ -781,6 +923,25 @@ class Member(User): # Member inherits from User
|
||||
def __repr__(self) -> str:
|
||||
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
|
||||
|
||||
@property
|
||||
def avatar(self) -> Optional["Asset"]:
|
||||
"""Return the member's avatar as an :class:`Asset`."""
|
||||
|
||||
if self._avatar:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._avatar, self._client)
|
||||
return None
|
||||
|
||||
@avatar.setter
|
||||
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._avatar = value
|
||||
elif value is None:
|
||||
self._avatar = None
|
||||
else:
|
||||
self._avatar = value.url
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Return the nickname if set, otherwise the username."""
|
||||
@ -847,6 +1008,41 @@ class Member(User): # Member inherits from User
|
||||
|
||||
return max(role_objects, key=lambda r: r.position)
|
||||
|
||||
@property
|
||||
def guild_permissions(self) -> "Permissions":
|
||||
"""Return the member's guild-level permissions."""
|
||||
|
||||
if not self.guild_id or not self._client:
|
||||
return Permissions(0)
|
||||
|
||||
guild = self._client.get_guild(self.guild_id)
|
||||
if guild is None:
|
||||
return Permissions(0)
|
||||
|
||||
base = Permissions(0)
|
||||
|
||||
everyone = guild.get_role(guild.id)
|
||||
if everyone is not None:
|
||||
base |= Permissions(int(everyone.permissions))
|
||||
|
||||
for rid in self.roles:
|
||||
role = guild.get_role(rid)
|
||||
if role is not None:
|
||||
base |= Permissions(int(role.permissions))
|
||||
|
||||
if base & Permissions.ADMINISTRATOR:
|
||||
return Permissions(~0)
|
||||
|
||||
return base
|
||||
|
||||
@property
|
||||
def voice(self) -> Optional["VoiceState"]:
|
||||
"""Return the member's cached voice state as a :class:`VoiceState`."""
|
||||
|
||||
if self.voice_state is None:
|
||||
return None
|
||||
return VoiceState.from_dict(self.voice_state)
|
||||
|
||||
|
||||
class PartialEmoji:
|
||||
"""Represents a partial emoji, often used in components or reactions.
|
||||
@ -1044,7 +1240,7 @@ class PermissionOverwrite:
|
||||
return f"<PermissionOverwrite id='{self.id}' type='{self.type.name if hasattr(self.type, 'name') else self._type_val}' allow='{self.allow}' deny='{self.deny}'>"
|
||||
|
||||
|
||||
class Guild:
|
||||
class Guild(HashableById):
|
||||
"""Represents a Discord Guild (Server).
|
||||
|
||||
Attributes:
|
||||
@ -1084,15 +1280,42 @@ class Guild:
|
||||
nsfw_level (GuildNSFWLevel): Guild NSFW level.
|
||||
stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model)
|
||||
premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled.
|
||||
text_channels (List[TextChannel]): List of text-based channels in this guild.
|
||||
voice_channels (List[VoiceChannel]): List of voice-based channels in this guild.
|
||||
category_channels (List[CategoryChannel]): List of category channels in this guild.
|
||||
"""
|
||||
|
||||
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
||||
def __init__(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
client_instance: "Client",
|
||||
*,
|
||||
shard_id: Optional[int] = None,
|
||||
):
|
||||
self._client: "Client" = client_instance
|
||||
self._shard_id: Optional[int] = (
|
||||
shard_id if shard_id is not None else data.get("shard_id")
|
||||
)
|
||||
self.id: str = data["id"]
|
||||
self.name: str = data["name"]
|
||||
self.icon: Optional[str] = data.get("icon")
|
||||
self.splash: Optional[str] = data.get("splash")
|
||||
self.discovery_splash: Optional[str] = data.get("discovery_splash")
|
||||
icon_hash = data.get("icon")
|
||||
self._icon: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/icons/{self.id}/{icon_hash}.png"
|
||||
if icon_hash
|
||||
else None
|
||||
)
|
||||
splash_hash = data.get("splash")
|
||||
self._splash: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/splashes/{self.id}/{splash_hash}.png"
|
||||
if splash_hash
|
||||
else None
|
||||
)
|
||||
discovery_hash = data.get("discovery_splash")
|
||||
self._discovery_splash: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/discovery-splashes/{self.id}/{discovery_hash}.png"
|
||||
if discovery_hash
|
||||
else None
|
||||
)
|
||||
self.owner: Optional[bool] = data.get("owner")
|
||||
self.owner_id: str = data["owner_id"]
|
||||
self.permissions: Optional[str] = data.get("permissions")
|
||||
@ -1129,7 +1352,12 @@ class Guild:
|
||||
self.max_members: Optional[int] = data.get("max_members")
|
||||
self.vanity_url_code: Optional[str] = data.get("vanity_url_code")
|
||||
self.description: Optional[str] = data.get("description")
|
||||
self.banner: Optional[str] = data.get("banner")
|
||||
banner_hash = data.get("banner")
|
||||
self._banner: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/banners/{self.id}/{banner_hash}.png"
|
||||
if banner_hash
|
||||
else None
|
||||
)
|
||||
self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"])
|
||||
self.premium_subscription_count: Optional[int] = data.get(
|
||||
"premium_subscription_count"
|
||||
@ -1168,6 +1396,28 @@ class Guild:
|
||||
getattr(client_instance, "member_cache_flags", MemberCacheFlags())
|
||||
)
|
||||
self._threads: Dict[str, "Thread"] = {}
|
||||
self.text_channels: List["TextChannel"] = []
|
||||
self.voice_channels: List["VoiceChannel"] = []
|
||||
self.category_channels: List["CategoryChannel"] = []
|
||||
|
||||
@property
|
||||
def shard_id(self) -> Optional[int]:
|
||||
"""ID of the shard that received this guild, if any."""
|
||||
|
||||
return self._shard_id
|
||||
|
||||
@property
|
||||
def shard(self) -> Optional["Shard"]:
|
||||
"""The :class:`Shard` this guild belongs to."""
|
||||
|
||||
if self._shard_id is None:
|
||||
return None
|
||||
manager = getattr(self._client, "_shard_manager", None)
|
||||
if not manager:
|
||||
return None
|
||||
if 0 <= self._shard_id < len(manager.shards):
|
||||
return manager.shards[self._shard_id]
|
||||
return None
|
||||
|
||||
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
||||
return self._channels.get(channel_id)
|
||||
@ -1203,9 +1453,86 @@ class Guild:
|
||||
def get_role(self, role_id: str) -> Optional[Role]:
|
||||
return next((role for role in self.roles if role.id == role_id), None)
|
||||
|
||||
@property
|
||||
def me(self) -> Optional[Member]:
|
||||
"""The member object for the connected bot in this guild, if present."""
|
||||
|
||||
client_user = getattr(self._client, "user", None)
|
||||
if not client_user:
|
||||
return None
|
||||
return self.get_member(client_user.id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Guild id='{self.id}' name='{self.name}'>"
|
||||
|
||||
@property
|
||||
def icon(self) -> Optional["Asset"]:
|
||||
if self._icon:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._icon, self._client)
|
||||
return None
|
||||
|
||||
@icon.setter
|
||||
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._icon = value
|
||||
elif value is None:
|
||||
self._icon = None
|
||||
else:
|
||||
self._icon = value.url
|
||||
|
||||
@property
|
||||
def splash(self) -> Optional["Asset"]:
|
||||
if self._splash:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._splash, self._client)
|
||||
return None
|
||||
|
||||
@splash.setter
|
||||
def splash(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._splash = value
|
||||
elif value is None:
|
||||
self._splash = None
|
||||
else:
|
||||
self._splash = value.url
|
||||
|
||||
@property
|
||||
def discovery_splash(self) -> Optional["Asset"]:
|
||||
if self._discovery_splash:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._discovery_splash, self._client)
|
||||
return None
|
||||
|
||||
@discovery_splash.setter
|
||||
def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._discovery_splash = value
|
||||
elif value is None:
|
||||
self._discovery_splash = None
|
||||
else:
|
||||
self._discovery_splash = value.url
|
||||
|
||||
@property
|
||||
def banner(self) -> Optional["Asset"]:
|
||||
if self._banner:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._banner, self._client)
|
||||
return None
|
||||
|
||||
@banner.setter
|
||||
def banner(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._banner = value
|
||||
elif value is None:
|
||||
self._banner = None
|
||||
else:
|
||||
self._banner = value.url
|
||||
|
||||
async def fetch_widget(self) -> Dict[str, Any]:
|
||||
"""|coro| Fetch this guild's widget settings."""
|
||||
|
||||
@ -1259,8 +1586,82 @@ class Guild:
|
||||
del self._client._gateway._member_chunk_requests[nonce]
|
||||
raise
|
||||
|
||||
async def prune_members(self, days: int, *, compute_count: bool = True) -> int:
|
||||
"""|coro| Remove inactive members from the guild.
|
||||
|
||||
class Channel:
|
||||
Parameters
|
||||
----------
|
||||
days: int
|
||||
Number of days of inactivity required to be pruned.
|
||||
compute_count: bool
|
||||
Whether to return the number of members pruned.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of members pruned.
|
||||
"""
|
||||
|
||||
return await self._client._http.begin_guild_prune(
|
||||
self.id, days=days, compute_count=compute_count
|
||||
)
|
||||
|
||||
async def create_text_channel(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
**options: Any,
|
||||
) -> "TextChannel":
|
||||
"""|coro| Create a new text channel in this guild."""
|
||||
|
||||
payload: Dict[str, Any] = {"name": name, "type": ChannelType.GUILD_TEXT.value}
|
||||
payload.update(options)
|
||||
data = await self._client._http.create_guild_channel(
|
||||
self.id, payload, reason=reason
|
||||
)
|
||||
return cast("TextChannel", self._client.parse_channel(data))
|
||||
|
||||
async def create_voice_channel(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
**options: Any,
|
||||
) -> "VoiceChannel":
|
||||
"""|coro| Create a new voice channel in this guild."""
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"name": name,
|
||||
"type": ChannelType.GUILD_VOICE.value,
|
||||
}
|
||||
payload.update(options)
|
||||
data = await self._client._http.create_guild_channel(
|
||||
self.id, payload, reason=reason
|
||||
)
|
||||
return cast("VoiceChannel", self._client.parse_channel(data))
|
||||
|
||||
async def create_category(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
**options: Any,
|
||||
) -> "CategoryChannel":
|
||||
"""|coro| Create a new category channel in this guild."""
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"name": name,
|
||||
"type": ChannelType.GUILD_CATEGORY.value,
|
||||
}
|
||||
payload.update(options)
|
||||
data = await self._client._http.create_guild_channel(
|
||||
self.id, payload, reason=reason
|
||||
)
|
||||
return cast("CategoryChannel", self._client.parse_channel(data))
|
||||
|
||||
|
||||
class Channel(HashableById):
|
||||
"""Base class for Discord channels."""
|
||||
|
||||
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
||||
@ -1540,6 +1941,31 @@ class TextChannel(Channel, Messageable):
|
||||
data = await self._client._http.start_thread_without_message(self.id, payload)
|
||||
return cast("Thread", self._client.parse_channel(data))
|
||||
|
||||
async def create_invite(
|
||||
self,
|
||||
*,
|
||||
max_age: Optional[int] = None,
|
||||
max_uses: Optional[int] = None,
|
||||
temporary: Optional[bool] = None,
|
||||
unique: Optional[bool] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> "Invite":
|
||||
"""|coro| Create an invite to this channel."""
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
if max_age is not None:
|
||||
payload["max_age"] = max_age
|
||||
if max_uses is not None:
|
||||
payload["max_uses"] = max_uses
|
||||
if temporary is not None:
|
||||
payload["temporary"] = temporary
|
||||
if unique is not None:
|
||||
payload["unique"] = unique
|
||||
|
||||
return await self._client._http.create_channel_invite(
|
||||
self.id, payload, reason=reason
|
||||
)
|
||||
|
||||
|
||||
class VoiceChannel(Channel):
|
||||
"""Represents a guild voice channel or stage voice channel."""
|
||||
@ -1839,7 +2265,12 @@ class Webhook:
|
||||
self.guild_id: Optional[str] = data.get("guild_id")
|
||||
self.channel_id: Optional[str] = data.get("channel_id")
|
||||
self.name: Optional[str] = data.get("name")
|
||||
self.avatar: Optional[str] = data.get("avatar")
|
||||
avatar_hash = data.get("avatar")
|
||||
self._avatar: Optional[str] = (
|
||||
f"https://cdn.discordapp.com/webhooks/{self.id}/{avatar_hash}.png"
|
||||
if avatar_hash
|
||||
else None
|
||||
)
|
||||
self.token: Optional[str] = data.get("token")
|
||||
self.application_id: Optional[str] = data.get("application_id")
|
||||
self.url: Optional[str] = data.get("url")
|
||||
@ -1848,6 +2279,25 @@ class Webhook:
|
||||
def __repr__(self) -> str:
|
||||
return f"<Webhook id='{self.id}' name='{self.name}'>"
|
||||
|
||||
@property
|
||||
def avatar(self) -> Optional["Asset"]:
|
||||
"""Return the webhook's avatar as an :class:`Asset`."""
|
||||
|
||||
if self._avatar:
|
||||
from .asset import Asset
|
||||
|
||||
return Asset(self._avatar, self._client)
|
||||
return None
|
||||
|
||||
@avatar.setter
|
||||
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||
if isinstance(value, str):
|
||||
self._avatar = value
|
||||
elif value is None:
|
||||
self._avatar = None
|
||||
else:
|
||||
self._avatar = value.url
|
||||
|
||||
@classmethod
|
||||
def from_url(
|
||||
cls, url: str, session: Optional[aiohttp.ClientSession] = None
|
||||
@ -1875,6 +2325,33 @@ class Webhook:
|
||||
|
||||
return cls({"id": webhook_id, "token": token, "url": url})
|
||||
|
||||
@classmethod
|
||||
def from_token(
|
||||
cls,
|
||||
webhook_id: str,
|
||||
token: str,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
) -> "Webhook":
|
||||
"""Create a minimal :class:`Webhook` from an ID and token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
webhook_id:
|
||||
The ID of the webhook.
|
||||
token:
|
||||
The webhook token.
|
||||
session:
|
||||
Unused for now. Present for API compatibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Webhook
|
||||
A webhook instance containing only the ``id``, ``token`` and ``url``.
|
||||
"""
|
||||
|
||||
url = f"https://discord.com/api/webhooks/{webhook_id}/{token}"
|
||||
return cls({"id": webhook_id, "token": token, "url": url})
|
||||
|
||||
async def send(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
@ -2548,6 +3025,45 @@ class VoiceStateUpdate:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceState:
|
||||
"""Represents a cached voice state for a member."""
|
||||
|
||||
guild_id: Optional[str]
|
||||
channel_id: Optional[str]
|
||||
user_id: Optional[str]
|
||||
session_id: Optional[str]
|
||||
deaf: bool = False
|
||||
mute: bool = False
|
||||
self_deaf: bool = False
|
||||
self_mute: bool = False
|
||||
self_stream: Optional[bool] = None
|
||||
self_video: bool = False
|
||||
suppress: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "VoiceState":
|
||||
return cls(
|
||||
guild_id=data.get("guild_id"),
|
||||
channel_id=data.get("channel_id"),
|
||||
user_id=data.get("user_id"),
|
||||
session_id=data.get("session_id"),
|
||||
deaf=data.get("deaf", False),
|
||||
mute=data.get("mute", False),
|
||||
self_deaf=data.get("self_deaf", False),
|
||||
self_mute=data.get("self_mute", False),
|
||||
self_stream=data.get("self_stream"),
|
||||
self_video=data.get("self_video", False),
|
||||
suppress=data.get("suppress", False),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<VoiceState guild_id='{self.guild_id}' user_id='{self.user_id}' "
|
||||
f"channel_id='{self.channel_id}'>"
|
||||
)
|
||||
|
||||
|
||||
class Reaction:
|
||||
"""Represents a message reaction event."""
|
||||
|
||||
@ -2640,6 +3156,18 @@ class Invite:
|
||||
return f"<Invite code='{self.code}' guild_id='{self.guild_id}' channel_id='{self.channel_id}'>"
|
||||
|
||||
|
||||
class InviteDelete:
|
||||
"""Represents an INVITE_DELETE event."""
|
||||
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.channel_id: str = data["channel_id"]
|
||||
self.guild_id: Optional[str] = data.get("guild_id")
|
||||
self.code: str = data["code"]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<InviteDelete code='{self.code}' guild_id='{self.guild_id}' channel_id='{self.channel_id}'>"
|
||||
|
||||
|
||||
class GuildMemberRemove:
|
||||
"""Represents a GUILD_MEMBER_REMOVE event."""
|
||||
|
||||
|
19
disagreement/object.py
Normal file
19
disagreement/object.py
Normal file
@ -0,0 +1,19 @@
|
||||
class Object:
|
||||
"""A minimal wrapper around a Discord snowflake ID."""
|
||||
|
||||
__slots__ = ("id",)
|
||||
|
||||
def __init__(self, object_id: int) -> None:
|
||||
self.id = int(object_id)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Object) and self.id == other.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Object id={self.id}>"
|
@ -57,6 +57,15 @@ class Permissions(IntFlag):
|
||||
USE_EXTERNAL_SOUNDS = 1 << 45
|
||||
SEND_VOICE_MESSAGES = 1 << 46
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> "Permissions":
|
||||
"""Return a ``Permissions`` object with every permission bit enabled."""
|
||||
|
||||
value = 0
|
||||
for perm in cls:
|
||||
value |= perm.value
|
||||
return cls(value)
|
||||
|
||||
|
||||
def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int:
|
||||
"""Return a combined integer value for multiple permissions."""
|
||||
|
@ -3,7 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING
|
||||
from typing import Any, AsyncIterator, Dict, Iterable, Optional, TYPE_CHECKING, Callable
|
||||
import re
|
||||
|
||||
# Discord epoch in milliseconds (2015-01-01T00:00:00Z)
|
||||
DISCORD_EPOCH = 1420070400000
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - for type hinting only
|
||||
from .models import Message, TextChannel
|
||||
@ -14,6 +18,27 @@ def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def find(predicate: Callable[[Any], bool], iterable: Iterable[Any]) -> Optional[Any]:
|
||||
"""Return the first element in ``iterable`` matching the ``predicate``."""
|
||||
for element in iterable:
|
||||
if predicate(element):
|
||||
return element
|
||||
return None
|
||||
|
||||
|
||||
def get(iterable: Iterable[Any], **attrs: Any) -> Optional[Any]:
|
||||
"""Return the first element with matching attributes."""
|
||||
def predicate(elem: Any) -> bool:
|
||||
return all(getattr(elem, attr, None) == value for attr, value in attrs.items())
|
||||
return find(predicate, iterable)
|
||||
|
||||
|
||||
def snowflake_time(snowflake: int) -> datetime:
|
||||
"""Return the creation time of a Discord snowflake."""
|
||||
timestamp_ms = (snowflake >> 22) + DISCORD_EPOCH
|
||||
return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc)
|
||||
|
||||
|
||||
async def message_pager(
|
||||
channel: "TextChannel",
|
||||
*,
|
||||
@ -21,32 +46,11 @@ async def message_pager(
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> AsyncIterator["Message"]:
|
||||
"""Asynchronously paginate a channel's messages.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
channel:
|
||||
The :class:`TextChannel` to fetch messages from.
|
||||
limit:
|
||||
The maximum number of messages to yield. ``None`` fetches until no
|
||||
more messages are returned.
|
||||
before:
|
||||
Fetch messages with IDs less than this snowflake.
|
||||
after:
|
||||
Fetch messages with IDs greater than this snowflake.
|
||||
|
||||
Yields
|
||||
------
|
||||
Message
|
||||
Messages in the channel, oldest first.
|
||||
"""
|
||||
|
||||
"""Asynchronously paginate a channel's messages."""
|
||||
remaining = limit
|
||||
last_id = before
|
||||
while remaining is None or remaining > 0:
|
||||
fetch_limit = 100
|
||||
if remaining is not None:
|
||||
fetch_limit = min(fetch_limit, remaining)
|
||||
fetch_limit = min(100, remaining) if remaining is not None else 100
|
||||
|
||||
params: Dict[str, Any] = {"limit": fetch_limit}
|
||||
if last_id is not None:
|
||||
@ -71,3 +75,52 @@ async def message_pager(
|
||||
remaining -= 1
|
||||
if remaining == 0:
|
||||
return
|
||||
|
||||
|
||||
class Paginator:
|
||||
"""Helper to split text into pages under a character limit."""
|
||||
|
||||
def __init__(self, limit: int = 2000) -> None:
|
||||
self.limit = limit
|
||||
self._pages: list[str] = []
|
||||
self._current = ""
|
||||
|
||||
def add_line(self, line: str) -> None:
|
||||
"""Add a line of text to the paginator."""
|
||||
if len(line) > self.limit:
|
||||
if self._current:
|
||||
self._pages.append(self._current)
|
||||
self._current = ""
|
||||
for i in range(0, len(line), self.limit):
|
||||
chunk = line[i : i + self.limit]
|
||||
if len(chunk) == self.limit:
|
||||
self._pages.append(chunk)
|
||||
else:
|
||||
self._current = chunk
|
||||
return
|
||||
|
||||
if not self._current:
|
||||
self._current = line
|
||||
elif len(self._current) + 1 + len(line) <= self.limit:
|
||||
self._current += "\n" + line
|
||||
else:
|
||||
self._pages.append(self._current)
|
||||
self._current = line
|
||||
|
||||
@property
|
||||
def pages(self) -> list[str]:
|
||||
"""Return the accumulated pages."""
|
||||
pages = list(self._pages)
|
||||
if self._current:
|
||||
pages.append(self._current)
|
||||
return pages
|
||||
|
||||
|
||||
def escape_markdown(text: str) -> str:
|
||||
"""Escape Discord markdown formatting in ``text``."""
|
||||
return re.sub(r"([\\*_~`>|])", r"\\\1", text)
|
||||
|
||||
|
||||
def escape_mentions(text: str) -> str:
|
||||
"""Escape Discord mentions in ``text``."""
|
||||
return text.replace("@", "@\u200b")
|
||||
|
@ -79,6 +79,9 @@ class VoiceClient:
|
||||
self._server_port: Optional[int] = None
|
||||
self._current_source: Optional[AudioSource] = None
|
||||
self._play_task: Optional[asyncio.Task] = None
|
||||
self._pause_event = asyncio.Event()
|
||||
self._pause_event.set()
|
||||
self._is_playing = False
|
||||
self._sink: Optional[AudioSink] = None
|
||||
self._ssrc_map: dict[int, int] = {}
|
||||
self._ssrc_lock = threading.Lock()
|
||||
@ -191,8 +194,10 @@ class VoiceClient:
|
||||
|
||||
async def _play_loop(self) -> None:
|
||||
assert self._current_source is not None
|
||||
self._is_playing = True
|
||||
try:
|
||||
while True:
|
||||
await self._pause_event.wait()
|
||||
data = await self._current_source.read()
|
||||
if not data:
|
||||
break
|
||||
@ -204,13 +209,17 @@ class VoiceClient:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
self._play_task = None
|
||||
self._is_playing = False
|
||||
self._pause_event.set()
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._play_task:
|
||||
self._play_task.cancel()
|
||||
self._pause_event.set()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._play_task
|
||||
self._play_task = None
|
||||
self._is_playing = False
|
||||
if self._current_source:
|
||||
await self._current_source.close()
|
||||
self._current_source = None
|
||||
@ -229,6 +238,27 @@ class VoiceClient:
|
||||
|
||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the current audio source."""
|
||||
|
||||
if self._play_task and not self._play_task.done():
|
||||
self._pause_event.clear()
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume playback of a paused source."""
|
||||
|
||||
if self._play_task and not self._play_task.done():
|
||||
self._pause_event.set()
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
"""Return ``True`` if playback is currently paused."""
|
||||
|
||||
return bool(self._play_task and not self._pause_event.is_set())
|
||||
|
||||
def is_playing(self) -> bool:
|
||||
"""Return ``True`` if audio is actively being played."""
|
||||
return self._is_playing and self._pause_event.is_set()
|
||||
|
||||
def listen(self, sink: AudioSink) -> None:
|
||||
"""Start listening to voice and routing to a sink."""
|
||||
if not isinstance(sink, AudioSink):
|
||||
|
@ -13,12 +13,28 @@ if member:
|
||||
print(member.display_name)
|
||||
```
|
||||
|
||||
To access the bot's own member object, use the ``Guild.me`` property. It returns
|
||||
``None`` if the bot is not in the guild or its user data hasn't been loaded:
|
||||
|
||||
```python
|
||||
bot_member = guild.me
|
||||
if bot_member:
|
||||
print(bot_member.joined_at)
|
||||
```
|
||||
|
||||
The cache can be cleared manually if needed:
|
||||
|
||||
```python
|
||||
client.cache.clear()
|
||||
```
|
||||
|
||||
## Partial Objects
|
||||
|
||||
Some events only include minimal data for related resources. When only an ``id``
|
||||
is available, Disagreement represents the resource using :class:`~disagreement.Object`.
|
||||
These objects can be compared and used in sets or dictionaries and can be passed
|
||||
to API methods to fetch the full data when needed.
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Components](using_components.md)
|
||||
|
@ -11,7 +11,11 @@ The command handler registers a `help` command automatically. Use it to list all
|
||||
!help ping # shows help for the "ping" command
|
||||
```
|
||||
|
||||
The help command will show each command's brief description if provided.
|
||||
Commands are grouped by their Cog name and paginated so that long help
|
||||
lists are split into multiple messages using the `Paginator` utility.
|
||||
|
||||
If you need custom formatting you can subclass
|
||||
`HelpCommand` and override `send_command_help` or `send_group_help`.
|
||||
|
||||
## Checks
|
||||
|
||||
|
@ -132,6 +132,28 @@ async def on_shard_resume(info: dict):
|
||||
...
|
||||
```
|
||||
|
||||
## CONNECT
|
||||
|
||||
Dispatched when the WebSocket connection opens. The callback receives a
|
||||
dictionary with the shard ID.
|
||||
|
||||
```python
|
||||
@client.event
|
||||
async def on_connect(info: dict):
|
||||
print("connected", info.get("shard_id"))
|
||||
```
|
||||
|
||||
## DISCONNECT
|
||||
|
||||
Fired when the WebSocket connection closes. The callback receives a dictionary
|
||||
with the shard ID.
|
||||
|
||||
```python
|
||||
@client.event
|
||||
async def on_disconnect(info: dict):
|
||||
...
|
||||
```
|
||||
|
||||
## VOICE_STATE_UPDATE
|
||||
|
||||
Triggered when a user's voice connection state changes, such as joining or leaving a voice channel. The callback receives a `VoiceStateUpdate` model.
|
||||
|
@ -14,6 +14,7 @@ A Python library for interacting with the Discord API, with a focus on bot devel
|
||||
- Built-in caching layer
|
||||
- Experimental voice support
|
||||
- Helpful error handling utilities
|
||||
- Paginator utility for splitting long messages
|
||||
|
||||
## Installation
|
||||
|
||||
@ -60,13 +61,9 @@ if not token:
|
||||
|
||||
intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
|
||||
client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
|
||||
async def main() -> None:
|
||||
|
||||
client.add_cog(Basics(client))
|
||||
await client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
client.run()
|
||||
```
|
||||
|
||||
### Global Error Handling
|
||||
|
@ -15,6 +15,12 @@ from disagreement import Permissions
|
||||
value = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
|
||||
```
|
||||
|
||||
You can also get a bitmask containing **every** permission:
|
||||
|
||||
```python
|
||||
all_perms = Permissions.all()
|
||||
```
|
||||
|
||||
## Helper Functions
|
||||
|
||||
### ``permissions_value``
|
||||
|
@ -8,13 +8,8 @@ manually.
|
||||
and configures the `ShardManager` automatically.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import disagreement
|
||||
|
||||
bot = disagreement.AutoShardedClient(token="YOUR_TOKEN")
|
||||
|
||||
async def main():
|
||||
await bot.run()
|
||||
|
||||
asyncio.run(main())
|
||||
bot.run()
|
||||
```
|
||||
|
@ -157,6 +157,22 @@ container = Container(
|
||||
A container can itself contain layout and content components, letting you build complex messages.
|
||||
|
||||
|
||||
## Persistent Views
|
||||
|
||||
Views with ``timeout=None`` are persistent. Their ``custom_id`` components are saved to ``persistent_views.json`` so they survive bot restarts.
|
||||
|
||||
```python
|
||||
class MyView(View):
|
||||
@button(label="Press", custom_id="press")
|
||||
async def handle(self, view, inter):
|
||||
await inter.respond("Pressed!")
|
||||
|
||||
client.add_persistent_view(MyView())
|
||||
```
|
||||
|
||||
When the client starts, it loads this file and registers each view again. Remove
|
||||
the file to clear stored views.
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Slash Commands](slash_commands.md)
|
||||
|
20
docs/utils.md
Normal file
20
docs/utils.md
Normal file
@ -0,0 +1,20 @@
|
||||
# Utility Helpers
|
||||
|
||||
Disagreement provides a few small utility functions for working with Discord data.
|
||||
|
||||
## `utcnow`
|
||||
|
||||
Returns the current timezone-aware UTC `datetime`.
|
||||
|
||||
## `snowflake_time`
|
||||
|
||||
Converts a Discord snowflake ID into the UTC timestamp when it was generated.
|
||||
|
||||
```python
|
||||
from disagreement.utils import snowflake_time
|
||||
|
||||
created_at = snowflake_time(175928847299117063)
|
||||
print(created_at.isoformat())
|
||||
```
|
||||
|
||||
The function extracts the timestamp from the snowflake and returns a `datetime` in UTC.
|
@ -6,6 +6,10 @@ Disagreement includes experimental support for connecting to voice channels. You
|
||||
voice = await client.join_voice(guild_id, channel_id)
|
||||
await voice.play_file("welcome.mp3")
|
||||
await voice.play_file("another.mp3") # switch sources while connected
|
||||
voice.pause()
|
||||
voice.resume()
|
||||
if voice.is_playing():
|
||||
print("audio is playing")
|
||||
await voice.close()
|
||||
```
|
||||
|
||||
|
@ -67,9 +67,7 @@ BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
# --- Intents Configuration ---
|
||||
# Define the intents your bot needs. For basic message reading and responding:
|
||||
intents = (
|
||||
GatewayIntent.GUILDS
|
||||
| GatewayIntent.GUILD_MESSAGES
|
||||
| GatewayIntent.MESSAGE_CONTENT
|
||||
GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
|
||||
) # MESSAGE_CONTENT is privileged!
|
||||
|
||||
# If you don't need message content and only react to commands/mentions,
|
||||
@ -210,14 +208,14 @@ async def on_guild_available(guild: Guild):
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
async def main():
|
||||
def main():
|
||||
print("Starting Disagreement Bot...")
|
||||
try:
|
||||
# Add the Cog to the client
|
||||
client.add_cog(ExampleCog(client)) # Pass client instance to Cog constructor
|
||||
# client.add_cog is synchronous, but it schedules cog.cog_load() if it's async.
|
||||
|
||||
await client.run()
|
||||
client.run()
|
||||
except AuthenticationError:
|
||||
print(
|
||||
"Authentication failed. Please check your bot token and ensure it's correct."
|
||||
@ -232,7 +230,7 @@ async def main():
|
||||
finally:
|
||||
if not client.is_closed():
|
||||
print("Ensuring client is closed...")
|
||||
await client.close()
|
||||
asyncio.run(client.close())
|
||||
print("Bot has been shut down.")
|
||||
|
||||
|
||||
@ -244,4 +242,4 @@ if __name__ == "__main__":
|
||||
# if os.name == 'nt':
|
||||
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
asyncio.run(main())
|
||||
main()
|
||||
|
@ -263,7 +263,7 @@ class ComponentCommandsCog(Cog):
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
def main():
|
||||
@client.event
|
||||
async def on_ready():
|
||||
if client.user:
|
||||
@ -283,8 +283,8 @@ async def main():
|
||||
)
|
||||
|
||||
client.add_cog(ComponentCommandsCog(client))
|
||||
await client.run()
|
||||
client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
main()
|
||||
|
@ -65,11 +65,9 @@ client.app_command_handler.add_command(user_info)
|
||||
client.app_command_handler.add_command(quote)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
await client.run()
|
||||
def main() -> None:
|
||||
client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
main()
|
||||
|
@ -27,10 +27,10 @@ intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
|
||||
client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
def main() -> None:
|
||||
client.add_cog(Basics(client))
|
||||
await client.run()
|
||||
client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
main()
|
||||
|
@ -230,7 +230,7 @@ class TestCog(Cog):
|
||||
|
||||
|
||||
# --- Main Bot Script ---
|
||||
async def main():
|
||||
def main():
|
||||
bot_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||
application_id = os.getenv("DISCORD_APPLICATION_ID")
|
||||
|
||||
@ -291,7 +291,7 @@ async def main():
|
||||
client.add_cog(TestCog(client))
|
||||
|
||||
try:
|
||||
await client.run()
|
||||
client.run()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Bot shutting down...")
|
||||
except Exception as e:
|
||||
@ -300,7 +300,7 @@ async def main():
|
||||
)
|
||||
finally:
|
||||
if not client.is_closed():
|
||||
await client.close()
|
||||
asyncio.run(client.close())
|
||||
logger.info("Bot has been closed.")
|
||||
|
||||
|
||||
@ -310,6 +310,6 @@ if __name__ == "__main__":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
try:
|
||||
asyncio.run(main())
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main loop interrupted. Exiting.")
|
||||
|
@ -62,9 +62,9 @@ async def on_ready():
|
||||
print("------")
|
||||
|
||||
|
||||
async def main():
|
||||
await client.run()
|
||||
def main():
|
||||
client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
main()
|
||||
|
@ -63,6 +63,4 @@ async def on_ready():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(client.run())
|
||||
client.run()
|
||||
|
@ -9,7 +9,15 @@ from typing import Set
|
||||
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from disagreement import Client, GatewayIntent, Member, Message, Cog, command, CommandContext
|
||||
from disagreement import (
|
||||
Client,
|
||||
GatewayIntent,
|
||||
Member,
|
||||
Message,
|
||||
Cog,
|
||||
command,
|
||||
CommandContext,
|
||||
)
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@ -26,9 +34,7 @@ if not BOT_TOKEN:
|
||||
sys.exit(1)
|
||||
|
||||
intents = (
|
||||
GatewayIntent.GUILDS
|
||||
| GatewayIntent.GUILD_MESSAGES
|
||||
| GatewayIntent.MESSAGE_CONTENT
|
||||
GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
|
||||
)
|
||||
client = Client(token=BOT_TOKEN, command_prefix="!", intents=intents)
|
||||
|
||||
@ -78,10 +84,10 @@ async def on_message(message: Message) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
def main() -> None:
|
||||
client.add_cog(ModerationCog(client))
|
||||
await client.run()
|
||||
client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
main()
|
||||
|
@ -137,11 +137,11 @@ async def on_reaction_remove(reaction: Reaction, user: User | Member):
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
async def main():
|
||||
def main():
|
||||
print("Starting Reaction Bot...")
|
||||
try:
|
||||
client.add_cog(ReactionCog(client))
|
||||
await client.run()
|
||||
client.run()
|
||||
except AuthenticationError:
|
||||
print("Authentication failed. Check your bot token.")
|
||||
except Exception as e:
|
||||
@ -149,9 +149,9 @@ async def main():
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
if not client.is_closed():
|
||||
await client.close()
|
||||
asyncio.run(client.close())
|
||||
print("Bot has been shut down.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
main()
|
||||
|
@ -34,12 +34,12 @@ async def on_ready():
|
||||
print("Shard bot ready")
|
||||
|
||||
|
||||
async def main():
|
||||
def main():
|
||||
if not TOKEN:
|
||||
print("DISCORD_BOT_TOKEN environment variable not set")
|
||||
return
|
||||
await client.run()
|
||||
client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
main()
|
||||
|
@ -53,9 +53,7 @@ BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
|
||||
# --- Intents Configuration ---
|
||||
intents = (
|
||||
GatewayIntent.GUILDS
|
||||
| GatewayIntent.GUILD_MESSAGES
|
||||
| GatewayIntent.MESSAGE_CONTENT
|
||||
GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
|
||||
)
|
||||
|
||||
# --- Initialize the Client ---
|
||||
@ -106,11 +104,11 @@ async def on_ready():
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
async def main():
|
||||
def main():
|
||||
print("Starting Typing Indicator Bot...")
|
||||
try:
|
||||
client.add_cog(TypingCog(client))
|
||||
await client.run()
|
||||
client.run()
|
||||
except AuthenticationError:
|
||||
print("Authentication failed. Check your bot token.")
|
||||
except Exception as e:
|
||||
@ -118,9 +116,9 @@ async def main():
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
if not client.is_closed():
|
||||
await client.close()
|
||||
asyncio.run(client.close())
|
||||
print("Bot has been shut down.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
main()
|
||||
|
@ -63,3 +63,4 @@ nav:
|
||||
- 'OAuth2': 'oauth2.md'
|
||||
- 'Presence': 'presence.md'
|
||||
- 'Voice Client': 'voice_client.md'
|
||||
- 'Utility Helpers': 'utils.md'
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "disagreement"
|
||||
version = "0.8.0"
|
||||
version = "0.8.1"
|
||||
description = "A Python library for the Discord API."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
@ -3,7 +3,16 @@ import pytest
|
||||
from disagreement.ext.commands.converters import run_converters
|
||||
from disagreement.ext.commands.core import CommandContext, Command
|
||||
from disagreement.ext.commands.errors import BadArgument
|
||||
from disagreement.models import Message, Member, Role, Guild
|
||||
from disagreement.models import (
|
||||
Message,
|
||||
Member,
|
||||
Role,
|
||||
Guild,
|
||||
User,
|
||||
TextChannel,
|
||||
VoiceChannel,
|
||||
PartialEmoji,
|
||||
)
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
@ -11,21 +20,27 @@ from disagreement.enums import (
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
ChannelType,
|
||||
)
|
||||
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.cache import GuildCache
|
||||
from disagreement.cache import GuildCache, Cache, ChannelCache
|
||||
|
||||
|
||||
class DummyBot(Client):
|
||||
def __init__(self):
|
||||
super().__init__(token="test")
|
||||
self._guilds = GuildCache()
|
||||
self._users = Cache()
|
||||
self._channels = ChannelCache()
|
||||
|
||||
def get_guild(self, guild_id):
|
||||
return self._guilds.get(guild_id)
|
||||
|
||||
def get_channel(self, channel_id):
|
||||
return self._channels.get(channel_id)
|
||||
|
||||
async def fetch_member(self, guild_id, member_id):
|
||||
guild = self._guilds.get(guild_id)
|
||||
return guild.get_member(member_id) if guild else None
|
||||
@ -37,6 +52,12 @@ class DummyBot(Client):
|
||||
async def fetch_guild(self, guild_id):
|
||||
return self._guilds.get(guild_id)
|
||||
|
||||
async def fetch_user(self, user_id):
|
||||
return self._users.get(user_id)
|
||||
|
||||
async def fetch_channel(self, channel_id):
|
||||
return self._channels.get(channel_id)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def guild_objects():
|
||||
@ -60,6 +81,9 @@ def guild_objects():
|
||||
guild = Guild(guild_data, client_instance=bot)
|
||||
bot._guilds.set(guild.id, guild)
|
||||
|
||||
user = User({"id": "7", "username": "u", "discriminator": "0001"})
|
||||
bot._users.set(user.id, user)
|
||||
|
||||
member = Member(
|
||||
{
|
||||
"user": {"id": "3", "username": "m", "discriminator": "0001"},
|
||||
@ -86,12 +110,38 @@ def guild_objects():
|
||||
guild._members.set(member.id, member)
|
||||
guild.roles.append(role)
|
||||
|
||||
return guild, member, role
|
||||
text_channel = TextChannel(
|
||||
{
|
||||
"id": "20",
|
||||
"type": ChannelType.GUILD_TEXT.value,
|
||||
"guild_id": guild.id,
|
||||
"permission_overwrites": [],
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
voice_channel = VoiceChannel(
|
||||
{
|
||||
"id": "21",
|
||||
"type": ChannelType.GUILD_VOICE.value,
|
||||
"guild_id": guild.id,
|
||||
"permission_overwrites": [],
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
|
||||
guild._channels.set(text_channel.id, text_channel)
|
||||
guild.text_channels.append(text_channel)
|
||||
guild._channels.set(voice_channel.id, voice_channel)
|
||||
guild.voice_channels.append(voice_channel)
|
||||
bot._channels.set(text_channel.id, text_channel)
|
||||
bot._channels.set(voice_channel.id, voice_channel)
|
||||
|
||||
return guild, member, role, user, text_channel, voice_channel
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def command_context(guild_objects):
|
||||
guild, member, role = guild_objects
|
||||
guild, member, role, _, _, _ = guild_objects
|
||||
bot = guild._client
|
||||
message_data = {
|
||||
"id": "10",
|
||||
@ -114,7 +164,7 @@ def command_context(guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_converter(command_context, guild_objects):
|
||||
_, member, _ = guild_objects
|
||||
_, member, _, _, _, _ = guild_objects
|
||||
mention = f"<@!{member.id}>"
|
||||
result = await run_converters(command_context, Member, mention)
|
||||
assert result is member
|
||||
@ -124,7 +174,7 @@ async def test_member_converter(command_context, guild_objects):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_converter(command_context, guild_objects):
|
||||
_, _, role = guild_objects
|
||||
_, _, role, _, _, _ = guild_objects
|
||||
mention = f"<@&{role.id}>"
|
||||
result = await run_converters(command_context, Role, mention)
|
||||
assert result is role
|
||||
@ -132,13 +182,55 @@ async def test_role_converter(command_context, guild_objects):
|
||||
assert result is role
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_converter(command_context, guild_objects):
|
||||
_, _, _, user, _, _ = guild_objects
|
||||
mention = f"<@{user.id}>"
|
||||
result = await run_converters(command_context, User, mention)
|
||||
assert result is user
|
||||
result = await run_converters(command_context, User, user.id)
|
||||
assert result is user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_converter(command_context, guild_objects):
|
||||
guild, _, _ = guild_objects
|
||||
guild, _, _, _, _, _ = guild_objects
|
||||
result = await run_converters(command_context, Guild, guild.id)
|
||||
assert result is guild
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_channel_converter(command_context, guild_objects):
|
||||
_, _, _, _, text_channel, _ = guild_objects
|
||||
mention = f"<#{text_channel.id}>"
|
||||
result = await run_converters(command_context, TextChannel, mention)
|
||||
assert result is text_channel
|
||||
result = await run_converters(command_context, TextChannel, text_channel.id)
|
||||
assert result is text_channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_channel_converter(command_context, guild_objects):
|
||||
_, _, _, _, _, voice_channel = guild_objects
|
||||
mention = f"<#{voice_channel.id}>"
|
||||
result = await run_converters(command_context, VoiceChannel, mention)
|
||||
assert result is voice_channel
|
||||
result = await run_converters(command_context, VoiceChannel, voice_channel.id)
|
||||
assert result is voice_channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emoji_converter(command_context):
|
||||
result = await run_converters(command_context, PartialEmoji, "<:smile:1>")
|
||||
assert isinstance(result, PartialEmoji)
|
||||
assert result.id == "1"
|
||||
assert result.name == "smile"
|
||||
|
||||
result = await run_converters(command_context, PartialEmoji, "😄")
|
||||
assert result.id is None
|
||||
assert result.name == "😄"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_converter_no_guild():
|
||||
guild_data = {
|
||||
|
14
tests/test_asset.py
Normal file
14
tests/test_asset.py
Normal file
@ -0,0 +1,14 @@
|
||||
from disagreement.models import User
|
||||
from disagreement.asset import Asset
|
||||
|
||||
|
||||
def test_user_avatar_returns_asset():
|
||||
user = User({"id": "1", "username": "u", "discriminator": "0001", "avatar": "abc"})
|
||||
avatar = user.avatar
|
||||
assert isinstance(avatar, Asset)
|
||||
assert avatar.url == "https://cdn.discordapp.com/avatars/1/abc.png"
|
||||
|
||||
|
||||
def test_user_avatar_none():
|
||||
user = User({"id": "1", "username": "u", "discriminator": "0001"})
|
||||
assert user.avatar is None
|
83
tests/test_auto_sync_commands.py
Normal file
83
tests/test_auto_sync_commands.py
Normal file
@ -0,0 +1,83 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.gateway import GatewayClient
|
||||
from disagreement.event_dispatcher import EventDispatcher
|
||||
|
||||
|
||||
class DummyHTTP:
|
||||
pass
|
||||
|
||||
|
||||
class DummyUser:
|
||||
username = "u"
|
||||
discriminator = "0001"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_sync_on_ready(monkeypatch):
|
||||
client = Client(token="t", application_id="123")
|
||||
http = DummyHTTP()
|
||||
dispatcher = EventDispatcher(client)
|
||||
gw = GatewayClient(
|
||||
http_client=http,
|
||||
event_dispatcher=dispatcher,
|
||||
token="t",
|
||||
intents=0,
|
||||
client_instance=client,
|
||||
)
|
||||
monkeypatch.setattr(client, "parse_user", lambda d: DummyUser())
|
||||
monkeypatch.setattr(gw._dispatcher, "dispatch", AsyncMock())
|
||||
sync_mock = AsyncMock()
|
||||
monkeypatch.setattr(client, "sync_application_commands", sync_mock)
|
||||
|
||||
data = {
|
||||
"t": "READY",
|
||||
"s": 1,
|
||||
"d": {
|
||||
"session_id": "s1",
|
||||
"resume_gateway_url": "url",
|
||||
"application": {"id": "123"},
|
||||
"user": {"id": "1"},
|
||||
},
|
||||
}
|
||||
|
||||
await gw._handle_dispatch(data)
|
||||
await asyncio.sleep(0)
|
||||
sync_mock.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_sync_disabled(monkeypatch):
|
||||
client = Client(token="t", application_id="123", sync_commands_on_ready=False)
|
||||
http = DummyHTTP()
|
||||
dispatcher = EventDispatcher(client)
|
||||
gw = GatewayClient(
|
||||
http_client=http,
|
||||
event_dispatcher=dispatcher,
|
||||
token="t",
|
||||
intents=0,
|
||||
client_instance=client,
|
||||
)
|
||||
monkeypatch.setattr(client, "parse_user", lambda d: DummyUser())
|
||||
monkeypatch.setattr(gw._dispatcher, "dispatch", AsyncMock())
|
||||
sync_mock = AsyncMock()
|
||||
monkeypatch.setattr(client, "sync_application_commands", sync_mock)
|
||||
|
||||
data = {
|
||||
"t": "READY",
|
||||
"s": 1,
|
||||
"d": {
|
||||
"session_id": "s1",
|
||||
"resume_gateway_url": "url",
|
||||
"application": {"id": "123"},
|
||||
"user": {"id": "1"},
|
||||
},
|
||||
}
|
||||
|
||||
await gw._handle_dispatch(data)
|
||||
await asyncio.sleep(0)
|
||||
sync_mock.assert_not_called()
|
@ -1,6 +1,60 @@
|
||||
import time
|
||||
|
||||
from disagreement.cache import Cache
|
||||
from disagreement.client import Client
|
||||
from disagreement.caching import MemberCacheFlags
|
||||
from disagreement.enums import (
|
||||
ChannelType,
|
||||
ExplicitContentFilterLevel,
|
||||
GuildNSFWLevel,
|
||||
MFALevel,
|
||||
MessageNotificationLevel,
|
||||
PremiumTier,
|
||||
VerificationLevel,
|
||||
)
|
||||
|
||||
|
||||
def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict:
|
||||
base = {
|
||||
"id": gid,
|
||||
"name": f"g{gid}",
|
||||
"owner_id": "1",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
"channels": [],
|
||||
"members": [],
|
||||
}
|
||||
for i in range(channel_count):
|
||||
base["channels"].append(
|
||||
{
|
||||
"id": f"{gid}-c{i}",
|
||||
"type": ChannelType.GUILD_TEXT.value,
|
||||
"guild_id": gid,
|
||||
"permission_overwrites": [],
|
||||
}
|
||||
)
|
||||
for i in range(member_count):
|
||||
base["members"].append(
|
||||
{
|
||||
"user": {
|
||||
"id": f"{gid}-m{i}",
|
||||
"username": f"u{i}",
|
||||
"discriminator": "0001",
|
||||
},
|
||||
"joined_at": "t",
|
||||
"roles": [],
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
|
||||
def test_cache_store_and_get():
|
||||
@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item():
|
||||
|
||||
assert cache.get_or_fetch("c", fetch) == 3
|
||||
assert called
|
||||
|
||||
|
||||
def test_client_get_all_channels_and_members():
|
||||
client = Client(token="t")
|
||||
client.parse_guild(_guild_payload("1", 2, 2))
|
||||
client.parse_guild(_guild_payload("2", 1, 1))
|
||||
|
||||
channels = {c.id for c in client.get_all_channels()}
|
||||
members = {m.id for m in client.get_all_members()}
|
||||
|
||||
assert channels == {"1-c0", "1-c1", "2-c0"}
|
||||
assert members == {"1-m0", "1-m1", "2-m0"}
|
||||
|
||||
|
||||
def test_client_get_all_members_disabled_cache():
|
||||
client = Client(token="t", member_cache_flags=MemberCacheFlags.none())
|
||||
client.parse_guild(_guild_payload("1", 1, 2))
|
||||
|
||||
assert client.get_all_members() == []
|
||||
|
@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
# pylint: disable=no-member
|
||||
|
||||
from disagreement.client import Client
|
||||
|
||||
|
||||
|
42
tests/test_client_uptime.py
Normal file
42
tests/test_client_uptime.py
Normal file
@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.client import Client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_records_start_time(monkeypatch):
|
||||
start = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
monkeypatch.setattr("disagreement.client.utcnow", lambda: start)
|
||||
|
||||
client = Client(token="t")
|
||||
monkeypatch.setattr(client, "_initialize_gateway", AsyncMock())
|
||||
client._gateway = SimpleNamespace(connect=AsyncMock())
|
||||
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
|
||||
|
||||
assert client.start_time is None
|
||||
await client.connect()
|
||||
assert client.start_time == start
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_uptime(monkeypatch):
|
||||
start = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||
end = start + timedelta(seconds=5)
|
||||
times = [start, end]
|
||||
|
||||
def fake_now():
|
||||
return times.pop(0)
|
||||
|
||||
monkeypatch.setattr("disagreement.client.utcnow", fake_now)
|
||||
|
||||
client = Client(token="t")
|
||||
monkeypatch.setattr(client, "_initialize_gateway", AsyncMock())
|
||||
client._gateway = SimpleNamespace(connect=AsyncMock())
|
||||
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
|
||||
|
||||
await client.connect()
|
||||
assert client.uptime() == timedelta(seconds=5)
|
@ -6,6 +6,7 @@ from disagreement.ext.commands.decorators import (
|
||||
check,
|
||||
cooldown,
|
||||
requires_permissions,
|
||||
is_owner,
|
||||
)
|
||||
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
|
||||
from disagreement.permissions import Permissions
|
||||
@ -133,3 +134,44 @@ async def test_requires_permissions_fail(message):
|
||||
|
||||
with pytest.raises(CheckFailure):
|
||||
await cmd.invoke(ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_owner_pass(message):
|
||||
message._client.owner_ids = ["2"]
|
||||
|
||||
@is_owner()
|
||||
async def cb(ctx):
|
||||
pass
|
||||
|
||||
cmd = Command(cb)
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=message._client,
|
||||
prefix="!",
|
||||
command=cmd,
|
||||
invoked_with="test",
|
||||
)
|
||||
|
||||
await cmd.invoke(ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_owner_fail(message):
|
||||
message._client.owner_ids = ["1"]
|
||||
|
||||
@is_owner()
|
||||
async def cb(ctx):
|
||||
pass
|
||||
|
||||
cmd = Command(cb)
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=message._client,
|
||||
prefix="!",
|
||||
command=cmd,
|
||||
invoked_with="test",
|
||||
)
|
||||
|
||||
with pytest.raises(CheckFailure):
|
||||
await cmd.invoke(ctx)
|
||||
|
59
tests/test_connect_events.py
Normal file
59
tests/test_connect_events.py
Normal file
@ -0,0 +1,59 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.shard_manager import ShardManager
|
||||
from disagreement.event_dispatcher import EventDispatcher
|
||||
|
||||
|
||||
class DummyGateway:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.connect = AsyncMock()
|
||||
self.close = AsyncMock()
|
||||
|
||||
dispatcher = kwargs.get("event_dispatcher")
|
||||
shard_id = kwargs.get("shard_id")
|
||||
|
||||
async def emit_connect():
|
||||
await dispatcher.dispatch("CONNECT", {"shard_id": shard_id})
|
||||
|
||||
async def emit_close():
|
||||
await dispatcher.dispatch("DISCONNECT", {"shard_id": shard_id})
|
||||
|
||||
self.connect.side_effect = emit_connect
|
||||
self.close.side_effect = emit_close
|
||||
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self):
|
||||
self._http = object()
|
||||
self._event_dispatcher = EventDispatcher(self)
|
||||
self.token = "t"
|
||||
self.intents = 0
|
||||
self.verbose = False
|
||||
self.gateway_max_retries = 5
|
||||
self.gateway_max_backoff = 60.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_disconnect_events(monkeypatch):
|
||||
monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway)
|
||||
client = DummyClient()
|
||||
manager = ShardManager(client, shard_count=1)
|
||||
|
||||
events: list[tuple[str, int | None]] = []
|
||||
|
||||
async def on_connect(info):
|
||||
events.append(("connect", info.get("shard_id")))
|
||||
|
||||
async def on_disconnect(info):
|
||||
events.append(("disconnect", info.get("shard_id")))
|
||||
|
||||
client._event_dispatcher.register("CONNECT", on_connect)
|
||||
client._event_dispatcher.register("DISCONNECT", on_disconnect)
|
||||
|
||||
await manager.start()
|
||||
await manager.close()
|
||||
|
||||
assert ("connect", 0) in events
|
||||
assert ("disconnect", 0) in events
|
41
tests/test_crosspost_message.py
Normal file
41
tests/test_crosspost_message.py
Normal file
@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.http import HTTPClient
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_crosspost_message_calls_request():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(return_value={"id": "m"})
|
||||
data = await http.crosspost_message("c", "m")
|
||||
http.request.assert_called_once_with(
|
||||
"POST",
|
||||
"/channels/c/messages/m/crosspost",
|
||||
)
|
||||
assert data == {"id": "m"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_crosspost_returns_message():
|
||||
payload = {
|
||||
"id": "2",
|
||||
"channel_id": "1",
|
||||
"author": {"id": "3", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
http = SimpleNamespace(crosspost_message=AsyncMock(return_value=payload))
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
client.parse_message = lambda d: Message(d, client_instance=client)
|
||||
message = Message(payload, client_instance=client)
|
||||
|
||||
new_msg = await message.crosspost()
|
||||
|
||||
http.crosspost_message.assert_awaited_once_with("1", "2")
|
||||
assert isinstance(new_msg, Message)
|
||||
assert new_msg._client is client
|
@ -22,6 +22,38 @@ def create_dummy_module(name):
|
||||
return called
|
||||
|
||||
|
||||
def create_async_module(name):
|
||||
mod = types.ModuleType(name)
|
||||
called = {"setup": False, "teardown": False}
|
||||
|
||||
async def setup():
|
||||
called["setup"] = True
|
||||
|
||||
def teardown():
|
||||
called["teardown"] = True
|
||||
|
||||
mod.setup = setup
|
||||
mod.teardown = teardown
|
||||
sys.modules[name] = mod
|
||||
return called
|
||||
|
||||
|
||||
def create_async_teardown_module(name):
|
||||
mod = types.ModuleType(name)
|
||||
called = {"setup": False, "teardown": False}
|
||||
|
||||
def setup():
|
||||
called["setup"] = True
|
||||
|
||||
async def teardown():
|
||||
called["teardown"] = True
|
||||
|
||||
mod.setup = setup
|
||||
mod.teardown = teardown
|
||||
sys.modules[name] = mod
|
||||
return called
|
||||
|
||||
|
||||
def test_load_and_unload_extension():
|
||||
called = create_dummy_module("dummy_ext")
|
||||
|
||||
@ -75,3 +107,23 @@ def test_reload_extension(monkeypatch):
|
||||
|
||||
loader.unload_extension("reload_ext")
|
||||
assert called_second["teardown"] is True
|
||||
|
||||
|
||||
def test_async_setup():
|
||||
called = create_async_module("async_ext")
|
||||
|
||||
loader.load_extension("async_ext")
|
||||
assert called["setup"] is True
|
||||
|
||||
loader.unload_extension("async_ext")
|
||||
assert called["teardown"] is True
|
||||
|
||||
|
||||
def test_async_teardown():
|
||||
called = create_async_teardown_module("async_teardown_ext")
|
||||
|
||||
loader.load_extension("async_teardown_ext")
|
||||
assert called["setup"] is True
|
||||
|
||||
loader.unload_extension("async_teardown_ext")
|
||||
assert called["teardown"] is True
|
||||
|
29
tests/test_get_cog.py
Normal file
29
tests/test_get_cog.py
Normal file
@ -0,0 +1,29 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.ext import commands
|
||||
|
||||
|
||||
class DummyCog(commands.Cog):
|
||||
def __init__(self, client: Client) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_command_handler_get_cog():
|
||||
bot = object()
|
||||
handler = commands.core.CommandHandler(client=bot, prefix="!")
|
||||
cog = DummyCog(bot) # type: ignore[arg-type]
|
||||
handler.add_cog(cog)
|
||||
await asyncio.sleep(0) # allow any scheduled tasks to start
|
||||
assert handler.get_cog("DummyCog") is cog
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_client_get_cog():
|
||||
client = Client(token="t")
|
||||
cog = DummyCog(client)
|
||||
client.add_cog(cog)
|
||||
await asyncio.sleep(0)
|
||||
assert client.get_cog("DummyCog") is cog
|
59
tests/test_get_context.py
Normal file
59
tests/test_get_context.py
Normal file
@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.ext.commands.core import Command, CommandHandler
|
||||
from disagreement.models import Message
|
||||
|
||||
|
||||
class DummyBot:
|
||||
def __init__(self):
|
||||
self.executed = False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_parses_without_execution():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
|
||||
async def foo(ctx, number: int, word: str):
|
||||
bot.executed = True
|
||||
|
||||
handler.add_command(Command(foo, name="foo"))
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "!foo 1 bar",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
|
||||
ctx = await handler.get_context(msg)
|
||||
assert ctx is not None
|
||||
assert ctx.command.name == "foo"
|
||||
assert ctx.args == [1, "bar"]
|
||||
assert bot.executed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_get_context():
|
||||
client = Client(token="t")
|
||||
|
||||
async def foo(ctx):
|
||||
raise RuntimeError("should not run")
|
||||
|
||||
client.command_handler.add_command(Command(foo, name="foo"))
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "!foo",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(msg_data, client_instance=client)
|
||||
|
||||
ctx = await client.get_context(msg)
|
||||
assert ctx is not None
|
||||
assert ctx.command.name == "foo"
|
126
tests/test_guild_channel_create.py
Normal file
126
tests/test_guild_channel_create.py
Normal file
@ -0,0 +1,126 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.http import HTTPClient
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import Guild, TextChannel, VoiceChannel, CategoryChannel
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
ChannelType,
|
||||
)
|
||||
|
||||
|
||||
def _guild_data():
|
||||
return {
|
||||
"id": "1",
|
||||
"name": "g",
|
||||
"owner_id": "1",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_create_guild_channel_calls_request():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(return_value={})
|
||||
payload = {"name": "chan", "type": ChannelType.GUILD_TEXT.value}
|
||||
|
||||
await http.create_guild_channel("1", payload, reason="r")
|
||||
|
||||
http.request.assert_called_once_with(
|
||||
"POST",
|
||||
"/guilds/1/channels",
|
||||
payload=payload,
|
||||
custom_headers={"X-Audit-Log-Reason": "r"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_create_text_channel_returns_channel():
|
||||
http = SimpleNamespace(
|
||||
create_guild_channel=AsyncMock(
|
||||
return_value={
|
||||
"id": "10",
|
||||
"type": ChannelType.GUILD_TEXT.value,
|
||||
"guild_id": "1",
|
||||
"permission_overwrites": [],
|
||||
}
|
||||
)
|
||||
)
|
||||
client = Client(token="t")
|
||||
client._http = http
|
||||
guild = Guild(_guild_data(), client_instance=client)
|
||||
|
||||
channel = await guild.create_text_channel("general")
|
||||
|
||||
http.create_guild_channel.assert_awaited_once_with(
|
||||
"1", {"name": "general", "type": ChannelType.GUILD_TEXT.value}, reason=None
|
||||
)
|
||||
assert isinstance(channel, TextChannel)
|
||||
assert client._channels.get("10") is channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_create_voice_channel_returns_channel():
|
||||
http = SimpleNamespace(
|
||||
create_guild_channel=AsyncMock(
|
||||
return_value={
|
||||
"id": "11",
|
||||
"type": ChannelType.GUILD_VOICE.value,
|
||||
"guild_id": "1",
|
||||
"permission_overwrites": [],
|
||||
}
|
||||
)
|
||||
)
|
||||
client = Client(token="t")
|
||||
client._http = http
|
||||
guild = Guild(_guild_data(), client_instance=client)
|
||||
|
||||
channel = await guild.create_voice_channel("Voice")
|
||||
|
||||
http.create_guild_channel.assert_awaited_once_with(
|
||||
"1", {"name": "Voice", "type": ChannelType.GUILD_VOICE.value}, reason=None
|
||||
)
|
||||
assert isinstance(channel, VoiceChannel)
|
||||
assert client._channels.get("11") is channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_create_category_returns_channel():
|
||||
http = SimpleNamespace(
|
||||
create_guild_channel=AsyncMock(
|
||||
return_value={
|
||||
"id": "12",
|
||||
"type": ChannelType.GUILD_CATEGORY.value,
|
||||
"guild_id": "1",
|
||||
"permission_overwrites": [],
|
||||
}
|
||||
)
|
||||
)
|
||||
client = Client(token="t")
|
||||
client._http = http
|
||||
guild = Guild(_guild_data(), client_instance=client)
|
||||
|
||||
channel = await guild.create_category("Cat")
|
||||
|
||||
http.create_guild_channel.assert_awaited_once_with(
|
||||
"1", {"name": "Cat", "type": ChannelType.GUILD_CATEGORY.value}, reason=None
|
||||
)
|
||||
assert isinstance(channel, CategoryChannel)
|
||||
assert client._channels.get("12") is channel
|
63
tests/test_guild_channel_lists.py
Normal file
63
tests/test_guild_channel_lists.py
Normal file
@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.enums import (
|
||||
ChannelType,
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
)
|
||||
from disagreement.models import TextChannel, VoiceChannel, CategoryChannel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_channel_lists_populated():
|
||||
client = Client(token="t")
|
||||
guild_data = {
|
||||
"id": "1",
|
||||
"name": "g",
|
||||
"owner_id": "1",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
"channels": [
|
||||
{
|
||||
"id": "10",
|
||||
"type": ChannelType.GUILD_TEXT.value,
|
||||
"guild_id": "1",
|
||||
"permission_overwrites": [],
|
||||
},
|
||||
{
|
||||
"id": "11",
|
||||
"type": ChannelType.GUILD_VOICE.value,
|
||||
"guild_id": "1",
|
||||
"permission_overwrites": [],
|
||||
},
|
||||
{
|
||||
"id": "12",
|
||||
"type": ChannelType.GUILD_CATEGORY.value,
|
||||
"guild_id": "1",
|
||||
"permission_overwrites": [],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
guild = client.parse_guild(guild_data)
|
||||
|
||||
assert len(guild.text_channels) == 1
|
||||
assert isinstance(guild.text_channels[0], TextChannel)
|
||||
assert len(guild.voice_channels) == 1
|
||||
assert isinstance(guild.voice_channels[0], VoiceChannel)
|
||||
assert len(guild.category_channels) == 1
|
||||
assert isinstance(guild.category_channels[0], CategoryChannel)
|
64
tests/test_guild_prune.py
Normal file
64
tests/test_guild_prune.py
Normal file
@ -0,0 +1,64 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.http import HTTPClient
|
||||
from disagreement.client import Client
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
)
|
||||
from disagreement.models import Guild
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_get_guild_prune_count_calls_request():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(return_value={"pruned": 3})
|
||||
count = await http.get_guild_prune_count("1", days=7)
|
||||
http.request.assert_called_once_with("GET", f"/guilds/1/prune", params={"days": 7})
|
||||
assert count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_begin_guild_prune_calls_request():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(return_value={"pruned": 2})
|
||||
count = await http.begin_guild_prune("1", days=1, compute_count=True)
|
||||
http.request.assert_called_once_with(
|
||||
"POST",
|
||||
f"/guilds/1/prune",
|
||||
payload={"days": 1, "compute_prune_count": True},
|
||||
)
|
||||
assert count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_prune_members_calls_http():
|
||||
http = SimpleNamespace(begin_guild_prune=AsyncMock(return_value=1))
|
||||
client = Client(token="t")
|
||||
client._http = http
|
||||
guild_data = {
|
||||
"id": "1",
|
||||
"name": "g",
|
||||
"owner_id": "1",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
}
|
||||
guild = Guild(guild_data, client_instance=client)
|
||||
count = await guild.prune_members(2)
|
||||
http.begin_guild_prune.assert_awaited_once_with("1", days=2, compute_count=True)
|
||||
assert count == 1
|
56
tests/test_guild_shard_property.py
Normal file
56
tests/test_guild_shard_property.py
Normal file
@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from disagreement.models import Guild
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
)
|
||||
|
||||
|
||||
class DummyShard:
|
||||
def __init__(self, shard_id):
|
||||
self.id = shard_id
|
||||
self.count = 1
|
||||
self.gateway = Mock()
|
||||
|
||||
|
||||
class DummyManager:
|
||||
def __init__(self):
|
||||
self.shards = [DummyShard(0)]
|
||||
|
||||
|
||||
class DummyClient:
|
||||
pass
|
||||
|
||||
|
||||
def _guild_data():
|
||||
return {
|
||||
"id": "1",
|
||||
"name": "g",
|
||||
"owner_id": "1",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
"shard_id": 0,
|
||||
}
|
||||
|
||||
|
||||
def test_guild_shard_property():
|
||||
client = DummyClient()
|
||||
client._shard_manager = DummyManager()
|
||||
guild = Guild(_guild_data(), client_instance=client, shard_id=0)
|
||||
assert guild.shard_id == 0
|
||||
assert guild.shard is client._shard_manager.shards[0]
|
86
tests/test_hashable_mixin.py
Normal file
86
tests/test_hashable_mixin.py
Normal file
@ -0,0 +1,86 @@
|
||||
import types
|
||||
from disagreement.models import User, Guild, Channel, Message
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
ChannelType,
|
||||
)
|
||||
|
||||
|
||||
def _guild_data(gid="1"):
|
||||
return {
|
||||
"id": gid,
|
||||
"name": "g",
|
||||
"owner_id": gid,
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
}
|
||||
|
||||
|
||||
def _user(uid="1"):
|
||||
return User({"id": uid, "username": "u", "discriminator": "0001"})
|
||||
|
||||
|
||||
def _message(mid="1"):
|
||||
data = {
|
||||
"id": mid,
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
return Message(data, client_instance=types.SimpleNamespace())
|
||||
|
||||
|
||||
def _channel(cid="1"):
|
||||
data = {"id": cid, "type": ChannelType.GUILD_TEXT.value}
|
||||
return Channel(data, client_instance=types.SimpleNamespace())
|
||||
|
||||
|
||||
def test_user_hash_and_eq():
|
||||
a = _user()
|
||||
b = _user()
|
||||
c = _user("2")
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
assert a != c
|
||||
|
||||
|
||||
def test_guild_hash_and_eq():
|
||||
a = Guild(_guild_data(), client_instance=types.SimpleNamespace())
|
||||
b = Guild(_guild_data(), client_instance=types.SimpleNamespace())
|
||||
c = Guild(_guild_data("2"), client_instance=types.SimpleNamespace())
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
assert a != c
|
||||
|
||||
|
||||
def test_channel_hash_and_eq():
|
||||
a = _channel()
|
||||
b = _channel()
|
||||
c = _channel("2")
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
assert a != c
|
||||
|
||||
|
||||
def test_message_hash_and_eq():
|
||||
a = _message()
|
||||
b = _message()
|
||||
c = _message("2")
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
assert a != c
|
@ -1,7 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.commands.core import CommandHandler, Command
|
||||
from disagreement.ext import commands
|
||||
from disagreement.ext.commands.core import CommandHandler, Command, Group
|
||||
from disagreement.models import Message
|
||||
from disagreement.ext.commands.help import HelpCommand
|
||||
|
||||
|
||||
class DummyBot:
|
||||
@ -13,15 +15,21 @@ class DummyBot:
|
||||
return {"id": "1", "channel_id": channel_id, "content": content}
|
||||
|
||||
|
||||
class MyCog(commands.Cog):
|
||||
def __init__(self, client) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
@commands.command()
|
||||
async def foo(self, ctx: commands.CommandContext) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_lists_commands():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
|
||||
async def foo(ctx):
|
||||
pass
|
||||
|
||||
handler.add_command(Command(foo, name="foo", brief="Foo cmd"))
|
||||
handler.add_cog(MyCog(bot))
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
@ -33,6 +41,7 @@ async def test_help_lists_commands():
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("foo" in m for m in bot.sent)
|
||||
assert any("MyCog" in m for m in bot.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -55,3 +64,65 @@ async def test_help_specific_command():
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("Bar desc" in m for m in bot.sent)
|
||||
|
||||
|
||||
class CustomHelp(HelpCommand):
|
||||
async def send_command_help(self, ctx, command):
|
||||
await ctx.send(f"custom {command.name}")
|
||||
|
||||
async def send_group_help(self, ctx, group):
|
||||
await ctx.send(f"group {group.name}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_help_methods():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
handler.remove_command("help")
|
||||
handler.add_command(CustomHelp(handler))
|
||||
|
||||
async def sub(ctx):
|
||||
pass
|
||||
|
||||
group = Group(sub, name="grp")
|
||||
handler.add_command(group)
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "!help grp",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("group grp" in m for m in bot.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_lists_subcommands():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
|
||||
async def root(ctx):
|
||||
pass
|
||||
|
||||
group = Group(root, name="root")
|
||||
|
||||
@group.command(name="child")
|
||||
async def child(ctx):
|
||||
pass
|
||||
|
||||
handler.add_command(group)
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "!help",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("root" in m for m in bot.sent)
|
||||
assert any("child" in m for m in bot.sent)
|
||||
|
31
tests/test_invite.py
Normal file
31
tests/test_invite.py
Normal file
@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.http import HTTPClient
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import Invite
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_get_invite_calls_request():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(return_value={"code": "abc"})
|
||||
|
||||
result = await http.get_invite("abc")
|
||||
|
||||
http.request.assert_called_once_with("GET", "/invites/abc")
|
||||
assert result == {"code": "abc"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_fetch_invite_returns_invite():
|
||||
http = SimpleNamespace(get_invite=AsyncMock(return_value={"code": "abc"}))
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
client._closed = False
|
||||
|
||||
invite = await client.fetch_invite("abc")
|
||||
|
||||
http.get_invite.assert_awaited_once_with("abc")
|
||||
assert isinstance(invite, Invite)
|
39
tests/test_invites.py
Normal file
39
tests/test_invites.py
Normal file
@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.http import HTTPClient
|
||||
from disagreement.models import TextChannel, Invite
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_channel_invite_calls_request_and_returns_model():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(return_value={"code": "abc"})
|
||||
invite = await http.create_channel_invite("123", {"max_age": 60}, reason="r")
|
||||
|
||||
http.request.assert_called_once_with(
|
||||
"POST",
|
||||
"/channels/123/invites",
|
||||
payload={"max_age": 60},
|
||||
custom_headers={"X-Audit-Log-Reason": "r"},
|
||||
)
|
||||
assert isinstance(invite, Invite)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_textchannel_create_invite_uses_http():
|
||||
http = SimpleNamespace(
|
||||
create_channel_invite=AsyncMock(return_value=Invite.from_dict({"code": "a"}))
|
||||
)
|
||||
client = Client(token="t")
|
||||
client._http = http
|
||||
|
||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||
invite = await channel.create_invite(max_age=30, reason="why")
|
||||
|
||||
http.create_channel_invite.assert_awaited_once_with(
|
||||
"c", {"max_age": 30}, reason="why"
|
||||
)
|
||||
assert isinstance(invite, Invite)
|
@ -1,4 +1,21 @@
|
||||
from disagreement.models import Member
|
||||
import pytest # pylint: disable=E0401
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
)
|
||||
from disagreement.models import Member, Guild, Role
|
||||
from disagreement.permissions import Permissions
|
||||
|
||||
|
||||
class DummyClient(Client):
|
||||
def __init__(self):
|
||||
super().__init__(token="test")
|
||||
|
||||
|
||||
def _make_member(member_id: str, username: str, nick: str | None):
|
||||
@ -12,6 +29,58 @@ def _make_member(member_id: str, username: str, nick: str | None):
|
||||
return Member(data, client_instance=None)
|
||||
|
||||
|
||||
def _base_guild(client: Client) -> Guild:
|
||||
data = {
|
||||
"id": "1",
|
||||
"name": "g",
|
||||
"owner_id": "1",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
}
|
||||
guild = Guild(data, client_instance=client)
|
||||
client._guilds.set(guild.id, guild)
|
||||
return guild
|
||||
|
||||
|
||||
def _role(guild: Guild, rid: str, perms: Permissions) -> Role:
|
||||
role = Role(
|
||||
{
|
||||
"id": rid,
|
||||
"name": f"r{rid}",
|
||||
"color": 0,
|
||||
"hoist": False,
|
||||
"position": 0,
|
||||
"permissions": str(int(perms)),
|
||||
"managed": False,
|
||||
"mentionable": False,
|
||||
}
|
||||
)
|
||||
guild.roles.append(role)
|
||||
return role
|
||||
|
||||
|
||||
def _member(guild: Guild, client: Client, *roles: Role) -> Member:
|
||||
data = {
|
||||
"user": {"id": "10", "username": "u", "discriminator": "0001"},
|
||||
"joined_at": "t",
|
||||
"roles": [r.id for r in roles] or [guild.id],
|
||||
}
|
||||
member = Member(data, client_instance=client)
|
||||
member.guild_id = guild.id
|
||||
member._client = client
|
||||
guild._members.set(member.id, member)
|
||||
return member
|
||||
|
||||
|
||||
def test_display_name_prefers_nick():
|
||||
member = _make_member("1", "u", "nickname")
|
||||
assert member.display_name == "nickname"
|
||||
@ -20,3 +89,25 @@ def test_display_name_prefers_nick():
|
||||
def test_display_name_falls_back_to_username():
|
||||
member = _make_member("2", "u2", None)
|
||||
assert member.display_name == "u2"
|
||||
|
||||
|
||||
def test_guild_permissions_from_roles():
|
||||
client = DummyClient()
|
||||
guild = _base_guild(client)
|
||||
everyone = _role(guild, guild.id, Permissions.VIEW_CHANNEL)
|
||||
mod = _role(guild, "2", Permissions.MANAGE_MESSAGES)
|
||||
member = _member(guild, client, everyone, mod)
|
||||
|
||||
perms = member.guild_permissions
|
||||
assert perms & Permissions.VIEW_CHANNEL
|
||||
assert perms & Permissions.MANAGE_MESSAGES
|
||||
assert not perms & Permissions.BAN_MEMBERS
|
||||
|
||||
|
||||
def test_guild_permissions_administrator_role_grants_all():
|
||||
client = DummyClient()
|
||||
guild = _base_guild(client)
|
||||
admin = _role(guild, "2", Permissions.ADMINISTRATOR)
|
||||
member = _member(guild, client, admin)
|
||||
|
||||
assert member.guild_permissions == Permissions(~0)
|
||||
|
36
tests/test_member_voice.py
Normal file
36
tests/test_member_voice.py
Normal file
@ -0,0 +1,36 @@
|
||||
from disagreement.models import Member, VoiceState
|
||||
|
||||
|
||||
def test_member_voice_dataclass():
|
||||
data = {
|
||||
"user": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"joined_at": "t",
|
||||
"roles": [],
|
||||
"voice_state": {
|
||||
"guild_id": "g",
|
||||
"channel_id": "c",
|
||||
"user_id": "1",
|
||||
"session_id": "s",
|
||||
"deaf": False,
|
||||
"mute": True,
|
||||
"self_deaf": False,
|
||||
"self_mute": False,
|
||||
"self_video": False,
|
||||
"suppress": False,
|
||||
},
|
||||
}
|
||||
member = Member(data, client_instance=None)
|
||||
voice = member.voice
|
||||
assert isinstance(voice, VoiceState)
|
||||
assert voice.channel_id == "c"
|
||||
assert voice.mute is True
|
||||
|
||||
|
||||
def test_member_voice_none():
|
||||
data = {
|
||||
"user": {"id": "2", "username": "u2", "discriminator": "0001"},
|
||||
"joined_at": "t",
|
||||
"roles": [],
|
||||
}
|
||||
member = Member(data, client_instance=None)
|
||||
assert member.voice is None
|
@ -21,3 +21,19 @@ def test_clean_content_removes_mentions():
|
||||
def test_clean_content_no_mentions():
|
||||
msg = make_message("Just text")
|
||||
assert msg.clean_content == "Just text"
|
||||
|
||||
|
||||
def test_created_at_parses_timestamp():
|
||||
ts = "2024-05-04T12:34:56+00:00"
|
||||
msg = make_message("hi")
|
||||
msg.timestamp = ts
|
||||
assert msg.created_at.isoformat() == ts
|
||||
|
||||
|
||||
def test_edited_at_parses_timestamp_or_none():
|
||||
ts = "2024-05-04T12:35:56+00:00"
|
||||
msg = make_message("hi")
|
||||
msg.timestamp = ts
|
||||
assert msg.edited_at is None
|
||||
msg.edited_timestamp = ts
|
||||
assert msg.edited_at.isoformat() == ts
|
||||
|
15
tests/test_object.py
Normal file
15
tests/test_object.py
Normal file
@ -0,0 +1,15 @@
|
||||
from disagreement.object import Object
|
||||
|
||||
|
||||
def test_object_int():
|
||||
obj = Object(123)
|
||||
assert int(obj) == 123
|
||||
|
||||
|
||||
def test_object_equality_and_hash():
|
||||
a = Object(1)
|
||||
b = Object(1)
|
||||
c = Object(2)
|
||||
assert a == b
|
||||
assert a != c
|
||||
assert hash(a) == hash(b)
|
23
tests/test_paginator.py
Normal file
23
tests/test_paginator.py
Normal file
@ -0,0 +1,23 @@
|
||||
from disagreement.utils import Paginator
|
||||
|
||||
|
||||
def test_paginator_single_page():
|
||||
p = Paginator(limit=10)
|
||||
p.add_line("hi")
|
||||
p.add_line("there")
|
||||
assert p.pages == ["hi\nthere"]
|
||||
|
||||
|
||||
def test_paginator_splits_pages():
|
||||
p = Paginator(limit=10)
|
||||
p.add_line("12345")
|
||||
p.add_line("67890")
|
||||
assert p.pages == ["12345", "67890"]
|
||||
p.add_line("xyz")
|
||||
assert p.pages == ["12345", "67890\nxyz"]
|
||||
|
||||
|
||||
def test_paginator_handles_long_line():
|
||||
p = Paginator(limit=5)
|
||||
p.add_line("abcdef")
|
||||
assert p.pages == ["abcde", "f"]
|
@ -32,3 +32,11 @@ def test_missing_permissions():
|
||||
current, Permissions.SEND_MESSAGES, Permissions.MANAGE_MESSAGES
|
||||
)
|
||||
assert missing == [Permissions.MANAGE_MESSAGES]
|
||||
|
||||
|
||||
def test_permissions_all():
|
||||
all_value = Permissions.all()
|
||||
union = Permissions(0)
|
||||
for perm in Permissions:
|
||||
union |= perm
|
||||
assert all_value == union
|
||||
|
@ -1,3 +1,4 @@
|
||||
import io
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
@ -38,7 +39,9 @@ async def test_http_send_message_with_files_uses_formdata():
|
||||
"timestamp": "t",
|
||||
}
|
||||
)
|
||||
await http.send_message("c", "hi", files=[File("f.txt", b"data")])
|
||||
await http.send_message(
|
||||
"c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")]
|
||||
)
|
||||
args, kwargs = http.request.call_args
|
||||
assert kwargs["is_json"] is False
|
||||
|
||||
@ -75,7 +78,33 @@ async def test_client_send_message_passes_files():
|
||||
"timestamp": "t",
|
||||
}
|
||||
)
|
||||
await client.send_message("c", "hi", files=[File("f.txt", b"data")])
|
||||
await client.send_message(
|
||||
"c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")]
|
||||
)
|
||||
client._http.send_message.assert_awaited_once()
|
||||
kwargs = client._http.send_message.call_args.kwargs
|
||||
assert kwargs["files"][0].filename == "f.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_from_path(tmp_path):
|
||||
file_path = tmp_path / "path.txt"
|
||||
file_path.write_bytes(b"ok")
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(
|
||||
return_value={
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
)
|
||||
await http.send_message("c", "hi", files=[File(file_path)])
|
||||
_, kwargs = http.request.call_args
|
||||
assert kwargs["is_json"] is False
|
||||
|
||||
|
||||
def test_file_spoiler():
|
||||
f = File(io.BytesIO(b"d"), filename="a.txt", spoiler=True)
|
||||
assert f.filename == "SPOILER_a.txt"
|
||||
|
@ -82,3 +82,24 @@ async def test_before_after_loop_callbacks() -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
assert events and events[0] == "before"
|
||||
assert "after" in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_interval_and_current_loop() -> None:
|
||||
count = 0
|
||||
|
||||
@tasks.loop(seconds=0.01)
|
||||
async def ticker() -> None:
|
||||
nonlocal count
|
||||
count += 1
|
||||
|
||||
ticker.start()
|
||||
await asyncio.sleep(0.03)
|
||||
initial = ticker.current_loop
|
||||
ticker.change_interval(seconds=0.02)
|
||||
await asyncio.sleep(0.05)
|
||||
ticker.stop()
|
||||
|
||||
assert initial >= 2
|
||||
assert ticker.current_loop > initial
|
||||
assert count == ticker.current_loop
|
||||
|
@ -1,8 +1,54 @@
|
||||
from datetime import timezone
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
|
||||
from disagreement.utils import utcnow
|
||||
from disagreement.utils import (
|
||||
escape_markdown,
|
||||
escape_mentions,
|
||||
utcnow,
|
||||
snowflake_time,
|
||||
find,
|
||||
get
|
||||
)
|
||||
|
||||
|
||||
def test_utcnow_timezone():
|
||||
now = utcnow()
|
||||
assert now.tzinfo == timezone.utc
|
||||
|
||||
|
||||
def test_find_returns_matching_element():
|
||||
seq = [1, 2, 3]
|
||||
assert find(lambda x: x > 1, seq) == 2
|
||||
assert find(lambda x: x > 3, seq) is None
|
||||
|
||||
|
||||
def test_get_matches_attributes():
|
||||
items = [
|
||||
SimpleNamespace(id=1, name="a"),
|
||||
SimpleNamespace(id=2, name="b"),
|
||||
]
|
||||
assert get(items, id=2) is items[1]
|
||||
assert get(items, id=1, name="a") is items[0]
|
||||
assert get(items, name="c") is None
|
||||
|
||||
|
||||
def test_snowflake_time():
|
||||
dt = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||
ms = int(dt.timestamp() * 1000) - 1420070400000
|
||||
snowflake = ms << 22
|
||||
assert snowflake_time(snowflake) == dt
|
||||
|
||||
|
||||
def test_escape_markdown():
|
||||
text = "**bold** _under_ ~strike~ `code` > quote | pipe"
|
||||
escaped = escape_markdown(text)
|
||||
assert (
|
||||
escaped
|
||||
== "\\*\\*bold\\*\\* \\_under\\_ \\~strike\\~ \\`code\\` \\> quote \\| pipe"
|
||||
)
|
||||
|
||||
|
||||
def test_escape_mentions():
|
||||
text = "Hello @everyone and <@123>!"
|
||||
escaped = escape_mentions(text)
|
||||
assert escaped == "Hello @\u200beveryone and <@\u200b123>!"
|
||||
|
@ -59,6 +59,17 @@ class DummySource(AudioSource):
|
||||
return b""
|
||||
|
||||
|
||||
class SlowSource(AudioSource):
|
||||
def __init__(self, chunks):
|
||||
self.chunks = list(chunks)
|
||||
|
||||
async def read(self) -> bytes:
|
||||
await asyncio.sleep(0)
|
||||
if self.chunks:
|
||||
return self.chunks.pop(0)
|
||||
return b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_client_handshake():
|
||||
hello = {"d": {"heartbeat_interval": 50}}
|
||||
@ -205,3 +216,49 @@ async def test_voice_client_volume_scaling(monkeypatch):
|
||||
samples[1] = int(samples[1] * 0.5)
|
||||
expected = samples.tobytes()
|
||||
assert udp.sent == [expected]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_and_status():
|
||||
ws = DummyWebSocket(
|
||||
[
|
||||
{"d": {"heartbeat_interval": 50}},
|
||||
{"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}},
|
||||
{"d": {"secret_key": []}},
|
||||
]
|
||||
)
|
||||
udp = DummyUDP()
|
||||
vc = VoiceClient(
|
||||
client=DummyVoiceClient(),
|
||||
endpoint="ws://localhost",
|
||||
session_id="sess",
|
||||
token="tok",
|
||||
guild_id=1,
|
||||
user_id=2,
|
||||
ws=ws,
|
||||
udp=udp,
|
||||
)
|
||||
await vc.connect()
|
||||
vc._heartbeat_task.cancel()
|
||||
|
||||
src = SlowSource([b"a", b"b", b"c"])
|
||||
await vc.play(src, wait=False)
|
||||
|
||||
while not udp.sent:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert vc.is_playing()
|
||||
vc.pause()
|
||||
assert vc.is_paused()
|
||||
await asyncio.sleep(0)
|
||||
sent = len(udp.sent)
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(udp.sent) == sent
|
||||
assert not vc.is_playing()
|
||||
|
||||
vc.resume()
|
||||
assert not vc.is_paused()
|
||||
|
||||
await vc._play_task
|
||||
assert udp.sent == [b"a", b"b", b"c"]
|
||||
assert not vc.is_playing()
|
||||
|
35
tests/test_walk_commands.py
Normal file
35
tests/test_walk_commands.py
Normal file
@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.commands.core import CommandHandler, Command, Group
|
||||
|
||||
|
||||
class DummyBot:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_walk_commands_recurses_groups():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
|
||||
async def root(ctx):
|
||||
pass
|
||||
|
||||
root_group = Group(root, name="root")
|
||||
|
||||
@root_group.command(name="child")
|
||||
async def child(ctx):
|
||||
pass
|
||||
|
||||
@root_group.group(name="sub")
|
||||
async def sub(ctx):
|
||||
pass
|
||||
|
||||
@sub.command(name="leaf")
|
||||
async def leaf(ctx):
|
||||
pass
|
||||
|
||||
handler.add_command(root_group)
|
||||
|
||||
names = [cmd.name for cmd in handler.walk_commands()]
|
||||
assert set(names) == {"help", "root", "child", "sub", "leaf"}
|
@ -146,6 +146,16 @@ def test_webhook_from_url_parses_id_and_token():
|
||||
assert webhook.url == url
|
||||
|
||||
|
||||
def test_webhook_from_token_builds_url_and_fields():
|
||||
from disagreement.models import Webhook
|
||||
|
||||
webhook = Webhook.from_token("123", "token")
|
||||
|
||||
assert webhook.id == "123"
|
||||
assert webhook.token == "token"
|
||||
assert webhook.url == "https://discord.com/api/webhooks/123/token"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_webhook_calls_request():
|
||||
http = HTTPClient(token="t")
|
||||
@ -185,3 +195,31 @@ async def test_webhook_send_uses_http():
|
||||
|
||||
http.execute_webhook.assert_awaited_once()
|
||||
assert isinstance(msg, Message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_webhook_calls_request():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(return_value={"id": "1"})
|
||||
|
||||
await http.get_webhook("1")
|
||||
|
||||
http.request.assert_called_once_with("GET", "/webhooks/1", use_auth_header=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_fetch_webhook_returns_model():
|
||||
from types import SimpleNamespace
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import Webhook
|
||||
|
||||
http = SimpleNamespace(get_webhook=AsyncMock(return_value={"id": "1"}))
|
||||
client = Client(token="test")
|
||||
client._http = http
|
||||
client._closed = False
|
||||
|
||||
webhook = await client.fetch_webhook("1")
|
||||
|
||||
http.get_webhook.assert_awaited_once_with("1")
|
||||
assert isinstance(webhook, Webhook)
|
||||
assert client._webhooks.get("1") is webhook
|
||||
|
Loading…
x
Reference in New Issue
Block a user