Compare commits
51 Commits
Author | SHA1 | Date | |
---|---|---|---|
380feddeeb | |||
3beaed8a1b | |||
e5ad932321 | |||
8e88aaec2f | |||
d710487fc2 | |||
506adeca20 | |||
e2061adc55 | |||
132521fa39 | |||
cec747a575 | |||
17751d3b09 | |||
4b3b6aeb45 | |||
aa55aa1d4c | |||
80f64c1f73 | |||
f5f8f6908c | |||
3437050f0e | |||
87d67eb63b | |||
9c10ab0f70 | |||
2008dd33d1 | |||
de40aa2c29 | |||
2056a3ddcf | |||
ccf55adba2 | |||
a335ed972c | |||
2586d3cd0d | |||
7f9647a442 | |||
a222dec661 | |||
3f7c286322 | |||
cc17d11509 | |||
9fabf1fbac | |||
223c86cb78 | |||
98afb89629 | |||
095e7e7192 | |||
c1c5cfb41a | |||
8be234c1f0 | |||
1464937f6f | |||
5d66eb79cc | |||
5d72643390 | |||
c7eb8563de | |||
a68bbe7826 | |||
6eff962682 | |||
f24c1befac | |||
c811e2b578 | |||
9f2fc0857b | |||
775dce0c80 | |||
a93ad432b7 | |||
3a264f4530 | |||
|
a41a301927 | ||
|
bd16b1c026 | ||
460583ef30 | |||
f1ca18a62a | |||
|
2c8e426353 | ||
c9aec0dc7e |
25
README.md
25
README.md
@ -13,6 +13,8 @@ A Python library for interacting with the Discord API, with a focus on bot devel
|
|||||||
- Message component helpers
|
- Message component helpers
|
||||||
- `Message.jump_url` property for quick links to messages
|
- `Message.jump_url` property for quick links to messages
|
||||||
- Built-in caching layer
|
- Built-in caching layer
|
||||||
|
- `Guild.me` property to access the bot's member object
|
||||||
|
- Easy CDN asset handling via the `Asset` model
|
||||||
- Experimental voice support
|
- Experimental voice support
|
||||||
- Helpful error handling utilities
|
- Helpful error handling utilities
|
||||||
|
|
||||||
@ -61,13 +63,9 @@ if not token:
|
|||||||
|
|
||||||
intents = disagreement.GatewayIntent.default() | disagreement.GatewayIntent.MESSAGE_CONTENT
|
intents = disagreement.GatewayIntent.default() | disagreement.GatewayIntent.MESSAGE_CONTENT
|
||||||
client = disagreement.Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
|
client = disagreement.Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
|
||||||
async def main() -> None:
|
|
||||||
client.add_cog(Basics(client))
|
|
||||||
await client.run()
|
|
||||||
|
|
||||||
|
client.add_cog(Basics(client))
|
||||||
if __name__ == "__main__":
|
client.run()
|
||||||
asyncio.run(main())
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Global Error Handling
|
### Global Error Handling
|
||||||
@ -114,6 +112,10 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying
|
|||||||
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
|
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
|
||||||
session parameter supported by ``aiohttp``.
|
session parameter supported by ``aiohttp``.
|
||||||
|
|
||||||
|
### Logging Out
|
||||||
|
|
||||||
|
Call ``Client.logout`` to disconnect from the Gateway and clear the current bot token while keeping the HTTP session alive. Assign a new token and call ``connect`` or ``run`` to log back in.
|
||||||
|
|
||||||
### Default Allowed Mentions
|
### Default Allowed Mentions
|
||||||
|
|
||||||
Specify default mention behaviour for all outgoing messages when constructing the client:
|
Specify default mention behaviour for all outgoing messages when constructing the client:
|
||||||
@ -129,6 +131,17 @@ client = disagreement.Client(
|
|||||||
This dictionary is used whenever ``send_message`` or helpers like ``Message.reply``
|
This dictionary is used whenever ``send_message`` or helpers like ``Message.reply``
|
||||||
are called without an explicit ``allowed_mentions`` argument.
|
are called without an explicit ``allowed_mentions`` argument.
|
||||||
|
|
||||||
|
### Working With Assets
|
||||||
|
|
||||||
|
Properties like ``User.avatar`` and ``Guild.icon`` return :class:`disagreement.Asset` objects.
|
||||||
|
Use ``read`` to get the bytes or ``save`` to write them to disk.
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await client.fetch_user(123)
|
||||||
|
data = await user.avatar.read()
|
||||||
|
await user.avatar.save("avatar.png")
|
||||||
|
```
|
||||||
|
|
||||||
### Defining Subcommands with `AppCommandGroup`
|
### Defining Subcommands with `AppCommandGroup`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -12,9 +12,10 @@ __title__ = "disagreement"
|
|||||||
__author__ = "Slipstream"
|
__author__ = "Slipstream"
|
||||||
__license__ = "BSD 3-Clause License"
|
__license__ = "BSD 3-Clause License"
|
||||||
__copyright__ = "Copyright 2025 Slipstream"
|
__copyright__ = "Copyright 2025 Slipstream"
|
||||||
__version__ = "0.6.0"
|
__version__ = "0.8.1"
|
||||||
|
|
||||||
from .client import Client, AutoShardedClient
|
from .client import Client, AutoShardedClient
|
||||||
|
from .asset import Asset
|
||||||
from .models import (
|
from .models import (
|
||||||
Message,
|
Message,
|
||||||
User,
|
User,
|
||||||
@ -39,6 +40,7 @@ from .models import (
|
|||||||
Container,
|
Container,
|
||||||
Guild,
|
Guild,
|
||||||
)
|
)
|
||||||
|
from .object import Object
|
||||||
from .voice_client import VoiceClient
|
from .voice_client import VoiceClient
|
||||||
from .audio import AudioSource, FFmpegAudioSource
|
from .audio import AudioSource, FFmpegAudioSource
|
||||||
from .typing import Typing
|
from .typing import Typing
|
||||||
@ -51,7 +53,15 @@ from .errors import (
|
|||||||
NotFound,
|
NotFound,
|
||||||
)
|
)
|
||||||
from .color import Color
|
from .color import Color
|
||||||
from .utils import utcnow, message_pager
|
from .utils import (
|
||||||
|
utcnow,
|
||||||
|
message_pager,
|
||||||
|
get,
|
||||||
|
find,
|
||||||
|
escape_markdown,
|
||||||
|
escape_mentions,
|
||||||
|
snowflake_time,
|
||||||
|
)
|
||||||
from .enums import (
|
from .enums import (
|
||||||
GatewayIntent,
|
GatewayIntent,
|
||||||
GatewayOpcode,
|
GatewayOpcode,
|
||||||
@ -101,6 +111,7 @@ from .ext.commands import (
|
|||||||
command,
|
command,
|
||||||
cooldown,
|
cooldown,
|
||||||
has_any_role,
|
has_any_role,
|
||||||
|
is_owner,
|
||||||
has_role,
|
has_role,
|
||||||
listener,
|
listener,
|
||||||
max_concurrency,
|
max_concurrency,
|
||||||
@ -116,6 +127,7 @@ import logging
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Client",
|
"Client",
|
||||||
"AutoShardedClient",
|
"AutoShardedClient",
|
||||||
|
"Asset",
|
||||||
"Message",
|
"Message",
|
||||||
"User",
|
"User",
|
||||||
"Reaction",
|
"Reaction",
|
||||||
@ -137,6 +149,7 @@ __all__ = [
|
|||||||
"MediaGallery",
|
"MediaGallery",
|
||||||
"MediaGalleryItem",
|
"MediaGalleryItem",
|
||||||
"Container",
|
"Container",
|
||||||
|
"Object",
|
||||||
"VoiceClient",
|
"VoiceClient",
|
||||||
"AudioSource",
|
"AudioSource",
|
||||||
"FFmpegAudioSource",
|
"FFmpegAudioSource",
|
||||||
@ -149,7 +162,12 @@ __all__ = [
|
|||||||
"NotFound",
|
"NotFound",
|
||||||
"Color",
|
"Color",
|
||||||
"utcnow",
|
"utcnow",
|
||||||
|
"escape_markdown",
|
||||||
|
"escape_mentions",
|
||||||
"message_pager",
|
"message_pager",
|
||||||
|
"get",
|
||||||
|
"find",
|
||||||
|
"snowflake_time",
|
||||||
"GatewayIntent",
|
"GatewayIntent",
|
||||||
"GatewayOpcode",
|
"GatewayOpcode",
|
||||||
"ButtonStyle",
|
"ButtonStyle",
|
||||||
@ -195,6 +213,7 @@ __all__ = [
|
|||||||
"command",
|
"command",
|
||||||
"cooldown",
|
"cooldown",
|
||||||
"has_any_role",
|
"has_any_role",
|
||||||
|
"is_owner",
|
||||||
"has_role",
|
"has_role",
|
||||||
"listener",
|
"listener",
|
||||||
"max_concurrency",
|
"max_concurrency",
|
||||||
|
51
disagreement/asset.py
Normal file
51
disagreement/asset.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""Utility class for Discord CDN assets."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import IO, Optional, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
import aiohttp # pylint: disable=import-error
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import Client
|
||||||
|
|
||||||
|
|
||||||
|
class Asset:
|
||||||
|
"""Represents a CDN asset such as an avatar or icon."""
|
||||||
|
|
||||||
|
def __init__(self, url: str, client_instance: Optional["Client"] = None) -> None:
|
||||||
|
self.url = url
|
||||||
|
self._client = client_instance
|
||||||
|
|
||||||
|
async def read(self) -> bytes:
|
||||||
|
"""Read the asset's bytes."""
|
||||||
|
|
||||||
|
session: Optional[aiohttp.ClientSession] = None
|
||||||
|
if self._client is not None:
|
||||||
|
await self._client._http._ensure_session() # type: ignore[attr-defined]
|
||||||
|
session = self._client._http._session # type: ignore[attr-defined]
|
||||||
|
if session is None:
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
close = True
|
||||||
|
else:
|
||||||
|
close = False
|
||||||
|
async with session.get(self.url) as resp:
|
||||||
|
data = await resp.read()
|
||||||
|
if close:
|
||||||
|
await session.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def save(self, fp: Union[str, os.PathLike[str], IO[bytes]]) -> None:
|
||||||
|
"""Save the asset to the given file path or file-like object."""
|
||||||
|
|
||||||
|
data = await self.read()
|
||||||
|
if isinstance(fp, (str, os.PathLike)):
|
||||||
|
path = os.fspath(fp)
|
||||||
|
with open(path, "wb") as file:
|
||||||
|
file.write(data)
|
||||||
|
else:
|
||||||
|
fp.write(data)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Asset url='{self.url}'>"
|
@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]):
|
|||||||
|
|
||||||
def _should_cache(self, member: Member) -> bool:
|
def _should_cache(self, member: Member) -> bool:
|
||||||
"""Determines if a member should be cached based on the flags."""
|
"""Determines if a member should be cached based on the flags."""
|
||||||
if self.flags.all:
|
if self.flags.all_enabled:
|
||||||
return True
|
return True
|
||||||
if self.flags.none:
|
if self.flags.no_flags:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.flags.online and member.status != "offline":
|
if self.flags.online and member.status != "offline":
|
||||||
|
@ -74,6 +74,14 @@ class MemberCacheFlags:
|
|||||||
for name in self.VALID_FLAGS:
|
for name in self.VALID_FLAGS:
|
||||||
yield name, getattr(self, name)
|
yield name, getattr(self, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_enabled(self) -> bool:
|
||||||
|
return self.value == self.ALL_FLAGS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def no_flags(self) -> bool:
|
||||||
|
return self.value == 0
|
||||||
|
|
||||||
def __int__(self) -> int:
|
def __int__(self) -> int:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
@ -2,8 +2,11 @@
|
|||||||
The main Client class for interacting with the Discord API.
|
The main Client class for interacting with the Discord API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import signal
|
import signal
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import importlib
|
||||||
from typing import (
|
from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
Callable,
|
Callable,
|
||||||
@ -14,8 +17,13 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
List,
|
List,
|
||||||
Dict,
|
Dict,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
|
PERSISTENT_VIEWS_FILE = "persistent_views.json"
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from .http import HTTPClient
|
from .http import HTTPClient
|
||||||
from .gateway import GatewayClient
|
from .gateway import GatewayClient
|
||||||
@ -35,6 +43,7 @@ from .interactions import Interaction, Snowflake
|
|||||||
from .error_handler import setup_global_error_handler
|
from .error_handler import setup_global_error_handler
|
||||||
from .voice_client import VoiceClient
|
from .voice_client import VoiceClient
|
||||||
from .models import Activity
|
from .models import Activity
|
||||||
|
from .utils import utcnow
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .models import (
|
from .models import (
|
||||||
@ -64,7 +73,16 @@ if TYPE_CHECKING:
|
|||||||
from .ext.app_commands.commands import AppCommand, AppCommandGroup
|
from .ext.app_commands.commands import AppCommand, AppCommandGroup
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
def _update_list(lst: List[Any], item: Any) -> None:
|
||||||
|
"""Replace an item with the same ID in a list or append if missing."""
|
||||||
|
for i, existing in enumerate(lst):
|
||||||
|
if getattr(existing, "id", None) == getattr(item, "id", None):
|
||||||
|
lst[i] = item
|
||||||
|
return
|
||||||
|
lst.append(item)
|
||||||
|
|
||||||
|
|
||||||
|
class Client:
|
||||||
"""
|
"""
|
||||||
Represents a client connection that connects to Discord.
|
Represents a client connection that connects to Discord.
|
||||||
This class is used to interact with the Discord WebSocket and API.
|
This class is used to interact with the Discord WebSocket and API.
|
||||||
@ -89,6 +107,9 @@ class Client:
|
|||||||
:class:`aiohttp.ClientSession`.
|
:class:`aiohttp.ClientSession`.
|
||||||
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
|
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
|
||||||
in the cache. When ``None``, the cache size is unlimited.
|
in the cache. When ``None``, the cache size is unlimited.
|
||||||
|
sync_commands_on_ready (bool): If ``True``, automatically call
|
||||||
|
:meth:`Client.sync_application_commands` after the ``READY`` event
|
||||||
|
when :attr:`Client.application_id` is available.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -109,7 +130,10 @@ class Client:
|
|||||||
member_cache_flags: Optional[MemberCacheFlags] = None,
|
member_cache_flags: Optional[MemberCacheFlags] = None,
|
||||||
message_cache_maxlen: Optional[int] = None,
|
message_cache_maxlen: Optional[int] = None,
|
||||||
http_options: Optional[Dict[str, Any]] = None,
|
http_options: Optional[Dict[str, Any]] = None,
|
||||||
|
owner_ids: Optional[List[Union[str, int]]] = None,
|
||||||
|
sync_commands_on_ready: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
raise ValueError("A bot token must be provided.")
|
raise ValueError("A bot token must be provided.")
|
||||||
|
|
||||||
@ -139,13 +163,14 @@ class Client:
|
|||||||
**(http_options or {}),
|
**(http_options or {}),
|
||||||
)
|
)
|
||||||
self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self)
|
self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self)
|
||||||
self._gateway: Optional[GatewayClient] = (
|
self._gateway: Optional[GatewayClient] = (
|
||||||
None # Initialized in run() or connect()
|
None # Initialized in start() or connect()
|
||||||
)
|
)
|
||||||
self.shard_count: Optional[int] = shard_count
|
self.shard_count: Optional[int] = shard_count
|
||||||
self.gateway_max_retries: int = gateway_max_retries
|
self.gateway_max_retries: int = gateway_max_retries
|
||||||
self.gateway_max_backoff: float = gateway_max_backoff
|
self.gateway_max_backoff: float = gateway_max_backoff
|
||||||
self._shard_manager: Optional[ShardManager] = None
|
self._shard_manager: Optional[ShardManager] = None
|
||||||
|
self.owner_ids: List[str] = [str(o) for o in owner_ids] if owner_ids else []
|
||||||
|
|
||||||
# Initialize CommandHandler
|
# Initialize CommandHandler
|
||||||
self.command_handler: CommandHandler = CommandHandler(
|
self.command_handler: CommandHandler = CommandHandler(
|
||||||
@ -163,6 +188,8 @@ class Client:
|
|||||||
None # The bot's own user object, populated on READY
|
None # The bot's own user object, populated on READY
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.start_time: Optional[datetime] = None
|
||||||
|
|
||||||
# Internal Caches
|
# Internal Caches
|
||||||
self._guilds: GuildCache = GuildCache()
|
self._guilds: GuildCache = GuildCache()
|
||||||
self._channels: ChannelCache = ChannelCache()
|
self._channels: ChannelCache = ChannelCache()
|
||||||
@ -171,11 +198,15 @@ class Client:
|
|||||||
self._views: Dict[Snowflake, "View"] = {}
|
self._views: Dict[Snowflake, "View"] = {}
|
||||||
self._persistent_views: Dict[str, "View"] = {}
|
self._persistent_views: Dict[str, "View"] = {}
|
||||||
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
||||||
self._webhooks: Dict[Snowflake, "Webhook"] = {}
|
self._webhooks: Dict[Snowflake, "Webhook"] = {}
|
||||||
|
|
||||||
|
# Load persistent views stored on disk
|
||||||
|
self._load_persistent_views()
|
||||||
|
|
||||||
# Default whether replies mention the user
|
# Default whether replies mention the user
|
||||||
self.mention_replies: bool = mention_replies
|
self.mention_replies: bool = mention_replies
|
||||||
self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions
|
self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions
|
||||||
|
self.sync_commands_on_ready: bool = sync_commands_on_ready
|
||||||
|
|
||||||
# Basic signal handling for graceful shutdown
|
# Basic signal handling for graceful shutdown
|
||||||
# This might be better handled by the user's application code, but can be a nice default.
|
# This might be better handled by the user's application code, but can be a nice default.
|
||||||
@ -187,13 +218,46 @@ class Client:
|
|||||||
self.loop.add_signal_handler(
|
self.loop.add_signal_handler(
|
||||||
signal.SIGTERM, lambda: self.loop.create_task(self.close())
|
signal.SIGTERM, lambda: self.loop.create_task(self.close())
|
||||||
)
|
)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
# add_signal_handler is not available on all platforms (e.g., Windows default event loop policy)
|
# add_signal_handler is not available on all platforms (e.g., Windows default event loop policy)
|
||||||
# Users on these platforms would need to handle shutdown differently.
|
# Users on these platforms would need to handle shutdown differently.
|
||||||
print(
|
print(
|
||||||
"Warning: Signal handlers for SIGINT/SIGTERM could not be added. "
|
"Warning: Signal handlers for SIGINT/SIGTERM could not be added. "
|
||||||
"Graceful shutdown via signals might not work as expected on this platform."
|
"Graceful shutdown via signals might not work as expected on this platform."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _load_persistent_views(self) -> None:
|
||||||
|
"""Load registered persistent views from disk."""
|
||||||
|
if not os.path.isfile(PERSISTENT_VIEWS_FILE):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with open(PERSISTENT_VIEWS_FILE, "r") as fp:
|
||||||
|
mapping = json.load(fp)
|
||||||
|
except Exception as e: # pragma: no cover - best effort load
|
||||||
|
print(f"Failed to load persistent views: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for custom_id, path in mapping.items():
|
||||||
|
try:
|
||||||
|
module_name, class_name = path.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
view = cls()
|
||||||
|
self._persistent_views[custom_id] = view
|
||||||
|
except Exception as e: # pragma: no cover - best effort load
|
||||||
|
print(f"Failed to initialize persistent view {path}: {e}")
|
||||||
|
|
||||||
|
def _save_persistent_views(self) -> None:
|
||||||
|
"""Persist registered views to disk."""
|
||||||
|
data = {}
|
||||||
|
for custom_id, view in self._persistent_views.items():
|
||||||
|
cls = view.__class__
|
||||||
|
data[custom_id] = f"{cls.__module__}.{cls.__name__}"
|
||||||
|
try:
|
||||||
|
with open(PERSISTENT_VIEWS_FILE, "w") as fp:
|
||||||
|
json.dump(data, fp)
|
||||||
|
except Exception as e: # pragma: no cover - best effort save
|
||||||
|
print(f"Failed to save persistent views: {e}")
|
||||||
|
|
||||||
async def _initialize_gateway(self):
|
async def _initialize_gateway(self):
|
||||||
"""Initializes the GatewayClient if it doesn't exist."""
|
"""Initializes the GatewayClient if it doesn't exist."""
|
||||||
@ -238,14 +302,15 @@ class Client:
|
|||||||
f"Client connected using {self.shard_count} shards, waiting for READY signal..."
|
f"Client connected using {self.shard_count} shards, waiting for READY signal..."
|
||||||
)
|
)
|
||||||
await self.wait_until_ready()
|
await self.wait_until_ready()
|
||||||
|
self.start_time = utcnow()
|
||||||
print("Client is READY!")
|
print("Client is READY!")
|
||||||
return
|
return
|
||||||
|
|
||||||
await self._initialize_gateway()
|
await self._initialize_gateway()
|
||||||
assert self._gateway is not None # Should be initialized by now
|
assert self._gateway is not None # Should be initialized by now
|
||||||
|
|
||||||
retry_delay = 5 # seconds
|
retry_delay = 5 # seconds
|
||||||
max_retries = 5 # For initial connection attempts by Client.run, Gateway has its own internal retries for some cases.
|
max_retries = 5 # For initial connection attempts by Client.start, Gateway has its own internal retries for some cases.
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
@ -254,6 +319,7 @@ class Client:
|
|||||||
# and its READY handler will set self._ready_event via dispatcher.
|
# and its READY handler will set self._ready_event via dispatcher.
|
||||||
print("Client connected to Gateway, waiting for READY signal...")
|
print("Client connected to Gateway, waiting for READY signal...")
|
||||||
await self.wait_until_ready() # Wait for the READY event from Gateway
|
await self.wait_until_ready() # Wait for the READY event from Gateway
|
||||||
|
self.start_time = utcnow()
|
||||||
print("Client is READY!")
|
print("Client is READY!")
|
||||||
return # Successfully connected and ready
|
return # Successfully connected and ready
|
||||||
except AuthenticationError: # Non-recoverable by retry here
|
except AuthenticationError: # Non-recoverable by retry here
|
||||||
@ -262,25 +328,24 @@ class Client:
|
|||||||
raise
|
raise
|
||||||
except DisagreementException as e: # Includes GatewayException
|
except DisagreementException as e: # Includes GatewayException
|
||||||
print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}")
|
print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}")
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
print(f"Retrying in {retry_delay} seconds...")
|
print(f"Retrying in {retry_delay} seconds...")
|
||||||
await asyncio.sleep(retry_delay)
|
await asyncio.sleep(retry_delay)
|
||||||
retry_delay = min(
|
retry_delay = min(
|
||||||
retry_delay * 2, 60
|
retry_delay * 2, 60
|
||||||
) # Exponential backoff up to 60s
|
) # Exponential backoff up to 60s
|
||||||
else:
|
else:
|
||||||
print("Max connection retries reached. Giving up.")
|
print("Max connection retries reached. Giving up.")
|
||||||
await self.close() # Ensure cleanup
|
await self.close() # Ensure cleanup
|
||||||
raise
|
raise
|
||||||
if max_retries == 0: # If max_retries was 0, means no retries attempted
|
if max_retries == 0: # If max_retries was 0, means no retries attempted
|
||||||
raise DisagreementException("Connection failed with 0 retries allowed.")
|
raise DisagreementException("Connection failed with 0 retries allowed.")
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def start(self) -> None:
|
||||||
"""
|
"""
|
||||||
A blocking call that connects the client to Discord and runs until the client is closed.
|
Connect the client to Discord and run until the client is closed.
|
||||||
This method is a coroutine.
|
This method is a coroutine containing the main run loop logic.
|
||||||
It handles login, Gateway connection, and keeping the connection alive.
|
"""
|
||||||
"""
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise DisagreementException("Client is already closed.")
|
raise DisagreementException("Client is already closed.")
|
||||||
|
|
||||||
@ -337,15 +402,19 @@ class Client:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error checking gateway receive task: {e}")
|
print(f"Error checking gateway receive task: {e}")
|
||||||
break # Exit on other errors
|
break # Exit on other errors
|
||||||
await asyncio.sleep(1) # Main loop check interval
|
await asyncio.sleep(1) # Main loop check interval
|
||||||
except DisagreementException as e:
|
except DisagreementException as e:
|
||||||
print(f"Client run loop encountered an error: {e}")
|
print(f"Client run loop encountered an error: {e}")
|
||||||
# Error already logged by connect or other methods
|
# Error already logged by connect or other methods
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
print("Client run loop was cancelled.")
|
print("Client run loop was cancelled.")
|
||||||
finally:
|
finally:
|
||||||
if not self._closed:
|
if not self._closed:
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""Synchronously start the client using :func:`asyncio.run`."""
|
||||||
|
asyncio.run(self.start())
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -367,6 +436,7 @@ class Client:
|
|||||||
await self._http.close()
|
await self._http.close()
|
||||||
|
|
||||||
self._ready_event.set() # Ensure any waiters for ready are unblocked
|
self._ready_event.set() # Ensure any waiters for ready are unblocked
|
||||||
|
self.start_time = None
|
||||||
print("Client closed.")
|
print("Client closed.")
|
||||||
|
|
||||||
async def __aenter__(self) -> "Client":
|
async def __aenter__(self) -> "Client":
|
||||||
@ -384,15 +454,23 @@ class Client:
|
|||||||
await self.close()
|
await self.close()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def close_gateway(self, code: int = 1000) -> None:
|
async def close_gateway(self, code: int = 1000) -> None:
|
||||||
"""Closes only the gateway connection, allowing for potential reconnect."""
|
"""Closes only the gateway connection, allowing for potential reconnect."""
|
||||||
if self._shard_manager:
|
if self._shard_manager:
|
||||||
await self._shard_manager.close()
|
await self._shard_manager.close()
|
||||||
self._shard_manager = None
|
self._shard_manager = None
|
||||||
if self._gateway:
|
if self._gateway:
|
||||||
await self._gateway.close(code=code)
|
await self._gateway.close(code=code)
|
||||||
self._gateway = None
|
self._gateway = None
|
||||||
self._ready_event.clear() # No longer ready if gateway is closed
|
self._ready_event.clear() # No longer ready if gateway is closed
|
||||||
|
|
||||||
|
async def logout(self) -> None:
|
||||||
|
"""Invalidate the bot token and disconnect from the Gateway."""
|
||||||
|
await self.close_gateway()
|
||||||
|
self.token = ""
|
||||||
|
self._http.token = ""
|
||||||
|
self.user = None
|
||||||
|
self.start_time = None
|
||||||
|
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
"""Indicates if the client has been closed."""
|
"""Indicates if the client has been closed."""
|
||||||
@ -415,6 +493,17 @@ class Client:
|
|||||||
latency = getattr(self._gateway, "latency_ms", None)
|
latency = getattr(self._gateway, "latency_ms", None)
|
||||||
return round(latency, 2) if latency is not None else None
|
return round(latency, 2) if latency is not None else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guilds(self) -> List["Guild"]:
|
||||||
|
"""Returns all guilds from the internal cache."""
|
||||||
|
return self._guilds.values()
|
||||||
|
|
||||||
|
def uptime(self) -> Optional[timedelta]:
|
||||||
|
"""Return the duration since the client connected, or ``None`` if not connected."""
|
||||||
|
if self.start_time is None:
|
||||||
|
return None
|
||||||
|
return utcnow() - self.start_time
|
||||||
|
|
||||||
async def wait_until_ready(self) -> None:
|
async def wait_until_ready(self) -> None:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
Waits until the client is fully connected to Discord and the initial state is processed.
|
Waits until the client is fully connected to Discord and the initial state is processed.
|
||||||
@ -529,38 +618,43 @@ class Client:
|
|||||||
print(f"Message: {message.content}")
|
print(f"Message: {message.content}")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
coro: Callable[..., Awaitable[None]],
|
coro: Callable[..., Awaitable[None]],
|
||||||
) -> Callable[..., Awaitable[None]]:
|
) -> Callable[..., Awaitable[None]]:
|
||||||
if not asyncio.iscoroutinefunction(coro):
|
if not asyncio.iscoroutinefunction(coro):
|
||||||
raise TypeError("Event registered must be a coroutine function.")
|
raise TypeError("Event registered must be a coroutine function.")
|
||||||
self._event_dispatcher.register(event_name.upper(), coro)
|
self._event_dispatcher.register(event_name.upper(), coro)
|
||||||
return coro
|
return coro
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def add_listener(
|
|
||||||
self, event_name: str, coro: Callable[..., Awaitable[None]]
|
|
||||||
) -> None:
|
|
||||||
"""Register ``coro`` to listen for ``event_name``."""
|
|
||||||
|
|
||||||
self._event_dispatcher.register(event_name, coro)
|
|
||||||
|
|
||||||
def remove_listener(
|
|
||||||
self, event_name: str, coro: Callable[..., Awaitable[None]]
|
|
||||||
) -> None:
|
|
||||||
"""Remove ``coro`` from ``event_name`` listeners."""
|
|
||||||
|
|
||||||
self._event_dispatcher.unregister(event_name, coro)
|
|
||||||
|
|
||||||
async def _process_message_for_commands(self, message: "Message") -> None:
|
return decorator
|
||||||
"""Internal listener to process messages for commands."""
|
|
||||||
# Make sure message object is valid and not from a bot (optional, common check)
|
def add_listener(
|
||||||
if (
|
self, event_name: str, coro: Callable[..., Awaitable[None]]
|
||||||
not message or not message.author or message.author.bot
|
) -> None:
|
||||||
): # Add .bot check to User model
|
"""Register ``coro`` to listen for ``event_name``."""
|
||||||
return
|
|
||||||
await self.command_handler.process_commands(message)
|
self._event_dispatcher.register(event_name, coro)
|
||||||
|
|
||||||
|
def remove_listener(
|
||||||
|
self, event_name: str, coro: Callable[..., Awaitable[None]]
|
||||||
|
) -> None:
|
||||||
|
"""Remove ``coro`` from ``event_name`` listeners."""
|
||||||
|
|
||||||
|
self._event_dispatcher.unregister(event_name, coro)
|
||||||
|
|
||||||
|
async def _process_message_for_commands(self, message: "Message") -> None:
|
||||||
|
"""Internal listener to process messages for commands."""
|
||||||
|
# Make sure message object is valid and not from a bot (optional, common check)
|
||||||
|
if (
|
||||||
|
not message or not message.author or message.author.bot
|
||||||
|
): # Add .bot check to User model
|
||||||
|
return
|
||||||
|
await self.command_handler.process_commands(message)
|
||||||
|
|
||||||
|
async def get_context(self, message: "Message") -> Optional["CommandContext"]:
|
||||||
|
"""Return a :class:`CommandContext` for ``message`` without executing the command."""
|
||||||
|
|
||||||
|
return await self.command_handler.get_context(message)
|
||||||
|
|
||||||
# --- Command Framework Methods ---
|
# --- Command Framework Methods ---
|
||||||
|
|
||||||
@ -587,7 +681,7 @@ class Client:
|
|||||||
f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'."
|
f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
def remove_cog(self, cog_name: str) -> Optional[Cog]:
|
def remove_cog(self, cog_name: str) -> Optional[Cog]:
|
||||||
"""
|
"""
|
||||||
Removes a Cog from the bot.
|
Removes a Cog from the bot.
|
||||||
|
|
||||||
@ -614,7 +708,12 @@ class Client:
|
|||||||
# Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique
|
# Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique
|
||||||
# (e.g. if it needs type or if groups and commands can share names).
|
# (e.g. if it needs type or if groups and commands can share names).
|
||||||
# For now, assuming name is sufficient for removal from the handler's flat list.
|
# For now, assuming name is sufficient for removal from the handler's flat list.
|
||||||
return removed_cog
|
return removed_cog
|
||||||
|
|
||||||
|
def get_cog(self, name: str) -> Optional[Cog]:
|
||||||
|
"""Return a loaded cog by name if present."""
|
||||||
|
|
||||||
|
return self.command_handler.get_cog(name)
|
||||||
|
|
||||||
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
|
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
|
||||||
"""
|
"""
|
||||||
@ -754,14 +853,19 @@ class Client:
|
|||||||
"""Parses user data and returns a User object, updating cache."""
|
"""Parses user data and returns a User object, updating cache."""
|
||||||
from .models import User # Ensure User model is available
|
from .models import User # Ensure User model is available
|
||||||
|
|
||||||
user = User(data)
|
user = User(data, client_instance=self)
|
||||||
self._users.set(user.id, user) # Cache the user
|
self._users.set(user.id, user) # Cache the user
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def parse_channel(self, data: Dict[str, Any]) -> "Channel":
|
def parse_channel(self, data: Dict[str, Any]) -> "Channel":
|
||||||
"""Parses channel data and returns a Channel object, updating caches."""
|
"""Parses channel data and returns a Channel object, updating caches."""
|
||||||
|
|
||||||
from .models import channel_factory
|
from .models import (
|
||||||
|
channel_factory,
|
||||||
|
TextChannel,
|
||||||
|
VoiceChannel,
|
||||||
|
CategoryChannel,
|
||||||
|
)
|
||||||
|
|
||||||
channel = channel_factory(data, self)
|
channel = channel_factory(data, self)
|
||||||
self._channels.set(channel.id, channel)
|
self._channels.set(channel.id, channel)
|
||||||
@ -769,6 +873,12 @@ class Client:
|
|||||||
guild = self._guilds.get(channel.guild_id)
|
guild = self._guilds.get(channel.guild_id)
|
||||||
if guild:
|
if guild:
|
||||||
guild._channels.set(channel.id, channel)
|
guild._channels.set(channel.id, channel)
|
||||||
|
if isinstance(channel, TextChannel):
|
||||||
|
_update_list(guild.text_channels, channel)
|
||||||
|
elif isinstance(channel, VoiceChannel):
|
||||||
|
_update_list(guild.voice_channels, channel)
|
||||||
|
elif isinstance(channel, CategoryChannel):
|
||||||
|
_update_list(guild.category_channels, channel)
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
def parse_message(self, data: Dict[str, Any]) -> "Message":
|
def parse_message(self, data: Dict[str, Any]) -> "Message":
|
||||||
@ -929,7 +1039,8 @@ class Client:
|
|||||||
"""Parses guild data and returns a Guild object, updating cache."""
|
"""Parses guild data and returns a Guild object, updating cache."""
|
||||||
from .models import Guild
|
from .models import Guild
|
||||||
|
|
||||||
guild = Guild(data, client_instance=self)
|
shard_id = data.get("shard_id")
|
||||||
|
guild = Guild(data, client_instance=self, shard_id=shard_id)
|
||||||
self._guilds.set(guild.id, guild)
|
self._guilds.set(guild.id, guild)
|
||||||
|
|
||||||
presences = {p["user"]["id"]: p for p in data.get("presences", [])}
|
presences = {p["user"]["id"]: p for p in data.get("presences", [])}
|
||||||
@ -1107,6 +1218,23 @@ class Client:
|
|||||||
|
|
||||||
return self.parse_message(message_data)
|
return self.parse_message(message_data)
|
||||||
|
|
||||||
|
async def create_dm(self, user_id: Snowflake) -> "DMChannel":
|
||||||
|
"""|coro| Create or fetch a DM channel with a user."""
|
||||||
|
from .models import DMChannel
|
||||||
|
|
||||||
|
dm_data = await self._http.create_dm(user_id)
|
||||||
|
return cast(DMChannel, self.parse_channel(dm_data))
|
||||||
|
|
||||||
|
async def send_dm(
|
||||||
|
self,
|
||||||
|
user_id: Snowflake,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "Message":
|
||||||
|
"""|coro| Convenience method to send a direct message to a user."""
|
||||||
|
channel = await self.create_dm(user_id)
|
||||||
|
return await self.send_message(channel.id, content=content, **kwargs)
|
||||||
|
|
||||||
def typing(self, channel_id: str) -> Typing:
|
def typing(self, channel_id: str) -> Typing:
|
||||||
"""Return a context manager to show a typing indicator in a channel."""
|
"""Return a context manager to show a typing indicator in a channel."""
|
||||||
|
|
||||||
@ -1320,13 +1448,33 @@ class Client:
|
|||||||
|
|
||||||
return self._channels.get(channel_id)
|
return self._channels.get(channel_id)
|
||||||
|
|
||||||
def get_message(self, message_id: Snowflake) -> Optional["Message"]:
|
def get_message(self, message_id: Snowflake) -> Optional["Message"]:
|
||||||
"""Returns a message from the internal cache."""
|
"""Returns a message from the internal cache."""
|
||||||
|
|
||||||
|
return self._messages.get(message_id)
|
||||||
|
|
||||||
|
def get_all_channels(self) -> List["Channel"]:
|
||||||
|
"""Return all channels cached in every guild."""
|
||||||
|
|
||||||
|
channels: List["Channel"] = []
|
||||||
|
for guild in self._guilds.values():
|
||||||
|
channels.extend(guild._channels.values())
|
||||||
|
return channels
|
||||||
|
|
||||||
|
def get_all_members(self) -> List["Member"]:
|
||||||
|
"""Return all cached members across all guilds.
|
||||||
|
|
||||||
|
When member caching is disabled via :class:`MemberCacheFlags.none`, this
|
||||||
|
list will always be empty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
members: List["Member"] = []
|
||||||
|
for guild in self._guilds.values():
|
||||||
|
members.extend(guild._members.values())
|
||||||
|
return members
|
||||||
|
|
||||||
return self._messages.get(message_id)
|
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
|
||||||
|
"""Fetches a guild by ID from Discord and caches it."""
|
||||||
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
|
|
||||||
"""Fetches a guild by ID from Discord and caches it."""
|
|
||||||
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise DisagreementException("Client is closed.")
|
raise DisagreementException("Client is closed.")
|
||||||
@ -1340,19 +1488,19 @@ class Client:
|
|||||||
return self.parse_guild(guild_data)
|
return self.parse_guild(guild_data)
|
||||||
except DisagreementException as e:
|
except DisagreementException as e:
|
||||||
print(f"Failed to fetch guild {guild_id}: {e}")
|
print(f"Failed to fetch guild {guild_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def fetch_guilds(self) -> List["Guild"]:
|
async def fetch_guilds(self) -> List["Guild"]:
|
||||||
"""Fetch all guilds the current user is in."""
|
"""Fetch all guilds the current user is in."""
|
||||||
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise DisagreementException("Client is closed.")
|
raise DisagreementException("Client is closed.")
|
||||||
|
|
||||||
data = await self._http.get_current_user_guilds()
|
data = await self._http.get_current_user_guilds()
|
||||||
guilds: List["Guild"] = []
|
guilds: List["Guild"] = []
|
||||||
for guild_data in data:
|
for guild_data in data:
|
||||||
guilds.append(self.parse_guild(guild_data))
|
guilds.append(self.parse_guild(guild_data))
|
||||||
return guilds
|
return guilds
|
||||||
|
|
||||||
async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]:
|
async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]:
|
||||||
"""Fetches a channel from Discord by its ID and updates the cache."""
|
"""Fetches a channel from Discord by its ID and updates the cache."""
|
||||||
@ -1423,16 +1571,33 @@ class Client:
|
|||||||
data = await self._http.edit_webhook(webhook_id, payload)
|
data = await self._http.edit_webhook(webhook_id, payload)
|
||||||
return self.parse_webhook(data)
|
return self.parse_webhook(data)
|
||||||
|
|
||||||
async def delete_webhook(self, webhook_id: Snowflake) -> None:
|
async def delete_webhook(self, webhook_id: Snowflake) -> None:
|
||||||
"""|coro| Delete a webhook by ID."""
|
"""|coro| Delete a webhook by ID."""
|
||||||
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise DisagreementException("Client is closed.")
|
raise DisagreementException("Client is closed.")
|
||||||
|
|
||||||
await self._http.delete_webhook(webhook_id)
|
await self._http.delete_webhook(webhook_id)
|
||||||
|
|
||||||
async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]:
|
async def fetch_webhook(self, webhook_id: Snowflake) -> Optional["Webhook"]:
|
||||||
"""|coro| Fetch all templates for a guild."""
|
"""|coro| Fetch a webhook by ID."""
|
||||||
|
|
||||||
|
if self._closed:
|
||||||
|
raise DisagreementException("Client is closed.")
|
||||||
|
|
||||||
|
cached = self._webhooks.get(webhook_id)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._http.get_webhook(webhook_id)
|
||||||
|
return self.parse_webhook(data)
|
||||||
|
except DisagreementException as e:
|
||||||
|
print(f"Failed to fetch webhook {webhook_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]:
|
||||||
|
"""|coro| Fetch all templates for a guild."""
|
||||||
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise DisagreementException("Client is closed.")
|
raise DisagreementException("Client is closed.")
|
||||||
@ -1562,7 +1727,20 @@ class Client:
|
|||||||
if self._closed:
|
if self._closed:
|
||||||
raise DisagreementException("Client is closed.")
|
raise DisagreementException("Client is closed.")
|
||||||
|
|
||||||
await self._http.delete_invite(code)
|
await self._http.delete_invite(code)
|
||||||
|
|
||||||
|
async def fetch_invite(self, code: Snowflake) -> Optional["Invite"]:
|
||||||
|
"""|coro| Fetch a single invite by code."""
|
||||||
|
|
||||||
|
if self._closed:
|
||||||
|
raise DisagreementException("Client is closed.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._http.get_invite(code)
|
||||||
|
return self.parse_invite(data)
|
||||||
|
except DisagreementException as e:
|
||||||
|
print(f"Failed to fetch invite {code}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]:
|
async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]:
|
||||||
"""|coro| Fetch all invites for a channel."""
|
"""|coro| Fetch all invites for a channel."""
|
||||||
@ -1598,11 +1776,13 @@ class Client:
|
|||||||
|
|
||||||
for item in view.children:
|
for item in view.children:
|
||||||
if item.custom_id: # Ensure custom_id is not None
|
if item.custom_id: # Ensure custom_id is not None
|
||||||
if item.custom_id in self._persistent_views:
|
if item.custom_id in self._persistent_views:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"A component with custom_id '{item.custom_id}' is already registered."
|
f"A component with custom_id '{item.custom_id}' is already registered."
|
||||||
)
|
)
|
||||||
self._persistent_views[item.custom_id] = view
|
self._persistent_views[item.custom_id] = view
|
||||||
|
|
||||||
|
self._save_persistent_views()
|
||||||
|
|
||||||
# --- Application Command Methods ---
|
# --- Application Command Methods ---
|
||||||
async def process_interaction(self, interaction: Interaction) -> None:
|
async def process_interaction(self, interaction: Interaction) -> None:
|
||||||
@ -1647,16 +1827,6 @@ class Client:
|
|||||||
"Ensure the client is connected and READY."
|
"Ensure the client is connected and READY."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if not self.is_ready():
|
|
||||||
print(
|
|
||||||
"Warning: Client is not ready. Waiting for client to be ready before syncing commands."
|
|
||||||
)
|
|
||||||
await self.wait_until_ready()
|
|
||||||
if not self.application_id:
|
|
||||||
print(
|
|
||||||
"Error: application_id still not set after client is ready. Cannot sync commands."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
await self.app_command_handler.sync_commands(
|
await self.app_command_handler.sync_commands(
|
||||||
application_id=self.application_id, guild_id=guild_id
|
application_id=self.application_id, guild_id=guild_id
|
||||||
@ -1677,6 +1847,16 @@ class Client:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def on_connect(self) -> None:
|
||||||
|
"""|coro| Called when the WebSocket connection opens."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_disconnect(self) -> None:
|
||||||
|
"""|coro| Called when the WebSocket connection closes."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
async def on_app_command_error(
|
async def on_app_command_error(
|
||||||
self, context: AppCommandContext, error: Exception
|
self, context: AppCommandContext, error: Exception
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -268,12 +268,19 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum
|
|||||||
VERIFIED = "VERIFIED"
|
VERIFIED = "VERIFIED"
|
||||||
VIP_REGIONS = "VIP_REGIONS"
|
VIP_REGIONS = "VIP_REGIONS"
|
||||||
WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED"
|
WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED"
|
||||||
|
SOUNDBOARD = "SOUNDBOARD"
|
||||||
|
VIDEO_QUALITY_720_60FPS = "VIDEO_QUALITY_720_60FPS"
|
||||||
# Add more as they become known or needed
|
# Add more as they become known or needed
|
||||||
|
|
||||||
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
|
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
|
||||||
@classmethod
|
@classmethod
|
||||||
def _missing_(cls, value): # type: ignore
|
def _missing_(cls, value): # type: ignore
|
||||||
return str(value)
|
member = object.__new__(cls)
|
||||||
|
member._name_ = str(value)
|
||||||
|
member._value_ = str(value)
|
||||||
|
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
|
||||||
|
cls._member_map_[member._name_] = member # pylint: disable=no-member
|
||||||
|
return member
|
||||||
|
|
||||||
|
|
||||||
# --- Guild Scheduled Event Enums ---
|
# --- Guild Scheduled Event Enums ---
|
||||||
@ -329,7 +336,12 @@ class VoiceRegion(str, Enum):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _missing_(cls, value): # type: ignore
|
def _missing_(cls, value): # type: ignore
|
||||||
return str(value)
|
member = object.__new__(cls)
|
||||||
|
member._name_ = str(value)
|
||||||
|
member._value_ = str(value)
|
||||||
|
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
|
||||||
|
cls._member_map_[member._name_] = member # pylint: disable=no-member
|
||||||
|
return member
|
||||||
|
|
||||||
|
|
||||||
# --- Channel Enums ---
|
# --- Channel Enums ---
|
||||||
|
@ -61,6 +61,11 @@ class EventDispatcher:
|
|||||||
"GUILD_ROLE_UPDATE": self._parse_guild_role_update,
|
"GUILD_ROLE_UPDATE": self._parse_guild_role_update,
|
||||||
"TYPING_START": self._parse_typing_start,
|
"TYPING_START": self._parse_typing_start,
|
||||||
"VOICE_STATE_UPDATE": self._parse_voice_state_update,
|
"VOICE_STATE_UPDATE": self._parse_voice_state_update,
|
||||||
|
"THREAD_CREATE": self._parse_thread_create,
|
||||||
|
"THREAD_UPDATE": self._parse_thread_update,
|
||||||
|
"THREAD_DELETE": self._parse_thread_delete,
|
||||||
|
"INVITE_CREATE": self._parse_invite_create,
|
||||||
|
"INVITE_DELETE": self._parse_invite_delete,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _parse_message_create(self, data: Dict[str, Any]) -> Message:
|
def _parse_message_create(self, data: Dict[str, Any]) -> Message:
|
||||||
@ -165,6 +170,43 @@ class EventDispatcher:
|
|||||||
|
|
||||||
return GuildRoleUpdate(data, client_instance=self._client)
|
return GuildRoleUpdate(data, client_instance=self._client)
|
||||||
|
|
||||||
|
def _parse_thread_create(self, data: Dict[str, Any]):
|
||||||
|
"""Parses THREAD_CREATE into a Thread object and updates caches."""
|
||||||
|
|
||||||
|
return self._client.parse_channel(data)
|
||||||
|
|
||||||
|
def _parse_thread_update(self, data: Dict[str, Any]):
|
||||||
|
"""Parses THREAD_UPDATE into a Thread object."""
|
||||||
|
|
||||||
|
return self._client.parse_channel(data)
|
||||||
|
|
||||||
|
def _parse_thread_delete(self, data: Dict[str, Any]):
|
||||||
|
"""Parses THREAD_DELETE, removing the thread from caches."""
|
||||||
|
|
||||||
|
thread = self._client.parse_channel(data)
|
||||||
|
thread_id = data.get("id")
|
||||||
|
if thread_id:
|
||||||
|
self._client._channels.invalidate(thread_id)
|
||||||
|
guild_id = data.get("guild_id")
|
||||||
|
if guild_id:
|
||||||
|
guild = self._client._guilds.get(guild_id)
|
||||||
|
if guild:
|
||||||
|
guild._channels.invalidate(thread_id)
|
||||||
|
guild._threads.pop(thread_id, None)
|
||||||
|
return thread
|
||||||
|
|
||||||
|
def _parse_invite_create(self, data: Dict[str, Any]):
|
||||||
|
"""Parses INVITE_CREATE into an Invite object."""
|
||||||
|
|
||||||
|
return self._client.parse_invite(data)
|
||||||
|
|
||||||
|
def _parse_invite_delete(self, data: Dict[str, Any]):
|
||||||
|
"""Parses INVITE_DELETE into an InviteDelete model."""
|
||||||
|
|
||||||
|
from .models import InviteDelete
|
||||||
|
|
||||||
|
return InviteDelete(data)
|
||||||
|
|
||||||
# Potentially add _parse_user for events that directly provide a full user object
|
# Potentially add _parse_user for events that directly provide a full user object
|
||||||
# def _parse_user_update(self, data: Dict[str, Any]) -> User:
|
# def _parse_user_update(self, data: Dict[str, Any]) -> User:
|
||||||
# return User(data=data)
|
# return User(data=data)
|
||||||
|
@ -18,51 +18,23 @@ from typing import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from disagreement.client import Client
|
from disagreement.client import Client
|
||||||
from disagreement.interactions import Interaction, ResolvedData, Snowflake
|
from disagreement.interactions import Interaction, ResolvedData, Snowflake
|
||||||
from disagreement.enums import (
|
|
||||||
ApplicationCommandType,
|
|
||||||
ApplicationCommandOptionType,
|
|
||||||
InteractionType,
|
|
||||||
)
|
|
||||||
from .commands import (
|
|
||||||
AppCommand,
|
|
||||||
SlashCommand,
|
|
||||||
UserCommand,
|
|
||||||
MessageCommand,
|
|
||||||
AppCommandGroup,
|
|
||||||
)
|
|
||||||
from .context import AppCommandContext
|
|
||||||
from disagreement.models import (
|
|
||||||
User,
|
|
||||||
Member,
|
|
||||||
Role,
|
|
||||||
Attachment,
|
|
||||||
Message,
|
|
||||||
) # For resolved data
|
|
||||||
|
|
||||||
# Channel models would also go here
|
from disagreement.enums import (
|
||||||
|
ApplicationCommandType,
|
||||||
|
ApplicationCommandOptionType,
|
||||||
|
InteractionType,
|
||||||
|
)
|
||||||
|
from .commands import (
|
||||||
|
AppCommand,
|
||||||
|
SlashCommand,
|
||||||
|
UserCommand,
|
||||||
|
MessageCommand,
|
||||||
|
AppCommandGroup,
|
||||||
|
)
|
||||||
|
from .context import AppCommandContext
|
||||||
|
from disagreement.models import User, Member, Role, Attachment, Message
|
||||||
|
|
||||||
# Placeholder for models not yet fully defined or imported
|
Channel = Any
|
||||||
if not TYPE_CHECKING:
|
|
||||||
from disagreement.enums import (
|
|
||||||
ApplicationCommandType,
|
|
||||||
ApplicationCommandOptionType,
|
|
||||||
InteractionType,
|
|
||||||
)
|
|
||||||
from .commands import (
|
|
||||||
AppCommand,
|
|
||||||
SlashCommand,
|
|
||||||
UserCommand,
|
|
||||||
MessageCommand,
|
|
||||||
AppCommandGroup,
|
|
||||||
)
|
|
||||||
from .context import AppCommandContext
|
|
||||||
|
|
||||||
User = Any
|
|
||||||
Member = Any
|
|
||||||
Role = Any
|
|
||||||
Attachment = Any
|
|
||||||
Channel = Any
|
|
||||||
Message = Any
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -587,12 +559,19 @@ class AppCommandHandler:
|
|||||||
# print(f"Failed to send error message for app command: {send_e}")
|
# print(f"Failed to send error message for app command: {send_e}")
|
||||||
|
|
||||||
async def sync_commands(
|
async def sync_commands(
|
||||||
self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None
|
self,
|
||||||
|
application_id: Optional["Snowflake"] = None,
|
||||||
|
guild_id: Optional["Snowflake"] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Synchronizes (registers/updates) all application commands with Discord.
|
Synchronizes (registers/updates) all application commands with Discord.
|
||||||
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
|
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
|
||||||
"""
|
"""
|
||||||
|
if application_id is None:
|
||||||
|
application_id = self.client.application_id
|
||||||
|
if application_id is None:
|
||||||
|
raise ValueError("application_id must be provided to sync commands")
|
||||||
|
|
||||||
cache = self._load_cached_ids()
|
cache = self._load_cached_ids()
|
||||||
scope_key = str(guild_id) if guild_id else "global"
|
scope_key = str(guild_id) if guild_id else "global"
|
||||||
stored = cache.get(scope_key, {})
|
stored = cache.get(scope_key, {})
|
||||||
|
@ -14,11 +14,12 @@ from .decorators import (
|
|||||||
check,
|
check,
|
||||||
check_any,
|
check_any,
|
||||||
cooldown,
|
cooldown,
|
||||||
max_concurrency,
|
max_concurrency,
|
||||||
requires_permissions,
|
requires_permissions,
|
||||||
has_role,
|
has_role,
|
||||||
has_any_role,
|
has_any_role,
|
||||||
)
|
is_owner,
|
||||||
|
)
|
||||||
from .errors import (
|
from .errors import (
|
||||||
CommandError,
|
CommandError,
|
||||||
CommandNotFound,
|
CommandNotFound,
|
||||||
@ -47,9 +48,10 @@ __all__ = [
|
|||||||
"cooldown",
|
"cooldown",
|
||||||
"max_concurrency",
|
"max_concurrency",
|
||||||
"requires_permissions",
|
"requires_permissions",
|
||||||
"has_role",
|
"has_role",
|
||||||
"has_any_role",
|
"has_any_role",
|
||||||
# Errors
|
"is_owner",
|
||||||
|
# Errors
|
||||||
"CommandError",
|
"CommandError",
|
||||||
"CommandNotFound",
|
"CommandNotFound",
|
||||||
"BadArgument",
|
"BadArgument",
|
||||||
|
@ -6,7 +6,16 @@ import re
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from .errors import BadArgument
|
from .errors import BadArgument
|
||||||
from disagreement.models import Member, Guild, Role
|
from disagreement.models import (
|
||||||
|
Member,
|
||||||
|
Guild,
|
||||||
|
Role,
|
||||||
|
User,
|
||||||
|
TextChannel,
|
||||||
|
VoiceChannel,
|
||||||
|
Emoji,
|
||||||
|
PartialEmoji,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .core import CommandContext
|
from .core import CommandContext
|
||||||
@ -143,6 +152,97 @@ class GuildConverter(Converter["Guild"]):
|
|||||||
raise BadArgument(f"Guild '{argument}' not found.")
|
raise BadArgument(f"Guild '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class UserConverter(Converter["User"]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> "User":
|
||||||
|
match = re.match(r"<@!?(\d+)>$", argument)
|
||||||
|
user_id = match.group(1) if match else argument
|
||||||
|
|
||||||
|
user = ctx.bot._users.get(user_id)
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
|
||||||
|
user = await ctx.bot.fetch_user(user_id)
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
raise BadArgument(f"User '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class TextChannelConverter(Converter["TextChannel"]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> "TextChannel":
|
||||||
|
if not ctx.message.guild_id:
|
||||||
|
raise BadArgument("TextChannel converter requires guild context.")
|
||||||
|
|
||||||
|
match = re.match(r"<#(?P<id>\d+)>$", argument)
|
||||||
|
channel_id = match.group("id") if match else argument
|
||||||
|
|
||||||
|
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||||
|
if guild:
|
||||||
|
channel = guild.get_channel(channel_id)
|
||||||
|
if isinstance(channel, TextChannel):
|
||||||
|
return channel
|
||||||
|
|
||||||
|
channel = (
|
||||||
|
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
|
||||||
|
)
|
||||||
|
if isinstance(channel, TextChannel):
|
||||||
|
return channel
|
||||||
|
|
||||||
|
if hasattr(ctx.bot, "fetch_channel"):
|
||||||
|
channel = await ctx.bot.fetch_channel(channel_id)
|
||||||
|
if isinstance(channel, TextChannel):
|
||||||
|
return channel
|
||||||
|
|
||||||
|
raise BadArgument(f"Text channel '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceChannelConverter(Converter["VoiceChannel"]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> "VoiceChannel":
|
||||||
|
if not ctx.message.guild_id:
|
||||||
|
raise BadArgument("VoiceChannel converter requires guild context.")
|
||||||
|
|
||||||
|
match = re.match(r"<#(?P<id>\d+)>$", argument)
|
||||||
|
channel_id = match.group("id") if match else argument
|
||||||
|
|
||||||
|
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||||
|
if guild:
|
||||||
|
channel = guild.get_channel(channel_id)
|
||||||
|
if isinstance(channel, VoiceChannel):
|
||||||
|
return channel
|
||||||
|
|
||||||
|
channel = (
|
||||||
|
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
|
||||||
|
)
|
||||||
|
if isinstance(channel, VoiceChannel):
|
||||||
|
return channel
|
||||||
|
|
||||||
|
if hasattr(ctx.bot, "fetch_channel"):
|
||||||
|
channel = await ctx.bot.fetch_channel(channel_id)
|
||||||
|
if isinstance(channel, VoiceChannel):
|
||||||
|
return channel
|
||||||
|
|
||||||
|
raise BadArgument(f"Voice channel '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class EmojiConverter(Converter["PartialEmoji"]):
|
||||||
|
_CUSTOM_RE = re.compile(r"<(?P<animated>a)?:(?P<name>[^:]+):(?P<id>\d+)>$")
|
||||||
|
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> "PartialEmoji":
|
||||||
|
match = self._CUSTOM_RE.match(argument)
|
||||||
|
if match:
|
||||||
|
return PartialEmoji(
|
||||||
|
{
|
||||||
|
"id": match.group("id"),
|
||||||
|
"name": match.group("name"),
|
||||||
|
"animated": bool(match.group("animated")),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if argument:
|
||||||
|
return PartialEmoji({"id": None, "name": argument})
|
||||||
|
|
||||||
|
raise BadArgument(f"Emoji '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
# Default converters mapping
|
# Default converters mapping
|
||||||
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||||
int: IntConverter(),
|
int: IntConverter(),
|
||||||
@ -152,7 +252,11 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
|||||||
Member: MemberConverter(),
|
Member: MemberConverter(),
|
||||||
Guild: GuildConverter(),
|
Guild: GuildConverter(),
|
||||||
Role: RoleConverter(),
|
Role: RoleConverter(),
|
||||||
# User: UserConverter(), # Add when User model and converter are ready
|
User: UserConverter(),
|
||||||
|
TextChannel: TextChannelConverter(),
|
||||||
|
VoiceChannel: VoiceChannelConverter(),
|
||||||
|
PartialEmoji: EmojiConverter(),
|
||||||
|
Emoji: EmojiConverter(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,8 +79,15 @@ class GroupMixin:
|
|||||||
)
|
)
|
||||||
self.commands[alias.lower()] = command
|
self.commands[alias.lower()] = command
|
||||||
|
|
||||||
def get_command(self, name: str) -> Optional["Command"]:
|
def get_command(self, name: str) -> Optional["Command"]:
|
||||||
return self.commands.get(name.lower())
|
return self.commands.get(name.lower())
|
||||||
|
|
||||||
|
def walk_commands(self):
|
||||||
|
"""Yield all commands in this group recursively."""
|
||||||
|
for cmd in dict.fromkeys(self.commands.values()):
|
||||||
|
yield cmd
|
||||||
|
if isinstance(cmd, Group):
|
||||||
|
yield from cmd.walk_commands()
|
||||||
|
|
||||||
|
|
||||||
class Command(GroupMixin):
|
class Command(GroupMixin):
|
||||||
@ -363,8 +370,20 @@ class CommandHandler:
|
|||||||
self.commands.pop(alias.lower(), None)
|
self.commands.pop(alias.lower(), None)
|
||||||
return command
|
return command
|
||||||
|
|
||||||
def get_command(self, name: str) -> Optional[Command]:
|
def get_command(self, name: str) -> Optional[Command]:
|
||||||
return self.commands.get(name.lower())
|
return self.commands.get(name.lower())
|
||||||
|
|
||||||
|
def walk_commands(self):
|
||||||
|
"""Yield every registered command, including subcommands."""
|
||||||
|
for cmd in dict.fromkeys(self.commands.values()):
|
||||||
|
yield cmd
|
||||||
|
if isinstance(cmd, Group):
|
||||||
|
yield from cmd.walk_commands()
|
||||||
|
|
||||||
|
def get_cog(self, name: str) -> Optional["Cog"]:
|
||||||
|
"""Return a loaded cog by name if present."""
|
||||||
|
|
||||||
|
return self.cogs.get(name)
|
||||||
|
|
||||||
def add_cog(self, cog_to_add: "Cog") -> None:
|
def add_cog(self, cog_to_add: "Cog") -> None:
|
||||||
from .cog import Cog
|
from .cog import Cog
|
||||||
@ -471,9 +490,9 @@ class CommandHandler:
|
|||||||
return self.prefix(self.client, message) # type: ignore
|
return self.prefix(self.client, message) # type: ignore
|
||||||
return self.prefix
|
return self.prefix
|
||||||
|
|
||||||
async def _parse_arguments(
|
async def _parse_arguments(
|
||||||
self, command: Command, ctx: CommandContext, view: StringView
|
self, command: Command, ctx: CommandContext, view: StringView
|
||||||
) -> Tuple[List[Any], Dict[str, Any]]:
|
) -> Tuple[List[Any], Dict[str, Any]]:
|
||||||
args_list = []
|
args_list = []
|
||||||
kwargs_dict = {}
|
kwargs_dict = {}
|
||||||
params_to_parse = list(command.params.values())
|
params_to_parse = list(command.params.values())
|
||||||
@ -636,7 +655,79 @@ class CommandHandler:
|
|||||||
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
|
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
|
||||||
kwargs_dict[param.name] = final_value_for_param
|
kwargs_dict[param.name] = final_value_for_param
|
||||||
|
|
||||||
return args_list, kwargs_dict
|
return args_list, kwargs_dict
|
||||||
|
|
||||||
|
async def get_context(self, message: "Message") -> Optional[CommandContext]:
|
||||||
|
"""Parse a message and return a :class:`CommandContext` without executing the command.
|
||||||
|
|
||||||
|
Returns ``None`` if the message does not invoke a command."""
|
||||||
|
|
||||||
|
if not message.content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
prefix_to_use = await self.get_prefix(message)
|
||||||
|
if not prefix_to_use:
|
||||||
|
return None
|
||||||
|
|
||||||
|
actual_prefix: Optional[str] = None
|
||||||
|
if isinstance(prefix_to_use, list):
|
||||||
|
for p in prefix_to_use:
|
||||||
|
if message.content.startswith(p):
|
||||||
|
actual_prefix = p
|
||||||
|
break
|
||||||
|
if not actual_prefix:
|
||||||
|
return None
|
||||||
|
elif isinstance(prefix_to_use, str):
|
||||||
|
if message.content.startswith(prefix_to_use):
|
||||||
|
actual_prefix = prefix_to_use
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if actual_prefix is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
view = StringView(message.content[len(actual_prefix) :])
|
||||||
|
|
||||||
|
command_name = view.get_word()
|
||||||
|
if not command_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
command = self.get_command(command_name)
|
||||||
|
if not command:
|
||||||
|
return None
|
||||||
|
|
||||||
|
invoked_with = command_name
|
||||||
|
|
||||||
|
if isinstance(command, Group):
|
||||||
|
view.skip_whitespace()
|
||||||
|
potential_subcommand = view.get_word()
|
||||||
|
if potential_subcommand:
|
||||||
|
subcommand = command.get_command(potential_subcommand)
|
||||||
|
if subcommand:
|
||||||
|
command = subcommand
|
||||||
|
invoked_with += f" {potential_subcommand}"
|
||||||
|
elif command.invoke_without_command:
|
||||||
|
view.index -= len(potential_subcommand) + view.previous
|
||||||
|
else:
|
||||||
|
raise CommandNotFound(
|
||||||
|
f"Subcommand '{potential_subcommand}' not found."
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message,
|
||||||
|
bot=self.client,
|
||||||
|
prefix=actual_prefix,
|
||||||
|
command=command,
|
||||||
|
invoked_with=invoked_with,
|
||||||
|
cog=command.cog,
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
|
||||||
|
ctx.args = parsed_args
|
||||||
|
ctx.kwargs = parsed_kwargs
|
||||||
|
return ctx
|
||||||
|
|
||||||
async def process_commands(self, message: "Message") -> None:
|
async def process_commands(self, message: "Message") -> None:
|
||||||
if not message.content:
|
if not message.content:
|
||||||
|
@ -292,3 +292,19 @@ def has_any_role(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
|
def is_owner() -> (
|
||||||
|
Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]
|
||||||
|
):
|
||||||
|
"""Check that the invoking user is listed as a bot owner."""
|
||||||
|
|
||||||
|
async def predicate(ctx: "CommandContext") -> bool:
|
||||||
|
from .errors import CheckFailure
|
||||||
|
|
||||||
|
owner_ids = getattr(ctx.bot, "owner_ids", [])
|
||||||
|
if str(ctx.author.id) not in {str(o) for o in owner_ids}:
|
||||||
|
raise CheckFailure("This command can only be used by the bot owner.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return check(predicate)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from .core import Command, CommandContext, CommandHandler
|
from ...utils import Paginator
|
||||||
|
from .core import Command, CommandContext, CommandHandler, Group
|
||||||
|
|
||||||
|
|
||||||
class HelpCommand(Command):
|
class HelpCommand(Command):
|
||||||
@ -15,17 +17,22 @@ class HelpCommand(Command):
|
|||||||
if not cmd or cmd.name.lower() != command.lower():
|
if not cmd or cmd.name.lower() != command.lower():
|
||||||
await ctx.send(f"Command '{command}' not found.")
|
await ctx.send(f"Command '{command}' not found.")
|
||||||
return
|
return
|
||||||
description = cmd.description or cmd.brief or "No description provided."
|
if isinstance(cmd, Group):
|
||||||
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
|
await self.send_group_help(ctx, cmd)
|
||||||
else:
|
elif cmd:
|
||||||
lines: List[str] = []
|
description = cmd.description or cmd.brief or "No description provided."
|
||||||
for registered in dict.fromkeys(handler.commands.values()):
|
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
|
||||||
brief = registered.brief or registered.description or ""
|
|
||||||
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
|
|
||||||
if lines:
|
|
||||||
await ctx.send("\n".join(lines))
|
|
||||||
else:
|
else:
|
||||||
await ctx.send("No commands available.")
|
lines: List[str] = []
|
||||||
|
for registered in handler.walk_commands():
|
||||||
|
brief = registered.brief or registered.description or ""
|
||||||
|
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
|
||||||
|
if lines:
|
||||||
|
await ctx.send("\n".join(lines))
|
||||||
|
else:
|
||||||
|
await self.send_command_help(ctx, cmd)
|
||||||
|
else:
|
||||||
|
await self.send_bot_help(ctx)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
callback,
|
callback,
|
||||||
@ -33,3 +40,42 @@ class HelpCommand(Command):
|
|||||||
brief="Show command help.",
|
brief="Show command help.",
|
||||||
description="Displays help for commands.",
|
description="Displays help for commands.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def send_bot_help(self, ctx: CommandContext) -> None:
|
||||||
|
groups = defaultdict(list)
|
||||||
|
for cmd in dict.fromkeys(self.handler.commands.values()):
|
||||||
|
key = cmd.cog.cog_name if cmd.cog else "No Category"
|
||||||
|
groups[key].append(cmd)
|
||||||
|
|
||||||
|
paginator = Paginator()
|
||||||
|
for cog_name, cmds in groups.items():
|
||||||
|
paginator.add_line(f"**{cog_name}**")
|
||||||
|
for cmd in cmds:
|
||||||
|
brief = cmd.brief or cmd.description or ""
|
||||||
|
paginator.add_line(f"{ctx.prefix}{cmd.name} - {brief}".strip())
|
||||||
|
paginator.add_line("")
|
||||||
|
|
||||||
|
pages = paginator.pages
|
||||||
|
if not pages:
|
||||||
|
await ctx.send("No commands available.")
|
||||||
|
return
|
||||||
|
for page in pages:
|
||||||
|
await ctx.send(page)
|
||||||
|
|
||||||
|
async def send_command_help(self, ctx: CommandContext, command: Command) -> None:
|
||||||
|
description = command.description or command.brief or "No description provided."
|
||||||
|
await ctx.send(f"**{ctx.prefix}{command.name}**\n{description}")
|
||||||
|
|
||||||
|
async def send_group_help(self, ctx: CommandContext, group: Group) -> None:
|
||||||
|
paginator = Paginator()
|
||||||
|
description = group.description or group.brief or "No description provided."
|
||||||
|
paginator.add_line(f"**{ctx.prefix}{group.name}**\n{description}")
|
||||||
|
if group.commands:
|
||||||
|
for sub in dict.fromkeys(group.commands.values()):
|
||||||
|
brief = sub.brief or sub.description or ""
|
||||||
|
paginator.add_line(
|
||||||
|
f"{ctx.prefix}{group.name} {sub.name} - {brief}".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
for page in paginator.pages:
|
||||||
|
await ctx.send(page)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Dict
|
from typing import Any, Coroutine, Dict, cast
|
||||||
|
|
||||||
__all__ = ["load_extension", "unload_extension", "reload_extension"]
|
__all__ = ["load_extension", "unload_extension", "reload_extension"]
|
||||||
|
|
||||||
@ -25,7 +27,20 @@ def load_extension(name: str) -> ModuleType:
|
|||||||
if not hasattr(module, "setup"):
|
if not hasattr(module, "setup"):
|
||||||
raise ImportError(f"Extension '{name}' does not define a setup function")
|
raise ImportError(f"Extension '{name}' does not define a setup function")
|
||||||
|
|
||||||
module.setup()
|
result = module.setup()
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
coro = cast(Coroutine[Any, Any, Any], result)
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
asyncio.run(coro)
|
||||||
|
else:
|
||||||
|
if loop.is_running():
|
||||||
|
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||||
|
future.result()
|
||||||
|
else:
|
||||||
|
loop.run_until_complete(coro)
|
||||||
|
|
||||||
_loaded_extensions[name] = module
|
_loaded_extensions[name] = module
|
||||||
return module
|
return module
|
||||||
|
|
||||||
@ -38,7 +53,19 @@ def unload_extension(name: str) -> None:
|
|||||||
raise ValueError(f"Extension '{name}' is not loaded")
|
raise ValueError(f"Extension '{name}' is not loaded")
|
||||||
|
|
||||||
if hasattr(module, "teardown"):
|
if hasattr(module, "teardown"):
|
||||||
module.teardown()
|
result = module.teardown()
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
coro = cast(Coroutine[Any, Any, Any], result)
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
asyncio.run(coro)
|
||||||
|
else:
|
||||||
|
if loop.is_running():
|
||||||
|
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||||
|
future.result()
|
||||||
|
else:
|
||||||
|
loop.run_until_complete(coro)
|
||||||
|
|
||||||
sys.modules.pop(name, None)
|
sys.modules.pop(name, None)
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ class Task:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._coro = coro
|
self._coro = coro
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
self._task: Optional[asyncio.Task[None]] = None
|
||||||
|
self._current_loop = 0
|
||||||
if time_of_day is not None and (
|
if time_of_day is not None and (
|
||||||
seconds or minutes or hours or delta is not None
|
seconds or minutes or hours or delta is not None
|
||||||
):
|
):
|
||||||
@ -68,6 +69,7 @@ class Task:
|
|||||||
await _maybe_call(self._on_error, exc)
|
await _maybe_call(self._on_error, exc)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
self._current_loop += 1
|
||||||
|
|
||||||
first = False
|
first = False
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@ -78,6 +80,7 @@ class Task:
|
|||||||
|
|
||||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||||
if self._task is None or self._task.done():
|
if self._task is None or self._task.done():
|
||||||
|
self._current_loop = 0
|
||||||
self._task = asyncio.create_task(self._run(*args, **kwargs))
|
self._task = asyncio.create_task(self._run(*args, **kwargs))
|
||||||
return self._task
|
return self._task
|
||||||
|
|
||||||
@ -90,6 +93,34 @@ class Task:
|
|||||||
def running(self) -> bool:
|
def running(self) -> bool:
|
||||||
return self._task is not None and not self._task.done()
|
return self._task is not None and not self._task.done()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_loop(self) -> int:
|
||||||
|
return self._current_loop
|
||||||
|
|
||||||
|
def change_interval(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seconds: float = 0.0,
|
||||||
|
minutes: float = 0.0,
|
||||||
|
hours: float = 0.0,
|
||||||
|
delta: Optional[datetime.timedelta] = None,
|
||||||
|
time_of_day: Optional[datetime.time] = None,
|
||||||
|
) -> None:
|
||||||
|
if time_of_day is not None and (
|
||||||
|
seconds or minutes or hours or delta is not None
|
||||||
|
):
|
||||||
|
raise ValueError("time_of_day cannot be used with an interval")
|
||||||
|
|
||||||
|
if delta is not None:
|
||||||
|
if not isinstance(delta, datetime.timedelta):
|
||||||
|
raise TypeError("delta must be a datetime.timedelta")
|
||||||
|
interval_seconds = delta.total_seconds()
|
||||||
|
else:
|
||||||
|
interval_seconds = seconds + minutes * 60.0 + hours * 3600.0
|
||||||
|
|
||||||
|
self._seconds = float(interval_seconds)
|
||||||
|
self._time_of_day = time_of_day
|
||||||
|
|
||||||
|
|
||||||
async def _maybe_call(
|
async def _maybe_call(
|
||||||
func: Callable[[Exception], Awaitable[None] | None], exc: Exception
|
func: Callable[[Exception], Awaitable[None] | None], exc: Exception
|
||||||
@ -181,10 +212,37 @@ class _Loop:
|
|||||||
if self._task is not None:
|
if self._task is not None:
|
||||||
self._task.stop()
|
self._task.stop()
|
||||||
|
|
||||||
|
def change_interval(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seconds: float = 0.0,
|
||||||
|
minutes: float = 0.0,
|
||||||
|
hours: float = 0.0,
|
||||||
|
delta: Optional[datetime.timedelta] = None,
|
||||||
|
time_of_day: Optional[datetime.time] = None,
|
||||||
|
) -> None:
|
||||||
|
self.seconds = seconds
|
||||||
|
self.minutes = minutes
|
||||||
|
self.hours = hours
|
||||||
|
self.delta = delta
|
||||||
|
self.time_of_day = time_of_day
|
||||||
|
if self._task is not None:
|
||||||
|
self._task.change_interval(
|
||||||
|
seconds=seconds,
|
||||||
|
minutes=minutes,
|
||||||
|
hours=hours,
|
||||||
|
delta=delta,
|
||||||
|
time_of_day=time_of_day,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self) -> bool:
|
def running(self) -> bool:
|
||||||
return self._task.running if self._task else False
|
return self._task.running if self._task else False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_loop(self) -> int:
|
||||||
|
return self._task.current_loop if self._task else 0
|
||||||
|
|
||||||
|
|
||||||
class _BoundLoop:
|
class _BoundLoop:
|
||||||
def __init__(self, parent: _Loop, owner: Any) -> None:
|
def __init__(self, parent: _Loop, owner: Any) -> None:
|
||||||
@ -202,6 +260,27 @@ class _BoundLoop:
|
|||||||
def running(self) -> bool:
|
def running(self) -> bool:
|
||||||
return self._parent.running
|
return self._parent.running
|
||||||
|
|
||||||
|
def change_interval(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seconds: float = 0.0,
|
||||||
|
minutes: float = 0.0,
|
||||||
|
hours: float = 0.0,
|
||||||
|
delta: Optional[datetime.timedelta] = None,
|
||||||
|
time_of_day: Optional[datetime.time] = None,
|
||||||
|
) -> None:
|
||||||
|
self._parent.change_interval(
|
||||||
|
seconds=seconds,
|
||||||
|
minutes=minutes,
|
||||||
|
hours=hours,
|
||||||
|
delta=delta,
|
||||||
|
time_of_day=time_of_day,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_loop(self) -> int:
|
||||||
|
return self._parent.current_loop
|
||||||
|
|
||||||
|
|
||||||
def loop(
|
def loop(
|
||||||
*,
|
*,
|
||||||
|
@ -334,7 +334,19 @@ class GatewayClient:
|
|||||||
self._resume_gateway_url,
|
self._resume_gateway_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The client is now ready for operations. Set the event before dispatching to user code.
|
||||||
|
self._client_instance._ready_event.set()
|
||||||
|
logger.info("Client is now marked as ready.")
|
||||||
|
|
||||||
|
if isinstance(raw_event_d_payload, dict) and self._shard_id is not None:
|
||||||
|
raw_event_d_payload["shard_id"] = self._shard_id
|
||||||
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
|
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
|
||||||
|
|
||||||
|
if (
|
||||||
|
getattr(self._client_instance, "sync_commands_on_ready", True)
|
||||||
|
and self._client_instance.application_id
|
||||||
|
):
|
||||||
|
asyncio.create_task(self._client_instance.sync_application_commands())
|
||||||
elif event_name == "GUILD_MEMBERS_CHUNK":
|
elif event_name == "GUILD_MEMBERS_CHUNK":
|
||||||
if isinstance(raw_event_d_payload, dict):
|
if isinstance(raw_event_d_payload, dict):
|
||||||
nonce = raw_event_d_payload.get("nonce")
|
nonce = raw_event_d_payload.get("nonce")
|
||||||
@ -384,6 +396,8 @@ class GatewayClient:
|
|||||||
event_data_to_dispatch = (
|
event_data_to_dispatch = (
|
||||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||||
)
|
)
|
||||||
|
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
|
||||||
|
event_data_to_dispatch["shard_id"] = self._shard_id
|
||||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||||
await self._dispatcher.dispatch(
|
await self._dispatcher.dispatch(
|
||||||
"SHARD_RESUME", {"shard_id": self._shard_id}
|
"SHARD_RESUME", {"shard_id": self._shard_id}
|
||||||
@ -394,6 +408,8 @@ class GatewayClient:
|
|||||||
event_data_to_dispatch = (
|
event_data_to_dispatch = (
|
||||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||||
)
|
)
|
||||||
|
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
|
||||||
|
event_data_to_dispatch["shard_id"] = self._shard_id
|
||||||
|
|
||||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||||
else:
|
else:
|
||||||
@ -553,6 +569,7 @@ class GatewayClient:
|
|||||||
await self._dispatcher.dispatch(
|
await self._dispatcher.dispatch(
|
||||||
"SHARD_CONNECT", {"shard_id": self._shard_id}
|
"SHARD_CONNECT", {"shard_id": self._shard_id}
|
||||||
)
|
)
|
||||||
|
await self._dispatcher.dispatch("CONNECT", {"shard_id": self._shard_id})
|
||||||
|
|
||||||
except aiohttp.ClientConnectorError as e:
|
except aiohttp.ClientConnectorError as e:
|
||||||
raise GatewayException(
|
raise GatewayException(
|
||||||
@ -608,6 +625,7 @@ class GatewayClient:
|
|||||||
await self._dispatcher.dispatch(
|
await self._dispatcher.dispatch(
|
||||||
"SHARD_DISCONNECT", {"shard_id": self._shard_id}
|
"SHARD_DISCONNECT", {"shard_id": self._shard_id}
|
||||||
)
|
)
|
||||||
|
await self._dispatcher.dispatch("DISCONNECT", {"shard_id": self._shard_id})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def latency(self) -> Optional[float]:
|
def latency(self) -> Optional[float]:
|
||||||
|
@ -663,6 +663,15 @@ class HTTPClient:
|
|||||||
|
|
||||||
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
|
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
|
||||||
|
|
||||||
|
async def crosspost_message(
|
||||||
|
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Crossposts a message to any following channels."""
|
||||||
|
|
||||||
|
return await self.request(
|
||||||
|
"POST", f"/channels/{channel_id}/messages/{message_id}/crosspost"
|
||||||
|
)
|
||||||
|
|
||||||
async def delete_channel(
|
async def delete_channel(
|
||||||
self, channel_id: str, reason: Optional[str] = None
|
self, channel_id: str, reason: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -702,6 +711,22 @@ class HTTPClient:
|
|||||||
"""Fetches a channel by ID."""
|
"""Fetches a channel by ID."""
|
||||||
return await self.request("GET", f"/channels/{channel_id}")
|
return await self.request("GET", f"/channels/{channel_id}")
|
||||||
|
|
||||||
|
async def create_guild_channel(
|
||||||
|
self,
|
||||||
|
guild_id: "Snowflake",
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Creates a new channel in the specified guild."""
|
||||||
|
|
||||||
|
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||||
|
return await self.request(
|
||||||
|
"POST",
|
||||||
|
f"/guilds/{guild_id}/channels",
|
||||||
|
payload=payload,
|
||||||
|
custom_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_channel_invites(
|
async def get_channel_invites(
|
||||||
self, channel_id: "Snowflake"
|
self, channel_id: "Snowflake"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
@ -721,11 +746,36 @@ class HTTPClient:
|
|||||||
|
|
||||||
return Invite.from_dict(data)
|
return Invite.from_dict(data)
|
||||||
|
|
||||||
|
async def create_channel_invite(
|
||||||
|
self,
|
||||||
|
channel_id: "Snowflake",
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
*,
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
) -> "Invite":
|
||||||
|
"""Creates an invite for a channel with an optional audit log reason."""
|
||||||
|
|
||||||
|
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||||
|
data = await self.request(
|
||||||
|
"POST",
|
||||||
|
f"/channels/{channel_id}/invites",
|
||||||
|
payload=payload,
|
||||||
|
custom_headers=headers,
|
||||||
|
)
|
||||||
|
from .models import Invite
|
||||||
|
|
||||||
|
return Invite.from_dict(data)
|
||||||
|
|
||||||
async def delete_invite(self, code: str) -> None:
|
async def delete_invite(self, code: str) -> None:
|
||||||
"""Deletes an invite by code."""
|
"""Deletes an invite by code."""
|
||||||
|
|
||||||
await self.request("DELETE", f"/invites/{code}")
|
await self.request("DELETE", f"/invites/{code}")
|
||||||
|
|
||||||
|
async def get_invite(self, code: "Snowflake") -> Dict[str, Any]:
|
||||||
|
"""Fetches a single invite by its code."""
|
||||||
|
|
||||||
|
return await self.request("GET", f"/invites/{code}")
|
||||||
|
|
||||||
async def create_webhook(
|
async def create_webhook(
|
||||||
self, channel_id: "Snowflake", payload: Dict[str, Any]
|
self, channel_id: "Snowflake", payload: Dict[str, Any]
|
||||||
) -> "Webhook":
|
) -> "Webhook":
|
||||||
@ -738,6 +788,11 @@ class HTTPClient:
|
|||||||
|
|
||||||
return Webhook(data)
|
return Webhook(data)
|
||||||
|
|
||||||
|
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
|
||||||
|
"""Fetches a webhook by ID and returns the raw payload."""
|
||||||
|
|
||||||
|
return await self.request("GET", f"/webhooks/{webhook_id}")
|
||||||
|
|
||||||
async def edit_webhook(
|
async def edit_webhook(
|
||||||
self, webhook_id: "Snowflake", payload: Dict[str, Any]
|
self, webhook_id: "Snowflake", payload: Dict[str, Any]
|
||||||
) -> "Webhook":
|
) -> "Webhook":
|
||||||
@ -753,6 +808,24 @@ class HTTPClient:
|
|||||||
|
|
||||||
await self.request("DELETE", f"/webhooks/{webhook_id}")
|
await self.request("DELETE", f"/webhooks/{webhook_id}")
|
||||||
|
|
||||||
|
async def get_webhook_with_token(
|
||||||
|
self, webhook_id: "Snowflake", token: Optional[str] = None
|
||||||
|
) -> "Webhook":
|
||||||
|
"""Fetches a webhook by ID, optionally using its token."""
|
||||||
|
|
||||||
|
endpoint = f"/webhooks/{webhook_id}"
|
||||||
|
use_auth = True
|
||||||
|
if token is not None:
|
||||||
|
endpoint += f"/{token}"
|
||||||
|
use_auth = False
|
||||||
|
if use_auth:
|
||||||
|
data = await self.request("GET", endpoint)
|
||||||
|
else:
|
||||||
|
data = await self.request("GET", endpoint, use_auth_header=False)
|
||||||
|
from .models import Webhook
|
||||||
|
|
||||||
|
return Webhook(data)
|
||||||
|
|
||||||
async def execute_webhook(
|
async def execute_webhook(
|
||||||
self,
|
self,
|
||||||
webhook_id: "Snowflake",
|
webhook_id: "Snowflake",
|
||||||
@ -839,13 +912,13 @@ class HTTPClient:
|
|||||||
use_auth_header=False,
|
use_auth_header=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]:
|
async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]:
|
||||||
"""Fetches a user object for a given user ID."""
|
"""Fetches a user object for a given user ID."""
|
||||||
return await self.request("GET", f"/users/{user_id}")
|
return await self.request("GET", f"/users/{user_id}")
|
||||||
|
|
||||||
async def get_current_user_guilds(self) -> List[Dict[str, Any]]:
|
async def get_current_user_guilds(self) -> List[Dict[str, Any]]:
|
||||||
"""Returns the guilds the current user is in."""
|
"""Returns the guilds the current user is in."""
|
||||||
return await self.request("GET", "/users/@me/guilds")
|
return await self.request("GET", "/users/@me/guilds")
|
||||||
|
|
||||||
async def get_guild_member(
|
async def get_guild_member(
|
||||||
self, guild_id: "Snowflake", user_id: "Snowflake"
|
self, guild_id: "Snowflake", user_id: "Snowflake"
|
||||||
@ -902,6 +975,29 @@ class HTTPClient:
|
|||||||
custom_headers=headers,
|
custom_headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_guild_prune_count(self, guild_id: "Snowflake", *, days: int) -> int:
|
||||||
|
"""Returns the number of members that would be pruned."""
|
||||||
|
|
||||||
|
data = await self.request(
|
||||||
|
"GET",
|
||||||
|
f"/guilds/{guild_id}/prune",
|
||||||
|
params={"days": days},
|
||||||
|
)
|
||||||
|
return int(data.get("pruned", 0))
|
||||||
|
|
||||||
|
async def begin_guild_prune(
|
||||||
|
self, guild_id: "Snowflake", *, days: int, compute_count: bool = True
|
||||||
|
) -> int:
|
||||||
|
"""Begins a prune operation for the guild and returns the count."""
|
||||||
|
|
||||||
|
payload = {"days": days, "compute_prune_count": compute_count}
|
||||||
|
data = await self.request(
|
||||||
|
"POST",
|
||||||
|
f"/guilds/{guild_id}/prune",
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
return int(data.get("pruned", 0))
|
||||||
|
|
||||||
async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
|
async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
|
||||||
"""Returns a list of role objects for the guild."""
|
"""Returns a list of role objects for the guild."""
|
||||||
return await self.request("GET", f"/guilds/{guild_id}/roles")
|
return await self.request("GET", f"/guilds/{guild_id}/roles")
|
||||||
@ -1364,3 +1460,8 @@ class HTTPClient:
|
|||||||
async def leave_thread(self, channel_id: "Snowflake") -> None:
|
async def leave_thread(self, channel_id: "Snowflake") -> None:
|
||||||
"""Removes the current user from a thread."""
|
"""Removes the current user from a thread."""
|
||||||
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")
|
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")
|
||||||
|
|
||||||
|
async def create_dm(self, recipient_id: "Snowflake") -> Dict[str, Any]:
|
||||||
|
"""Creates (or opens) a DM channel with the given user."""
|
||||||
|
payload = {"recipient_id": str(recipient_id)}
|
||||||
|
return await self.request("POST", "/users/@me/channels", payload=payload)
|
||||||
|
@ -2,11 +2,26 @@
|
|||||||
Data models for Discord objects.
|
Data models for Discord objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
IO,
|
||||||
|
)
|
||||||
|
|
||||||
from .cache import ChannelCache, MemberCache
|
from .cache import ChannelCache, MemberCache
|
||||||
from .caching import MemberCacheFlags
|
from .caching import MemberCacheFlags
|
||||||
@ -40,29 +55,43 @@ if TYPE_CHECKING:
|
|||||||
from .ui.view import View
|
from .ui.view import View
|
||||||
from .interactions import Snowflake
|
from .interactions import Snowflake
|
||||||
from .typing import Typing
|
from .typing import Typing
|
||||||
|
from .shard_manager import Shard
|
||||||
|
from .asset import Asset
|
||||||
|
|
||||||
# Forward reference Message if it were used in type hints before its definition
|
# Forward reference Message if it were used in type hints before its definition
|
||||||
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
|
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
|
||||||
from .components import component_factory
|
from .components import component_factory
|
||||||
|
|
||||||
|
|
||||||
class User:
|
class HashableById:
|
||||||
"""Represents a Discord User.
|
"""Mixin providing equality and hashing based on the ``id`` attribute."""
|
||||||
|
|
||||||
Attributes:
|
id: str
|
||||||
id (str): The user's unique ID.
|
|
||||||
username (str): The user's username.
|
|
||||||
discriminator (str): The user's 4-digit discord-tag.
|
|
||||||
bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False.
|
|
||||||
avatar (Optional[str]): The user's avatar hash, if any.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data: dict):
|
def __eq__(self, other: object) -> bool:
|
||||||
|
return isinstance(other, self.__class__) and self.id == other.id # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
def __hash__(self) -> int: # pragma: no cover - trivial
|
||||||
|
return hash(self.id)
|
||||||
|
|
||||||
|
|
||||||
|
class User(HashableById):
|
||||||
|
"""Represents a Discord User."""
|
||||||
|
|
||||||
|
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
|
||||||
|
self._client = client_instance
|
||||||
|
if "id" not in data and "user" in data:
|
||||||
|
data = data["user"]
|
||||||
self.id: str = data["id"]
|
self.id: str = data["id"]
|
||||||
self.username: str = data["username"]
|
self.username: Optional[str] = data.get("username")
|
||||||
self.discriminator: str = data["discriminator"]
|
self.discriminator: Optional[str] = data.get("discriminator")
|
||||||
self.bot: bool = data.get("bot", False)
|
self.bot: bool = data.get("bot", False)
|
||||||
self.avatar: Optional[str] = data.get("avatar")
|
avatar_hash = data.get("avatar")
|
||||||
|
self._avatar: Optional[str] = (
|
||||||
|
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
|
||||||
|
if avatar_hash
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mention(self) -> str:
|
def mention(self) -> str:
|
||||||
@ -70,10 +99,45 @@ class User:
|
|||||||
return f"<@{self.id}>"
|
return f"<@{self.id}>"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<User id='{self.id}' username='{self.username}' discriminator='{self.discriminator}'>"
|
username = self.username or "Unknown"
|
||||||
|
disc = self.discriminator or "????"
|
||||||
|
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avatar(self) -> Optional["Asset"]:
|
||||||
|
"""Return the user's avatar as an :class:`Asset`."""
|
||||||
|
|
||||||
|
if self._avatar:
|
||||||
|
from .asset import Asset
|
||||||
|
|
||||||
|
return Asset(self._avatar, self._client)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@avatar.setter
|
||||||
|
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
self._avatar = value
|
||||||
|
elif value is None:
|
||||||
|
self._avatar = None
|
||||||
|
else:
|
||||||
|
self._avatar = value.url
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
client: Optional["Client"] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "Message":
|
||||||
|
"""Send a direct message to this user."""
|
||||||
|
|
||||||
|
target_client = client or self._client
|
||||||
|
if target_client is None:
|
||||||
|
raise DisagreementException("User.send requires a Client instance")
|
||||||
|
return await target_client.send_dm(self.id, content=content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message(HashableById):
|
||||||
"""Represents a message sent in a channel on Discord.
|
"""Represents a message sent in a channel on Discord.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
@ -96,9 +160,10 @@ class Message:
|
|||||||
self.id: str = data["id"]
|
self.id: str = data["id"]
|
||||||
self.channel_id: str = data["channel_id"]
|
self.channel_id: str = data["channel_id"]
|
||||||
self.guild_id: Optional[str] = data.get("guild_id")
|
self.guild_id: Optional[str] = data.get("guild_id")
|
||||||
self.author: User = User(data["author"])
|
self.author: User = User(data["author"], client_instance)
|
||||||
self.content: str = data["content"]
|
self.content: str = data["content"]
|
||||||
self.timestamp: str = data["timestamp"]
|
self.timestamp: str = data["timestamp"]
|
||||||
|
self.edited_timestamp: Optional[str] = data.get("edited_timestamp")
|
||||||
if data.get("components"):
|
if data.get("components"):
|
||||||
self.components: Optional[List[ActionRow]] = [
|
self.components: Optional[List[ActionRow]] = [
|
||||||
ActionRow.from_dict(c, client_instance)
|
ActionRow.from_dict(c, client_instance)
|
||||||
@ -106,21 +171,21 @@ class Message:
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.components = None
|
self.components = None
|
||||||
self.attachments: List[Attachment] = [
|
self.attachments: List[Attachment] = [
|
||||||
Attachment(a) for a in data.get("attachments", [])
|
Attachment(a) for a in data.get("attachments", [])
|
||||||
]
|
]
|
||||||
self.pinned: bool = data.get("pinned", False)
|
self.pinned: bool = data.get("pinned", False)
|
||||||
# Add other fields as needed, e.g., attachments, embeds, reactions, etc.
|
# Add other fields as needed, e.g., attachments, embeds, reactions, etc.
|
||||||
# self.mentions: List[User] = [User(u) for u in data.get("mentions", [])]
|
# self.mentions: List[User] = [User(u) for u in data.get("mentions", [])]
|
||||||
# self.mention_roles: List[str] = data.get("mention_roles", [])
|
# self.mention_roles: List[str] = data.get("mention_roles", [])
|
||||||
# self.mention_everyone: bool = data.get("mention_everyone", False)
|
# self.mention_everyone: bool = data.get("mention_everyone", False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def jump_url(self) -> str:
|
def jump_url(self) -> str:
|
||||||
"""Return a URL that jumps to this message in the Discord client."""
|
"""Return a URL that jumps to this message in the Discord client."""
|
||||||
|
|
||||||
guild_or_dm = self.guild_id or "@me"
|
guild_or_dm = self.guild_id or "@me"
|
||||||
return f"https://discord.com/channels/{guild_or_dm}/{self.channel_id}/{self.id}"
|
return f"https://discord.com/channels/{guild_or_dm}/{self.channel_id}/{self.id}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def clean_content(self) -> str:
|
def clean_content(self) -> str:
|
||||||
@ -130,6 +195,20 @@ class Message:
|
|||||||
cleaned = pattern.sub("", self.content)
|
cleaned = pattern.sub("", self.content)
|
||||||
return " ".join(cleaned.split())
|
return " ".join(cleaned.split())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def created_at(self) -> datetime.datetime:
|
||||||
|
"""Return message timestamp as a :class:`~datetime.datetime`."""
|
||||||
|
|
||||||
|
return datetime.datetime.fromisoformat(self.timestamp)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def edited_at(self) -> Optional[datetime.datetime]:
|
||||||
|
"""Return edited timestamp as :class:`~datetime.datetime` if present."""
|
||||||
|
|
||||||
|
if self.edited_timestamp is None:
|
||||||
|
return None
|
||||||
|
return datetime.datetime.fromisoformat(self.edited_timestamp)
|
||||||
|
|
||||||
async def pin(self) -> None:
|
async def pin(self) -> None:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
@ -156,6 +235,15 @@ class Message:
|
|||||||
await self._client._http.unpin_message(self.channel_id, self.id)
|
await self._client._http.unpin_message(self.channel_id, self.id)
|
||||||
self.pinned = False
|
self.pinned = False
|
||||||
|
|
||||||
|
async def crosspost(self) -> "Message":
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
|
Crossposts this message to all follower channels and return the resulting message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
data = await self._client._http.crosspost_message(self.channel_id, self.id)
|
||||||
|
return self._client.parse_message(data)
|
||||||
|
|
||||||
async def reply(
|
async def reply(
|
||||||
self,
|
self,
|
||||||
content: Optional[str] = None,
|
content: Optional[str] = None,
|
||||||
@ -194,14 +282,14 @@ class Message:
|
|||||||
ValueError: If both `embed` and `embeds` are provided.
|
ValueError: If both `embed` and `embeds` are provided.
|
||||||
"""
|
"""
|
||||||
# Determine allowed mentions for the reply
|
# Determine allowed mentions for the reply
|
||||||
if mention_author is None:
|
if mention_author is None:
|
||||||
mention_author = getattr(self._client, "mention_replies", False)
|
mention_author = getattr(self._client, "mention_replies", False)
|
||||||
|
|
||||||
if allowed_mentions is None:
|
if allowed_mentions is None:
|
||||||
allowed_mentions = dict(getattr(self._client, "allowed_mentions", {}) or {})
|
allowed_mentions = dict(getattr(self._client, "allowed_mentions", {}) or {})
|
||||||
else:
|
else:
|
||||||
allowed_mentions = dict(allowed_mentions)
|
allowed_mentions = dict(allowed_mentions)
|
||||||
allowed_mentions.setdefault("replied_user", mention_author)
|
allowed_mentions.setdefault("replied_user", mention_author)
|
||||||
|
|
||||||
# Client.send_message is already updated to handle these parameters
|
# Client.send_message is already updated to handle these parameters
|
||||||
return await self._client.send_message(
|
return await self._client.send_message(
|
||||||
@ -624,38 +712,72 @@ class Attachment:
|
|||||||
|
|
||||||
|
|
||||||
class File:
|
class File:
|
||||||
"""Represents a file to be uploaded."""
|
"""Represents a file to be uploaded.
|
||||||
|
|
||||||
def __init__(self, filename: str, data: bytes):
|
Parameters
|
||||||
self.filename = filename
|
----------
|
||||||
self.data = data
|
fp:
|
||||||
|
A file path, file-like object, or bytes-like object containing the
|
||||||
|
data to upload.
|
||||||
|
filename:
|
||||||
|
Optional name of the file. If not provided and ``fp`` is a path or has
|
||||||
|
a ``name`` attribute, the name will be inferred.
|
||||||
|
spoiler:
|
||||||
|
When ``True`` the filename will be prefixed with ``"SPOILER_"``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fp: Union[str, bytes, os.PathLike[Any], IO[bytes]],
|
||||||
|
*,
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
spoiler: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if isinstance(fp, (str, os.PathLike)):
|
||||||
|
self.data = open(fp, "rb")
|
||||||
|
inferred = os.path.basename(fp)
|
||||||
|
elif isinstance(fp, bytes):
|
||||||
|
self.data = io.BytesIO(fp)
|
||||||
|
inferred = None
|
||||||
|
else:
|
||||||
|
self.data = fp
|
||||||
|
inferred = getattr(fp, "name", None)
|
||||||
|
|
||||||
|
name = filename or inferred
|
||||||
|
if name is None:
|
||||||
|
raise ValueError("filename could not be inferred")
|
||||||
|
|
||||||
|
if spoiler and not name.startswith("SPOILER_"):
|
||||||
|
name = f"SPOILER_{name}"
|
||||||
|
self.filename = name
|
||||||
|
self.spoiler = spoiler
|
||||||
|
|
||||||
|
|
||||||
class AllowedMentions:
|
class AllowedMentions:
|
||||||
"""Represents allowed mentions for a message or interaction response."""
|
"""Represents allowed mentions for a message or interaction response."""
|
||||||
|
|
||||||
def __init__(self, data: Dict[str, Any]):
|
def __init__(self, data: Dict[str, Any]):
|
||||||
self.parse: List[str] = data.get("parse", [])
|
self.parse: List[str] = data.get("parse", [])
|
||||||
self.roles: List[str] = data.get("roles", [])
|
self.roles: List[str] = data.get("roles", [])
|
||||||
self.users: List[str] = data.get("users", [])
|
self.users: List[str] = data.get("users", [])
|
||||||
self.replied_user: bool = data.get("replied_user", False)
|
self.replied_user: bool = data.get("replied_user", False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def all(cls) -> "AllowedMentions":
|
def all(cls) -> "AllowedMentions":
|
||||||
"""Return an instance allowing all mention types."""
|
"""Return an instance allowing all mention types."""
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
{
|
{
|
||||||
"parse": ["users", "roles", "everyone"],
|
"parse": ["users", "roles", "everyone"],
|
||||||
"replied_user": True,
|
"replied_user": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def none(cls) -> "AllowedMentions":
|
def none(cls) -> "AllowedMentions":
|
||||||
"""Return an instance disallowing all mentions."""
|
"""Return an instance disallowing all mentions."""
|
||||||
|
|
||||||
return cls({"parse": [], "replied_user": False})
|
return cls({"parse": [], "replied_user": False})
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
payload: Dict[str, Any] = {"parse": self.parse}
|
payload: Dict[str, Any] = {"parse": self.parse}
|
||||||
@ -697,7 +819,12 @@ class Role:
|
|||||||
self.name: str = data["name"]
|
self.name: str = data["name"]
|
||||||
self.color: int = data["color"]
|
self.color: int = data["color"]
|
||||||
self.hoist: bool = data["hoist"]
|
self.hoist: bool = data["hoist"]
|
||||||
self.icon: Optional[str] = data.get("icon")
|
icon_hash = data.get("icon")
|
||||||
|
self._icon: Optional[str] = (
|
||||||
|
f"https://cdn.discordapp.com/role-icons/{self.id}/{icon_hash}.png"
|
||||||
|
if icon_hash
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.unicode_emoji: Optional[str] = data.get("unicode_emoji")
|
self.unicode_emoji: Optional[str] = data.get("unicode_emoji")
|
||||||
self.position: int = data["position"]
|
self.position: int = data["position"]
|
||||||
self.permissions: str = data["permissions"] # String of bitwise permissions
|
self.permissions: str = data["permissions"] # String of bitwise permissions
|
||||||
@ -715,6 +842,23 @@ class Role:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Role id='{self.id}' name='{self.name}'>"
|
return f"<Role id='{self.id}' name='{self.name}'>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def icon(self) -> Optional["Asset"]:
|
||||||
|
if self._icon:
|
||||||
|
from .asset import Asset
|
||||||
|
|
||||||
|
return Asset(self._icon, None)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@icon.setter
|
||||||
|
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
self._icon = value
|
||||||
|
elif value is None:
|
||||||
|
self._icon = None
|
||||||
|
else:
|
||||||
|
self._icon = value.url
|
||||||
|
|
||||||
|
|
||||||
class Member(User): # Member inherits from User
|
class Member(User): # Member inherits from User
|
||||||
"""Represents a Guild Member.
|
"""Represents a Guild Member.
|
||||||
@ -743,12 +887,18 @@ class Member(User): # Member inherits from User
|
|||||||
) # Pass user_data or data if user_data is empty
|
) # Pass user_data or data if user_data is empty
|
||||||
|
|
||||||
self.nick: Optional[str] = data.get("nick")
|
self.nick: Optional[str] = data.get("nick")
|
||||||
self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash
|
avatar_hash = data.get("avatar")
|
||||||
self.roles: List[str] = data.get("roles", []) # List of role IDs
|
if avatar_hash:
|
||||||
self.joined_at: str = data["joined_at"] # ISO8601 timestamp
|
guild_id = data.get("guild_id")
|
||||||
self.premium_since: Optional[str] = data.get(
|
if guild_id:
|
||||||
"premium_since"
|
self._avatar = f"https://cdn.discordapp.com/guilds/{guild_id}/users/{self.id}/avatars/{avatar_hash}.png"
|
||||||
) # ISO8601 timestamp
|
else:
|
||||||
|
self._avatar = (
|
||||||
|
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
|
||||||
|
)
|
||||||
|
self.roles: List[str] = data.get("roles", [])
|
||||||
|
self.joined_at: str = data["joined_at"]
|
||||||
|
self.premium_since: Optional[str] = data.get("premium_since")
|
||||||
self.deaf: bool = data.get("deaf", False)
|
self.deaf: bool = data.get("deaf", False)
|
||||||
self.mute: bool = data.get("mute", False)
|
self.mute: bool = data.get("mute", False)
|
||||||
self.pending: bool = data.get("pending", False)
|
self.pending: bool = data.get("pending", False)
|
||||||
@ -758,6 +908,7 @@ class Member(User): # Member inherits from User
|
|||||||
self.communication_disabled_until: Optional[str] = data.get(
|
self.communication_disabled_until: Optional[str] = data.get(
|
||||||
"communication_disabled_until"
|
"communication_disabled_until"
|
||||||
) # ISO8601 timestamp
|
) # ISO8601 timestamp
|
||||||
|
self.voice_state = data.get("voice_state")
|
||||||
|
|
||||||
# If 'user' object was present, ensure User attributes are from there
|
# If 'user' object was present, ensure User attributes are from there
|
||||||
if user_data:
|
if user_data:
|
||||||
@ -772,11 +923,30 @@ class Member(User): # Member inherits from User
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
|
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avatar(self) -> Optional["Asset"]:
|
||||||
|
"""Return the member's avatar as an :class:`Asset`."""
|
||||||
|
|
||||||
|
if self._avatar:
|
||||||
|
from .asset import Asset
|
||||||
|
|
||||||
|
return Asset(self._avatar, self._client)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@avatar.setter
|
||||||
|
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
self._avatar = value
|
||||||
|
elif value is None:
|
||||||
|
self._avatar = None
|
||||||
|
else:
|
||||||
|
self._avatar = value.url
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_name(self) -> str:
|
def display_name(self) -> str:
|
||||||
"""Return the nickname if set, otherwise the username."""
|
"""Return the nickname if set, otherwise the username."""
|
||||||
|
|
||||||
return self.nick or self.username
|
return self.nick or self.username or ""
|
||||||
|
|
||||||
async def kick(self, *, reason: Optional[str] = None) -> None:
|
async def kick(self, *, reason: Optional[str] = None) -> None:
|
||||||
if not self.guild_id or not self._client:
|
if not self.guild_id or not self._client:
|
||||||
@ -838,6 +1008,41 @@ class Member(User): # Member inherits from User
|
|||||||
|
|
||||||
return max(role_objects, key=lambda r: r.position)
|
return max(role_objects, key=lambda r: r.position)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guild_permissions(self) -> "Permissions":
|
||||||
|
"""Return the member's guild-level permissions."""
|
||||||
|
|
||||||
|
if not self.guild_id or not self._client:
|
||||||
|
return Permissions(0)
|
||||||
|
|
||||||
|
guild = self._client.get_guild(self.guild_id)
|
||||||
|
if guild is None:
|
||||||
|
return Permissions(0)
|
||||||
|
|
||||||
|
base = Permissions(0)
|
||||||
|
|
||||||
|
everyone = guild.get_role(guild.id)
|
||||||
|
if everyone is not None:
|
||||||
|
base |= Permissions(int(everyone.permissions))
|
||||||
|
|
||||||
|
for rid in self.roles:
|
||||||
|
role = guild.get_role(rid)
|
||||||
|
if role is not None:
|
||||||
|
base |= Permissions(int(role.permissions))
|
||||||
|
|
||||||
|
if base & Permissions.ADMINISTRATOR:
|
||||||
|
return Permissions(~0)
|
||||||
|
|
||||||
|
return base
|
||||||
|
|
||||||
|
@property
|
||||||
|
def voice(self) -> Optional["VoiceState"]:
|
||||||
|
"""Return the member's cached voice state as a :class:`VoiceState`."""
|
||||||
|
|
||||||
|
if self.voice_state is None:
|
||||||
|
return None
|
||||||
|
return VoiceState.from_dict(self.voice_state)
|
||||||
|
|
||||||
|
|
||||||
class PartialEmoji:
|
class PartialEmoji:
|
||||||
"""Represents a partial emoji, often used in components or reactions.
|
"""Represents a partial emoji, often used in components or reactions.
|
||||||
@ -1035,7 +1240,7 @@ class PermissionOverwrite:
|
|||||||
return f"<PermissionOverwrite id='{self.id}' type='{self.type.name if hasattr(self.type, 'name') else self._type_val}' allow='{self.allow}' deny='{self.deny}'>"
|
return f"<PermissionOverwrite id='{self.id}' type='{self.type.name if hasattr(self.type, 'name') else self._type_val}' allow='{self.allow}' deny='{self.deny}'>"
|
||||||
|
|
||||||
|
|
||||||
class Guild:
|
class Guild(HashableById):
|
||||||
"""Represents a Discord Guild (Server).
|
"""Represents a Discord Guild (Server).
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
@ -1075,15 +1280,42 @@ class Guild:
|
|||||||
nsfw_level (GuildNSFWLevel): Guild NSFW level.
|
nsfw_level (GuildNSFWLevel): Guild NSFW level.
|
||||||
stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model)
|
stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model)
|
||||||
premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled.
|
premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled.
|
||||||
|
text_channels (List[TextChannel]): List of text-based channels in this guild.
|
||||||
|
voice_channels (List[VoiceChannel]): List of voice-based channels in this guild.
|
||||||
|
category_channels (List[CategoryChannel]): List of category channels in this guild.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
client_instance: "Client",
|
||||||
|
*,
|
||||||
|
shard_id: Optional[int] = None,
|
||||||
|
):
|
||||||
self._client: "Client" = client_instance
|
self._client: "Client" = client_instance
|
||||||
|
self._shard_id: Optional[int] = (
|
||||||
|
shard_id if shard_id is not None else data.get("shard_id")
|
||||||
|
)
|
||||||
self.id: str = data["id"]
|
self.id: str = data["id"]
|
||||||
self.name: str = data["name"]
|
self.name: str = data["name"]
|
||||||
self.icon: Optional[str] = data.get("icon")
|
icon_hash = data.get("icon")
|
||||||
self.splash: Optional[str] = data.get("splash")
|
self._icon: Optional[str] = (
|
||||||
self.discovery_splash: Optional[str] = data.get("discovery_splash")
|
f"https://cdn.discordapp.com/icons/{self.id}/{icon_hash}.png"
|
||||||
|
if icon_hash
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
splash_hash = data.get("splash")
|
||||||
|
self._splash: Optional[str] = (
|
||||||
|
f"https://cdn.discordapp.com/splashes/{self.id}/{splash_hash}.png"
|
||||||
|
if splash_hash
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
discovery_hash = data.get("discovery_splash")
|
||||||
|
self._discovery_splash: Optional[str] = (
|
||||||
|
f"https://cdn.discordapp.com/discovery-splashes/{self.id}/{discovery_hash}.png"
|
||||||
|
if discovery_hash
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.owner: Optional[bool] = data.get("owner")
|
self.owner: Optional[bool] = data.get("owner")
|
||||||
self.owner_id: str = data["owner_id"]
|
self.owner_id: str = data["owner_id"]
|
||||||
self.permissions: Optional[str] = data.get("permissions")
|
self.permissions: Optional[str] = data.get("permissions")
|
||||||
@ -1120,7 +1352,12 @@ class Guild:
|
|||||||
self.max_members: Optional[int] = data.get("max_members")
|
self.max_members: Optional[int] = data.get("max_members")
|
||||||
self.vanity_url_code: Optional[str] = data.get("vanity_url_code")
|
self.vanity_url_code: Optional[str] = data.get("vanity_url_code")
|
||||||
self.description: Optional[str] = data.get("description")
|
self.description: Optional[str] = data.get("description")
|
||||||
self.banner: Optional[str] = data.get("banner")
|
banner_hash = data.get("banner")
|
||||||
|
self._banner: Optional[str] = (
|
||||||
|
f"https://cdn.discordapp.com/banners/{self.id}/{banner_hash}.png"
|
||||||
|
if banner_hash
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"])
|
self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"])
|
||||||
self.premium_subscription_count: Optional[int] = data.get(
|
self.premium_subscription_count: Optional[int] = data.get(
|
||||||
"premium_subscription_count"
|
"premium_subscription_count"
|
||||||
@ -1159,6 +1396,28 @@ class Guild:
|
|||||||
getattr(client_instance, "member_cache_flags", MemberCacheFlags())
|
getattr(client_instance, "member_cache_flags", MemberCacheFlags())
|
||||||
)
|
)
|
||||||
self._threads: Dict[str, "Thread"] = {}
|
self._threads: Dict[str, "Thread"] = {}
|
||||||
|
self.text_channels: List["TextChannel"] = []
|
||||||
|
self.voice_channels: List["VoiceChannel"] = []
|
||||||
|
self.category_channels: List["CategoryChannel"] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shard_id(self) -> Optional[int]:
|
||||||
|
"""ID of the shard that received this guild, if any."""
|
||||||
|
|
||||||
|
return self._shard_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shard(self) -> Optional["Shard"]:
|
||||||
|
"""The :class:`Shard` this guild belongs to."""
|
||||||
|
|
||||||
|
if self._shard_id is None:
|
||||||
|
return None
|
||||||
|
manager = getattr(self._client, "_shard_manager", None)
|
||||||
|
if not manager:
|
||||||
|
return None
|
||||||
|
if 0 <= self._shard_id < len(manager.shards):
|
||||||
|
return manager.shards[self._shard_id]
|
||||||
|
return None
|
||||||
|
|
||||||
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
||||||
return self._channels.get(channel_id)
|
return self._channels.get(channel_id)
|
||||||
@ -1185,7 +1444,7 @@ class Guild:
|
|||||||
|
|
||||||
lowered = name.lower()
|
lowered = name.lower()
|
||||||
for member in self._members.values():
|
for member in self._members.values():
|
||||||
if member.username.lower() == lowered:
|
if member.username and member.username.lower() == lowered:
|
||||||
return member
|
return member
|
||||||
if member.nick and member.nick.lower() == lowered:
|
if member.nick and member.nick.lower() == lowered:
|
||||||
return member
|
return member
|
||||||
@ -1194,9 +1453,86 @@ class Guild:
|
|||||||
def get_role(self, role_id: str) -> Optional[Role]:
|
def get_role(self, role_id: str) -> Optional[Role]:
|
||||||
return next((role for role in self.roles if role.id == role_id), None)
|
return next((role for role in self.roles if role.id == role_id), None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def me(self) -> Optional[Member]:
|
||||||
|
"""The member object for the connected bot in this guild, if present."""
|
||||||
|
|
||||||
|
client_user = getattr(self._client, "user", None)
|
||||||
|
if not client_user:
|
||||||
|
return None
|
||||||
|
return self.get_member(client_user.id)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Guild id='{self.id}' name='{self.name}'>"
|
return f"<Guild id='{self.id}' name='{self.name}'>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def icon(self) -> Optional["Asset"]:
|
||||||
|
if self._icon:
|
||||||
|
from .asset import Asset
|
||||||
|
|
||||||
|
return Asset(self._icon, self._client)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@icon.setter
|
||||||
|
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
self._icon = value
|
||||||
|
elif value is None:
|
||||||
|
self._icon = None
|
||||||
|
else:
|
||||||
|
self._icon = value.url
|
||||||
|
|
||||||
|
@property
|
||||||
|
def splash(self) -> Optional["Asset"]:
|
||||||
|
if self._splash:
|
||||||
|
from .asset import Asset
|
||||||
|
|
||||||
|
return Asset(self._splash, self._client)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@splash.setter
|
||||||
|
def splash(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
self._splash = value
|
||||||
|
elif value is None:
|
||||||
|
self._splash = None
|
||||||
|
else:
|
||||||
|
self._splash = value.url
|
||||||
|
|
||||||
|
@property
|
||||||
|
def discovery_splash(self) -> Optional["Asset"]:
|
||||||
|
if self._discovery_splash:
|
||||||
|
from .asset import Asset
|
||||||
|
|
||||||
|
return Asset(self._discovery_splash, self._client)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@discovery_splash.setter
|
||||||
|
def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
self._discovery_splash = value
|
||||||
|
elif value is None:
|
||||||
|
self._discovery_splash = None
|
||||||
|
else:
|
||||||
|
self._discovery_splash = value.url
|
||||||
|
|
||||||
|
@property
|
||||||
|
def banner(self) -> Optional["Asset"]:
|
||||||
|
if self._banner:
|
||||||
|
from .asset import Asset
|
||||||
|
|
||||||
|
return Asset(self._banner, self._client)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@banner.setter
|
||||||
|
def banner(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
self._banner = value
|
||||||
|
elif value is None:
|
||||||
|
self._banner = None
|
||||||
|
else:
|
||||||
|
self._banner = value.url
|
||||||
|
|
||||||
async def fetch_widget(self) -> Dict[str, Any]:
|
async def fetch_widget(self) -> Dict[str, Any]:
|
||||||
"""|coro| Fetch this guild's widget settings."""
|
"""|coro| Fetch this guild's widget settings."""
|
||||||
|
|
||||||
@ -1250,8 +1586,82 @@ class Guild:
|
|||||||
del self._client._gateway._member_chunk_requests[nonce]
|
del self._client._gateway._member_chunk_requests[nonce]
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def prune_members(self, days: int, *, compute_count: bool = True) -> int:
|
||||||
|
"""|coro| Remove inactive members from the guild.
|
||||||
|
|
||||||
class Channel:
|
Parameters
|
||||||
|
----------
|
||||||
|
days: int
|
||||||
|
Number of days of inactivity required to be pruned.
|
||||||
|
compute_count: bool
|
||||||
|
Whether to return the number of members pruned.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
The number of members pruned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await self._client._http.begin_guild_prune(
|
||||||
|
self.id, days=days, compute_count=compute_count
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_text_channel(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
**options: Any,
|
||||||
|
) -> "TextChannel":
|
||||||
|
"""|coro| Create a new text channel in this guild."""
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {"name": name, "type": ChannelType.GUILD_TEXT.value}
|
||||||
|
payload.update(options)
|
||||||
|
data = await self._client._http.create_guild_channel(
|
||||||
|
self.id, payload, reason=reason
|
||||||
|
)
|
||||||
|
return cast("TextChannel", self._client.parse_channel(data))
|
||||||
|
|
||||||
|
async def create_voice_channel(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
**options: Any,
|
||||||
|
) -> "VoiceChannel":
|
||||||
|
"""|coro| Create a new voice channel in this guild."""
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"type": ChannelType.GUILD_VOICE.value,
|
||||||
|
}
|
||||||
|
payload.update(options)
|
||||||
|
data = await self._client._http.create_guild_channel(
|
||||||
|
self.id, payload, reason=reason
|
||||||
|
)
|
||||||
|
return cast("VoiceChannel", self._client.parse_channel(data))
|
||||||
|
|
||||||
|
async def create_category(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
**options: Any,
|
||||||
|
) -> "CategoryChannel":
|
||||||
|
"""|coro| Create a new category channel in this guild."""
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"type": ChannelType.GUILD_CATEGORY.value,
|
||||||
|
}
|
||||||
|
payload.update(options)
|
||||||
|
data = await self._client._http.create_guild_channel(
|
||||||
|
self.id, payload, reason=reason
|
||||||
|
)
|
||||||
|
return cast("CategoryChannel", self._client.parse_channel(data))
|
||||||
|
|
||||||
|
|
||||||
|
class Channel(HashableById):
|
||||||
"""Base class for Discord channels."""
|
"""Base class for Discord channels."""
|
||||||
|
|
||||||
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
||||||
@ -1531,6 +1941,31 @@ class TextChannel(Channel, Messageable):
|
|||||||
data = await self._client._http.start_thread_without_message(self.id, payload)
|
data = await self._client._http.start_thread_without_message(self.id, payload)
|
||||||
return cast("Thread", self._client.parse_channel(data))
|
return cast("Thread", self._client.parse_channel(data))
|
||||||
|
|
||||||
|
async def create_invite(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
max_age: Optional[int] = None,
|
||||||
|
max_uses: Optional[int] = None,
|
||||||
|
temporary: Optional[bool] = None,
|
||||||
|
unique: Optional[bool] = None,
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
) -> "Invite":
|
||||||
|
"""|coro| Create an invite to this channel."""
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {}
|
||||||
|
if max_age is not None:
|
||||||
|
payload["max_age"] = max_age
|
||||||
|
if max_uses is not None:
|
||||||
|
payload["max_uses"] = max_uses
|
||||||
|
if temporary is not None:
|
||||||
|
payload["temporary"] = temporary
|
||||||
|
if unique is not None:
|
||||||
|
payload["unique"] = unique
|
||||||
|
|
||||||
|
return await self._client._http.create_channel_invite(
|
||||||
|
self.id, payload, reason=reason
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VoiceChannel(Channel):
|
class VoiceChannel(Channel):
|
||||||
"""Represents a guild voice channel or stage voice channel."""
|
"""Represents a guild voice channel or stage voice channel."""
|
||||||
@ -1830,7 +2265,12 @@ class Webhook:
|
|||||||
self.guild_id: Optional[str] = data.get("guild_id")
|
self.guild_id: Optional[str] = data.get("guild_id")
|
||||||
self.channel_id: Optional[str] = data.get("channel_id")
|
self.channel_id: Optional[str] = data.get("channel_id")
|
||||||
self.name: Optional[str] = data.get("name")
|
self.name: Optional[str] = data.get("name")
|
||||||
self.avatar: Optional[str] = data.get("avatar")
|
avatar_hash = data.get("avatar")
|
||||||
|
self._avatar: Optional[str] = (
|
||||||
|
f"https://cdn.discordapp.com/webhooks/{self.id}/{avatar_hash}.png"
|
||||||
|
if avatar_hash
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.token: Optional[str] = data.get("token")
|
self.token: Optional[str] = data.get("token")
|
||||||
self.application_id: Optional[str] = data.get("application_id")
|
self.application_id: Optional[str] = data.get("application_id")
|
||||||
self.url: Optional[str] = data.get("url")
|
self.url: Optional[str] = data.get("url")
|
||||||
@ -1839,6 +2279,25 @@ class Webhook:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Webhook id='{self.id}' name='{self.name}'>"
|
return f"<Webhook id='{self.id}' name='{self.name}'>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avatar(self) -> Optional["Asset"]:
|
||||||
|
"""Return the webhook's avatar as an :class:`Asset`."""
|
||||||
|
|
||||||
|
if self._avatar:
|
||||||
|
from .asset import Asset
|
||||||
|
|
||||||
|
return Asset(self._avatar, self._client)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@avatar.setter
|
||||||
|
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
self._avatar = value
|
||||||
|
elif value is None:
|
||||||
|
self._avatar = None
|
||||||
|
else:
|
||||||
|
self._avatar = value.url
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_url(
|
def from_url(
|
||||||
cls, url: str, session: Optional[aiohttp.ClientSession] = None
|
cls, url: str, session: Optional[aiohttp.ClientSession] = None
|
||||||
@ -1866,6 +2325,33 @@ class Webhook:
|
|||||||
|
|
||||||
return cls({"id": webhook_id, "token": token, "url": url})
|
return cls({"id": webhook_id, "token": token, "url": url})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_token(
|
||||||
|
cls,
|
||||||
|
webhook_id: str,
|
||||||
|
token: str,
|
||||||
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
|
) -> "Webhook":
|
||||||
|
"""Create a minimal :class:`Webhook` from an ID and token.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
webhook_id:
|
||||||
|
The ID of the webhook.
|
||||||
|
token:
|
||||||
|
The webhook token.
|
||||||
|
session:
|
||||||
|
Unused for now. Present for API compatibility.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Webhook
|
||||||
|
A webhook instance containing only the ``id``, ``token`` and ``url``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
url = f"https://discord.com/api/webhooks/{webhook_id}/{token}"
|
||||||
|
return cls({"id": webhook_id, "token": token, "url": url})
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
self,
|
self,
|
||||||
content: Optional[str] = None,
|
content: Optional[str] = None,
|
||||||
@ -2471,7 +2957,7 @@ class PresenceUpdate:
|
|||||||
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
|
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
|
||||||
):
|
):
|
||||||
self._client = client_instance
|
self._client = client_instance
|
||||||
self.user = User(data["user"])
|
self.user = User(data["user"], client_instance)
|
||||||
self.guild_id: Optional[str] = data.get("guild_id")
|
self.guild_id: Optional[str] = data.get("guild_id")
|
||||||
self.status: Optional[str] = data.get("status")
|
self.status: Optional[str] = data.get("status")
|
||||||
self.activities: List[Activity] = []
|
self.activities: List[Activity] = []
|
||||||
@ -2491,7 +2977,7 @@ class PresenceUpdate:
|
|||||||
return f"<PresenceUpdate user_id='{self.user.id}' guild_id='{self.guild_id}' status='{self.status}'>"
|
return f"<PresenceUpdate user_id='{self.user.id}' guild_id='{self.guild_id}' status='{self.status}'>"
|
||||||
|
|
||||||
|
|
||||||
class TypingStart:
|
class TypingStart:
|
||||||
"""Represents a TYPING_START event."""
|
"""Represents a TYPING_START event."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -2504,39 +2990,78 @@ class TypingStart:
|
|||||||
self.timestamp: int = data["timestamp"]
|
self.timestamp: int = data["timestamp"]
|
||||||
self.member: Optional[Member] = (
|
self.member: Optional[Member] = (
|
||||||
Member(data["member"], client_instance) if data.get("member") else None
|
Member(data["member"], client_instance) if data.get("member") else None
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<TypingStart channel_id='{self.channel_id}' user_id='{self.user_id}'>"
|
return f"<TypingStart channel_id='{self.channel_id}' user_id='{self.user_id}'>"
|
||||||
|
|
||||||
|
|
||||||
class VoiceStateUpdate:
|
class VoiceStateUpdate:
|
||||||
"""Represents a VOICE_STATE_UPDATE event."""
|
"""Represents a VOICE_STATE_UPDATE event."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
|
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
|
||||||
):
|
):
|
||||||
self._client = client_instance
|
self._client = client_instance
|
||||||
self.guild_id: Optional[str] = data.get("guild_id")
|
self.guild_id: Optional[str] = data.get("guild_id")
|
||||||
self.channel_id: Optional[str] = data.get("channel_id")
|
self.channel_id: Optional[str] = data.get("channel_id")
|
||||||
self.user_id: str = data["user_id"]
|
self.user_id: str = data["user_id"]
|
||||||
self.member: Optional[Member] = (
|
self.member: Optional[Member] = (
|
||||||
Member(data["member"], client_instance) if data.get("member") else None
|
Member(data["member"], client_instance) if data.get("member") else None
|
||||||
)
|
)
|
||||||
self.session_id: str = data["session_id"]
|
self.session_id: str = data["session_id"]
|
||||||
self.deaf: bool = data.get("deaf", False)
|
self.deaf: bool = data.get("deaf", False)
|
||||||
self.mute: bool = data.get("mute", False)
|
self.mute: bool = data.get("mute", False)
|
||||||
self.self_deaf: bool = data.get("self_deaf", False)
|
self.self_deaf: bool = data.get("self_deaf", False)
|
||||||
self.self_mute: bool = data.get("self_mute", False)
|
self.self_mute: bool = data.get("self_mute", False)
|
||||||
self.self_stream: Optional[bool] = data.get("self_stream")
|
self.self_stream: Optional[bool] = data.get("self_stream")
|
||||||
self.self_video: bool = data.get("self_video", False)
|
self.self_video: bool = data.get("self_video", False)
|
||||||
self.suppress: bool = data.get("suppress", False)
|
self.suppress: bool = data.get("suppress", False)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"<VoiceStateUpdate guild_id='{self.guild_id}' user_id='{self.user_id}' "
|
f"<VoiceStateUpdate guild_id='{self.guild_id}' user_id='{self.user_id}' "
|
||||||
f"channel_id='{self.channel_id}'>"
|
f"channel_id='{self.channel_id}'>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VoiceState:
|
||||||
|
"""Represents a cached voice state for a member."""
|
||||||
|
|
||||||
|
guild_id: Optional[str]
|
||||||
|
channel_id: Optional[str]
|
||||||
|
user_id: Optional[str]
|
||||||
|
session_id: Optional[str]
|
||||||
|
deaf: bool = False
|
||||||
|
mute: bool = False
|
||||||
|
self_deaf: bool = False
|
||||||
|
self_mute: bool = False
|
||||||
|
self_stream: Optional[bool] = None
|
||||||
|
self_video: bool = False
|
||||||
|
suppress: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "VoiceState":
|
||||||
|
return cls(
|
||||||
|
guild_id=data.get("guild_id"),
|
||||||
|
channel_id=data.get("channel_id"),
|
||||||
|
user_id=data.get("user_id"),
|
||||||
|
session_id=data.get("session_id"),
|
||||||
|
deaf=data.get("deaf", False),
|
||||||
|
mute=data.get("mute", False),
|
||||||
|
self_deaf=data.get("self_deaf", False),
|
||||||
|
self_mute=data.get("self_mute", False),
|
||||||
|
self_stream=data.get("self_stream"),
|
||||||
|
self_video=data.get("self_video", False),
|
||||||
|
suppress=data.get("suppress", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<VoiceState guild_id='{self.guild_id}' user_id='{self.user_id}' "
|
||||||
|
f"channel_id='{self.channel_id}'>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Reaction:
|
class Reaction:
|
||||||
@ -2631,6 +3156,18 @@ class Invite:
|
|||||||
return f"<Invite code='{self.code}' guild_id='{self.guild_id}' channel_id='{self.channel_id}'>"
|
return f"<Invite code='{self.code}' guild_id='{self.guild_id}' channel_id='{self.channel_id}'>"
|
||||||
|
|
||||||
|
|
||||||
|
class InviteDelete:
|
||||||
|
"""Represents an INVITE_DELETE event."""
|
||||||
|
|
||||||
|
def __init__(self, data: Dict[str, Any]):
|
||||||
|
self.channel_id: str = data["channel_id"]
|
||||||
|
self.guild_id: Optional[str] = data.get("guild_id")
|
||||||
|
self.code: str = data["code"]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<InviteDelete code='{self.code}' guild_id='{self.guild_id}' channel_id='{self.channel_id}'>"
|
||||||
|
|
||||||
|
|
||||||
class GuildMemberRemove:
|
class GuildMemberRemove:
|
||||||
"""Represents a GUILD_MEMBER_REMOVE event."""
|
"""Represents a GUILD_MEMBER_REMOVE event."""
|
||||||
|
|
||||||
|
19
disagreement/object.py
Normal file
19
disagreement/object.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
class Object:
|
||||||
|
"""A minimal wrapper around a Discord snowflake ID."""
|
||||||
|
|
||||||
|
__slots__ = ("id",)
|
||||||
|
|
||||||
|
def __init__(self, object_id: int) -> None:
|
||||||
|
self.id = int(object_id)
|
||||||
|
|
||||||
|
def __int__(self) -> int:
|
||||||
|
return self.id
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.id)
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
return isinstance(other, Object) and self.id == other.id
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Object id={self.id}>"
|
@ -57,6 +57,15 @@ class Permissions(IntFlag):
|
|||||||
USE_EXTERNAL_SOUNDS = 1 << 45
|
USE_EXTERNAL_SOUNDS = 1 << 45
|
||||||
SEND_VOICE_MESSAGES = 1 << 46
|
SEND_VOICE_MESSAGES = 1 << 46
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all(cls) -> "Permissions":
|
||||||
|
"""Return a ``Permissions`` object with every permission bit enabled."""
|
||||||
|
|
||||||
|
value = 0
|
||||||
|
for perm in cls:
|
||||||
|
value |= perm.value
|
||||||
|
return cls(value)
|
||||||
|
|
||||||
|
|
||||||
def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int:
|
def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int:
|
||||||
"""Return a combined integer value for multiple permissions."""
|
"""Return a combined integer value for multiple permissions."""
|
||||||
|
@ -3,7 +3,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING
|
from typing import Any, AsyncIterator, Dict, Iterable, Optional, TYPE_CHECKING, Callable
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Discord epoch in milliseconds (2015-01-01T00:00:00Z)
|
||||||
|
DISCORD_EPOCH = 1420070400000
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover - for type hinting only
|
if TYPE_CHECKING: # pragma: no cover - for type hinting only
|
||||||
from .models import Message, TextChannel
|
from .models import Message, TextChannel
|
||||||
@ -14,6 +18,27 @@ def utcnow() -> datetime:
|
|||||||
return datetime.now(timezone.utc)
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def find(predicate: Callable[[Any], bool], iterable: Iterable[Any]) -> Optional[Any]:
|
||||||
|
"""Return the first element in ``iterable`` matching the ``predicate``."""
|
||||||
|
for element in iterable:
|
||||||
|
if predicate(element):
|
||||||
|
return element
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get(iterable: Iterable[Any], **attrs: Any) -> Optional[Any]:
|
||||||
|
"""Return the first element with matching attributes."""
|
||||||
|
def predicate(elem: Any) -> bool:
|
||||||
|
return all(getattr(elem, attr, None) == value for attr, value in attrs.items())
|
||||||
|
return find(predicate, iterable)
|
||||||
|
|
||||||
|
|
||||||
|
def snowflake_time(snowflake: int) -> datetime:
|
||||||
|
"""Return the creation time of a Discord snowflake."""
|
||||||
|
timestamp_ms = (snowflake >> 22) + DISCORD_EPOCH
|
||||||
|
return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
async def message_pager(
|
async def message_pager(
|
||||||
channel: "TextChannel",
|
channel: "TextChannel",
|
||||||
*,
|
*,
|
||||||
@ -21,32 +46,11 @@ async def message_pager(
|
|||||||
before: Optional[str] = None,
|
before: Optional[str] = None,
|
||||||
after: Optional[str] = None,
|
after: Optional[str] = None,
|
||||||
) -> AsyncIterator["Message"]:
|
) -> AsyncIterator["Message"]:
|
||||||
"""Asynchronously paginate a channel's messages.
|
"""Asynchronously paginate a channel's messages."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
channel:
|
|
||||||
The :class:`TextChannel` to fetch messages from.
|
|
||||||
limit:
|
|
||||||
The maximum number of messages to yield. ``None`` fetches until no
|
|
||||||
more messages are returned.
|
|
||||||
before:
|
|
||||||
Fetch messages with IDs less than this snowflake.
|
|
||||||
after:
|
|
||||||
Fetch messages with IDs greater than this snowflake.
|
|
||||||
|
|
||||||
Yields
|
|
||||||
------
|
|
||||||
Message
|
|
||||||
Messages in the channel, oldest first.
|
|
||||||
"""
|
|
||||||
|
|
||||||
remaining = limit
|
remaining = limit
|
||||||
last_id = before
|
last_id = before
|
||||||
while remaining is None or remaining > 0:
|
while remaining is None or remaining > 0:
|
||||||
fetch_limit = 100
|
fetch_limit = min(100, remaining) if remaining is not None else 100
|
||||||
if remaining is not None:
|
|
||||||
fetch_limit = min(fetch_limit, remaining)
|
|
||||||
|
|
||||||
params: Dict[str, Any] = {"limit": fetch_limit}
|
params: Dict[str, Any] = {"limit": fetch_limit}
|
||||||
if last_id is not None:
|
if last_id is not None:
|
||||||
@ -71,3 +75,52 @@ async def message_pager(
|
|||||||
remaining -= 1
|
remaining -= 1
|
||||||
if remaining == 0:
|
if remaining == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class Paginator:
|
||||||
|
"""Helper to split text into pages under a character limit."""
|
||||||
|
|
||||||
|
def __init__(self, limit: int = 2000) -> None:
|
||||||
|
self.limit = limit
|
||||||
|
self._pages: list[str] = []
|
||||||
|
self._current = ""
|
||||||
|
|
||||||
|
def add_line(self, line: str) -> None:
|
||||||
|
"""Add a line of text to the paginator."""
|
||||||
|
if len(line) > self.limit:
|
||||||
|
if self._current:
|
||||||
|
self._pages.append(self._current)
|
||||||
|
self._current = ""
|
||||||
|
for i in range(0, len(line), self.limit):
|
||||||
|
chunk = line[i : i + self.limit]
|
||||||
|
if len(chunk) == self.limit:
|
||||||
|
self._pages.append(chunk)
|
||||||
|
else:
|
||||||
|
self._current = chunk
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._current:
|
||||||
|
self._current = line
|
||||||
|
elif len(self._current) + 1 + len(line) <= self.limit:
|
||||||
|
self._current += "\n" + line
|
||||||
|
else:
|
||||||
|
self._pages.append(self._current)
|
||||||
|
self._current = line
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pages(self) -> list[str]:
|
||||||
|
"""Return the accumulated pages."""
|
||||||
|
pages = list(self._pages)
|
||||||
|
if self._current:
|
||||||
|
pages.append(self._current)
|
||||||
|
return pages
|
||||||
|
|
||||||
|
|
||||||
|
def escape_markdown(text: str) -> str:
|
||||||
|
"""Escape Discord markdown formatting in ``text``."""
|
||||||
|
return re.sub(r"([\\*_~`>|])", r"\\\1", text)
|
||||||
|
|
||||||
|
|
||||||
|
def escape_mentions(text: str) -> str:
|
||||||
|
"""Escape Discord mentions in ``text``."""
|
||||||
|
return text.replace("@", "@\u200b")
|
||||||
|
@ -77,11 +77,14 @@ class VoiceClient:
|
|||||||
self.secret_key: Optional[Sequence[int]] = None
|
self.secret_key: Optional[Sequence[int]] = None
|
||||||
self._server_ip: Optional[str] = None
|
self._server_ip: Optional[str] = None
|
||||||
self._server_port: Optional[int] = None
|
self._server_port: Optional[int] = None
|
||||||
self._current_source: Optional[AudioSource] = None
|
self._current_source: Optional[AudioSource] = None
|
||||||
self._play_task: Optional[asyncio.Task] = None
|
self._play_task: Optional[asyncio.Task] = None
|
||||||
self._sink: Optional[AudioSink] = None
|
self._pause_event = asyncio.Event()
|
||||||
self._ssrc_map: dict[int, int] = {}
|
self._pause_event.set()
|
||||||
self._ssrc_lock = threading.Lock()
|
self._is_playing = False
|
||||||
|
self._sink: Optional[AudioSink] = None
|
||||||
|
self._ssrc_map: dict[int, int] = {}
|
||||||
|
self._ssrc_lock = threading.Lock()
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
if self._ws is None:
|
if self._ws is None:
|
||||||
@ -189,31 +192,37 @@ class VoiceClient:
|
|||||||
raise RuntimeError("UDP socket not initialised")
|
raise RuntimeError("UDP socket not initialised")
|
||||||
self._udp.send(frame)
|
self._udp.send(frame)
|
||||||
|
|
||||||
async def _play_loop(self) -> None:
|
async def _play_loop(self) -> None:
|
||||||
assert self._current_source is not None
|
assert self._current_source is not None
|
||||||
try:
|
self._is_playing = True
|
||||||
while True:
|
try:
|
||||||
data = await self._current_source.read()
|
while True:
|
||||||
if not data:
|
await self._pause_event.wait()
|
||||||
break
|
data = await self._current_source.read()
|
||||||
volume = getattr(self._current_source, "volume", 1.0)
|
if not data:
|
||||||
if volume != 1.0:
|
break
|
||||||
data = _apply_volume(data, volume)
|
volume = getattr(self._current_source, "volume", 1.0)
|
||||||
await self.send_audio_frame(data)
|
if volume != 1.0:
|
||||||
finally:
|
data = _apply_volume(data, volume)
|
||||||
await self._current_source.close()
|
await self.send_audio_frame(data)
|
||||||
self._current_source = None
|
finally:
|
||||||
self._play_task = None
|
await self._current_source.close()
|
||||||
|
self._current_source = None
|
||||||
|
self._play_task = None
|
||||||
|
self._is_playing = False
|
||||||
|
self._pause_event.set()
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
if self._play_task:
|
if self._play_task:
|
||||||
self._play_task.cancel()
|
self._play_task.cancel()
|
||||||
with contextlib.suppress(asyncio.CancelledError):
|
self._pause_event.set()
|
||||||
await self._play_task
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
self._play_task = None
|
await self._play_task
|
||||||
if self._current_source:
|
self._play_task = None
|
||||||
await self._current_source.close()
|
self._is_playing = False
|
||||||
self._current_source = None
|
if self._current_source:
|
||||||
|
await self._current_source.close()
|
||||||
|
self._current_source = None
|
||||||
|
|
||||||
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
|
async def play(self, source: AudioSource, *, wait: bool = True) -> None:
|
||||||
"""|coro| Play an :class:`AudioSource` on the voice connection."""
|
"""|coro| Play an :class:`AudioSource` on the voice connection."""
|
||||||
@ -224,10 +233,31 @@ class VoiceClient:
|
|||||||
if wait:
|
if wait:
|
||||||
await self._play_task
|
await self._play_task
|
||||||
|
|
||||||
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
async def play_file(self, filename: str, *, wait: bool = True) -> None:
|
||||||
"""|coro| Stream an audio file or URL using FFmpeg."""
|
"""|coro| Stream an audio file or URL using FFmpeg."""
|
||||||
|
|
||||||
await self.play(FFmpegAudioSource(filename), wait=wait)
|
await self.play(FFmpegAudioSource(filename), wait=wait)
|
||||||
|
|
||||||
|
def pause(self) -> None:
|
||||||
|
"""Pause the current audio source."""
|
||||||
|
|
||||||
|
if self._play_task and not self._play_task.done():
|
||||||
|
self._pause_event.clear()
|
||||||
|
|
||||||
|
def resume(self) -> None:
|
||||||
|
"""Resume playback of a paused source."""
|
||||||
|
|
||||||
|
if self._play_task and not self._play_task.done():
|
||||||
|
self._pause_event.set()
|
||||||
|
|
||||||
|
def is_paused(self) -> bool:
|
||||||
|
"""Return ``True`` if playback is currently paused."""
|
||||||
|
|
||||||
|
return bool(self._play_task and not self._pause_event.is_set())
|
||||||
|
|
||||||
|
def is_playing(self) -> bool:
|
||||||
|
"""Return ``True`` if audio is actively being played."""
|
||||||
|
return self._is_playing and self._pause_event.is_set()
|
||||||
|
|
||||||
def listen(self, sink: AudioSink) -> None:
|
def listen(self, sink: AudioSink) -> None:
|
||||||
"""Start listening to voice and routing to a sink."""
|
"""Start listening to voice and routing to a sink."""
|
||||||
|
@ -13,12 +13,28 @@ if member:
|
|||||||
print(member.display_name)
|
print(member.display_name)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To access the bot's own member object, use the ``Guild.me`` property. It returns
|
||||||
|
``None`` if the bot is not in the guild or its user data hasn't been loaded:
|
||||||
|
|
||||||
|
```python
|
||||||
|
bot_member = guild.me
|
||||||
|
if bot_member:
|
||||||
|
print(bot_member.joined_at)
|
||||||
|
```
|
||||||
|
|
||||||
The cache can be cleared manually if needed:
|
The cache can be cleared manually if needed:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
client.cache.clear()
|
client.cache.clear()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Partial Objects
|
||||||
|
|
||||||
|
Some events only include minimal data for related resources. When only an ``id``
|
||||||
|
is available, Disagreement represents the resource using :class:`~disagreement.Object`.
|
||||||
|
These objects can be compared and used in sets or dictionaries and can be passed
|
||||||
|
to API methods to fetch the full data when needed.
|
||||||
|
|
||||||
## Next Steps
|
## Next Steps
|
||||||
|
|
||||||
- [Components](using_components.md)
|
- [Components](using_components.md)
|
||||||
|
@ -11,7 +11,11 @@ The command handler registers a `help` command automatically. Use it to list all
|
|||||||
!help ping # shows help for the "ping" command
|
!help ping # shows help for the "ping" command
|
||||||
```
|
```
|
||||||
|
|
||||||
The help command will show each command's brief description if provided.
|
Commands are grouped by their Cog name and paginated so that long help
|
||||||
|
lists are split into multiple messages using the `Paginator` utility.
|
||||||
|
|
||||||
|
If you need custom formatting you can subclass
|
||||||
|
`HelpCommand` and override `send_command_help` or `send_group_help`.
|
||||||
|
|
||||||
## Checks
|
## Checks
|
||||||
|
|
||||||
|
@ -132,6 +132,28 @@ async def on_shard_resume(info: dict):
|
|||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## CONNECT
|
||||||
|
|
||||||
|
Dispatched when the WebSocket connection opens. The callback receives a
|
||||||
|
dictionary with the shard ID.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@client.event
|
||||||
|
async def on_connect(info: dict):
|
||||||
|
print("connected", info.get("shard_id"))
|
||||||
|
```
|
||||||
|
|
||||||
|
## DISCONNECT
|
||||||
|
|
||||||
|
Fired when the WebSocket connection closes. The callback receives a dictionary
|
||||||
|
with the shard ID.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@client.event
|
||||||
|
async def on_disconnect(info: dict):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
## VOICE_STATE_UPDATE
|
## VOICE_STATE_UPDATE
|
||||||
|
|
||||||
Triggered when a user's voice connection state changes, such as joining or leaving a voice channel. The callback receives a `VoiceStateUpdate` model.
|
Triggered when a user's voice connection state changes, such as joining or leaving a voice channel. The callback receives a `VoiceStateUpdate` model.
|
||||||
|
@ -14,6 +14,7 @@ A Python library for interacting with the Discord API, with a focus on bot devel
|
|||||||
- Built-in caching layer
|
- Built-in caching layer
|
||||||
- Experimental voice support
|
- Experimental voice support
|
||||||
- Helpful error handling utilities
|
- Helpful error handling utilities
|
||||||
|
- Paginator utility for splitting long messages
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@ -60,13 +61,9 @@ if not token:
|
|||||||
|
|
||||||
intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
|
intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
|
||||||
client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
|
client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
|
||||||
async def main() -> None:
|
|
||||||
client.add_cog(Basics(client))
|
|
||||||
await client.run()
|
|
||||||
|
|
||||||
|
client.add_cog(Basics(client))
|
||||||
if __name__ == "__main__":
|
client.run()
|
||||||
asyncio.run(main())
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Global Error Handling
|
### Global Error Handling
|
||||||
|
@ -15,6 +15,12 @@ from disagreement import Permissions
|
|||||||
value = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
|
value = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also get a bitmask containing **every** permission:
|
||||||
|
|
||||||
|
```python
|
||||||
|
all_perms = Permissions.all()
|
||||||
|
```
|
||||||
|
|
||||||
## Helper Functions
|
## Helper Functions
|
||||||
|
|
||||||
### ``permissions_value``
|
### ``permissions_value``
|
||||||
|
@ -8,13 +8,8 @@ manually.
|
|||||||
and configures the `ShardManager` automatically.
|
and configures the `ShardManager` automatically.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import asyncio
|
|
||||||
import disagreement
|
import disagreement
|
||||||
|
|
||||||
bot = disagreement.AutoShardedClient(token="YOUR_TOKEN")
|
bot = disagreement.AutoShardedClient(token="YOUR_TOKEN")
|
||||||
|
bot.run()
|
||||||
async def main():
|
|
||||||
await bot.run()
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
```
|
```
|
||||||
|
@ -157,6 +157,22 @@ container = Container(
|
|||||||
A container can itself contain layout and content components, letting you build complex messages.
|
A container can itself contain layout and content components, letting you build complex messages.
|
||||||
|
|
||||||
|
|
||||||
|
## Persistent Views
|
||||||
|
|
||||||
|
Views with ``timeout=None`` are persistent. Their ``custom_id`` components are saved to ``persistent_views.json`` so they survive bot restarts.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyView(View):
|
||||||
|
@button(label="Press", custom_id="press")
|
||||||
|
async def handle(self, view, inter):
|
||||||
|
await inter.respond("Pressed!")
|
||||||
|
|
||||||
|
client.add_persistent_view(MyView())
|
||||||
|
```
|
||||||
|
|
||||||
|
When the client starts, it loads this file and registers each view again. Remove
|
||||||
|
the file to clear stored views.
|
||||||
|
|
||||||
## Next Steps
|
## Next Steps
|
||||||
|
|
||||||
- [Slash Commands](slash_commands.md)
|
- [Slash Commands](slash_commands.md)
|
||||||
|
20
docs/utils.md
Normal file
20
docs/utils.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Utility Helpers
|
||||||
|
|
||||||
|
Disagreement provides a few small utility functions for working with Discord data.
|
||||||
|
|
||||||
|
## `utcnow`
|
||||||
|
|
||||||
|
Returns the current timezone-aware UTC `datetime`.
|
||||||
|
|
||||||
|
## `snowflake_time`
|
||||||
|
|
||||||
|
Converts a Discord snowflake ID into the UTC timestamp when it was generated.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.utils import snowflake_time
|
||||||
|
|
||||||
|
created_at = snowflake_time(175928847299117063)
|
||||||
|
print(created_at.isoformat())
|
||||||
|
```
|
||||||
|
|
||||||
|
The function extracts the timestamp from the snowflake and returns a `datetime` in UTC.
|
@ -6,6 +6,10 @@ Disagreement includes experimental support for connecting to voice channels. You
|
|||||||
voice = await client.join_voice(guild_id, channel_id)
|
voice = await client.join_voice(guild_id, channel_id)
|
||||||
await voice.play_file("welcome.mp3")
|
await voice.play_file("welcome.mp3")
|
||||||
await voice.play_file("another.mp3") # switch sources while connected
|
await voice.play_file("another.mp3") # switch sources while connected
|
||||||
|
voice.pause()
|
||||||
|
voice.resume()
|
||||||
|
if voice.is_playing():
|
||||||
|
print("audio is playing")
|
||||||
await voice.close()
|
await voice.close()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -67,9 +67,7 @@ BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
|||||||
# --- Intents Configuration ---
|
# --- Intents Configuration ---
|
||||||
# Define the intents your bot needs. For basic message reading and responding:
|
# Define the intents your bot needs. For basic message reading and responding:
|
||||||
intents = (
|
intents = (
|
||||||
GatewayIntent.GUILDS
|
GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
|
||||||
| GatewayIntent.GUILD_MESSAGES
|
|
||||||
| GatewayIntent.MESSAGE_CONTENT
|
|
||||||
) # MESSAGE_CONTENT is privileged!
|
) # MESSAGE_CONTENT is privileged!
|
||||||
|
|
||||||
# If you don't need message content and only react to commands/mentions,
|
# If you don't need message content and only react to commands/mentions,
|
||||||
@ -210,14 +208,14 @@ async def on_guild_available(guild: Guild):
|
|||||||
|
|
||||||
|
|
||||||
# --- Main Execution ---
|
# --- Main Execution ---
|
||||||
async def main():
|
def main():
|
||||||
print("Starting Disagreement Bot...")
|
print("Starting Disagreement Bot...")
|
||||||
try:
|
try:
|
||||||
# Add the Cog to the client
|
# Add the Cog to the client
|
||||||
client.add_cog(ExampleCog(client)) # Pass client instance to Cog constructor
|
client.add_cog(ExampleCog(client)) # Pass client instance to Cog constructor
|
||||||
# client.add_cog is synchronous, but it schedules cog.cog_load() if it's async.
|
# client.add_cog is synchronous, but it schedules cog.cog_load() if it's async.
|
||||||
|
|
||||||
await client.run()
|
client.run()
|
||||||
except AuthenticationError:
|
except AuthenticationError:
|
||||||
print(
|
print(
|
||||||
"Authentication failed. Please check your bot token and ensure it's correct."
|
"Authentication failed. Please check your bot token and ensure it's correct."
|
||||||
@ -232,7 +230,7 @@ async def main():
|
|||||||
finally:
|
finally:
|
||||||
if not client.is_closed():
|
if not client.is_closed():
|
||||||
print("Ensuring client is closed...")
|
print("Ensuring client is closed...")
|
||||||
await client.close()
|
asyncio.run(client.close())
|
||||||
print("Bot has been shut down.")
|
print("Bot has been shut down.")
|
||||||
|
|
||||||
|
|
||||||
@ -244,4 +242,4 @@ if __name__ == "__main__":
|
|||||||
# if os.name == 'nt':
|
# if os.name == 'nt':
|
||||||
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
asyncio.run(main())
|
main()
|
||||||
|
@ -263,7 +263,7 @@ class ComponentCommandsCog(Cog):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
def main():
|
||||||
@client.event
|
@client.event
|
||||||
async def on_ready():
|
async def on_ready():
|
||||||
if client.user:
|
if client.user:
|
||||||
@ -283,8 +283,8 @@ async def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
client.add_cog(ComponentCommandsCog(client))
|
client.add_cog(ComponentCommandsCog(client))
|
||||||
await client.run()
|
client.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
main()
|
||||||
|
@ -65,11 +65,9 @@ client.app_command_handler.add_command(user_info)
|
|||||||
client.app_command_handler.add_command(quote)
|
client.app_command_handler.add_command(quote)
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
def main() -> None:
|
||||||
await client.run()
|
client.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
main()
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
@ -27,10 +27,10 @@ intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
|
|||||||
client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
|
client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
def main() -> None:
|
||||||
client.add_cog(Basics(client))
|
client.add_cog(Basics(client))
|
||||||
await client.run()
|
client.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
main()
|
||||||
|
@ -230,7 +230,7 @@ class TestCog(Cog):
|
|||||||
|
|
||||||
|
|
||||||
# --- Main Bot Script ---
|
# --- Main Bot Script ---
|
||||||
async def main():
|
def main():
|
||||||
bot_token = os.getenv("DISCORD_BOT_TOKEN")
|
bot_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||||
application_id = os.getenv("DISCORD_APPLICATION_ID")
|
application_id = os.getenv("DISCORD_APPLICATION_ID")
|
||||||
|
|
||||||
@ -291,7 +291,7 @@ async def main():
|
|||||||
client.add_cog(TestCog(client))
|
client.add_cog(TestCog(client))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await client.run()
|
client.run()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Bot shutting down...")
|
logger.info("Bot shutting down...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -300,7 +300,7 @@ async def main():
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if not client.is_closed():
|
if not client.is_closed():
|
||||||
await client.close()
|
asyncio.run(client.close())
|
||||||
logger.info("Bot has been closed.")
|
logger.info("Bot has been closed.")
|
||||||
|
|
||||||
|
|
||||||
@ -310,6 +310,6 @@ if __name__ == "__main__":
|
|||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
asyncio.run(main())
|
main()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Main loop interrupted. Exiting.")
|
logger.info("Main loop interrupted. Exiting.")
|
||||||
|
@ -62,9 +62,9 @@ async def on_ready():
|
|||||||
print("------")
|
print("------")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
def main():
|
||||||
await client.run()
|
client.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
main()
|
||||||
|
@ -63,6 +63,4 @@ async def on_ready():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
client.run()
|
||||||
|
|
||||||
asyncio.run(client.run())
|
|
||||||
|
@ -9,7 +9,15 @@ from typing import Set
|
|||||||
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
from disagreement import Client, GatewayIntent, Member, Message, Cog, command, CommandContext
|
from disagreement import (
|
||||||
|
Client,
|
||||||
|
GatewayIntent,
|
||||||
|
Member,
|
||||||
|
Message,
|
||||||
|
Cog,
|
||||||
|
command,
|
||||||
|
CommandContext,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@ -26,9 +34,7 @@ if not BOT_TOKEN:
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
intents = (
|
intents = (
|
||||||
GatewayIntent.GUILDS
|
GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
|
||||||
| GatewayIntent.GUILD_MESSAGES
|
|
||||||
| GatewayIntent.MESSAGE_CONTENT
|
|
||||||
)
|
)
|
||||||
client = Client(token=BOT_TOKEN, command_prefix="!", intents=intents)
|
client = Client(token=BOT_TOKEN, command_prefix="!", intents=intents)
|
||||||
|
|
||||||
@ -78,10 +84,10 @@ async def on_message(message: Message) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
def main() -> None:
|
||||||
client.add_cog(ModerationCog(client))
|
client.add_cog(ModerationCog(client))
|
||||||
await client.run()
|
client.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
main()
|
||||||
|
@ -137,11 +137,11 @@ async def on_reaction_remove(reaction: Reaction, user: User | Member):
|
|||||||
|
|
||||||
|
|
||||||
# --- Main Execution ---
|
# --- Main Execution ---
|
||||||
async def main():
|
def main():
|
||||||
print("Starting Reaction Bot...")
|
print("Starting Reaction Bot...")
|
||||||
try:
|
try:
|
||||||
client.add_cog(ReactionCog(client))
|
client.add_cog(ReactionCog(client))
|
||||||
await client.run()
|
client.run()
|
||||||
except AuthenticationError:
|
except AuthenticationError:
|
||||||
print("Authentication failed. Check your bot token.")
|
print("Authentication failed. Check your bot token.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -149,9 +149,9 @@ async def main():
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
if not client.is_closed():
|
if not client.is_closed():
|
||||||
await client.close()
|
asyncio.run(client.close())
|
||||||
print("Bot has been shut down.")
|
print("Bot has been shut down.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
main()
|
||||||
|
@ -34,12 +34,12 @@ async def on_ready():
|
|||||||
print("Shard bot ready")
|
print("Shard bot ready")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
def main():
|
||||||
if not TOKEN:
|
if not TOKEN:
|
||||||
print("DISCORD_BOT_TOKEN environment variable not set")
|
print("DISCORD_BOT_TOKEN environment variable not set")
|
||||||
return
|
return
|
||||||
await client.run()
|
client.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
main()
|
||||||
|
@ -53,9 +53,7 @@ BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
|||||||
|
|
||||||
# --- Intents Configuration ---
|
# --- Intents Configuration ---
|
||||||
intents = (
|
intents = (
|
||||||
GatewayIntent.GUILDS
|
GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
|
||||||
| GatewayIntent.GUILD_MESSAGES
|
|
||||||
| GatewayIntent.MESSAGE_CONTENT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Initialize the Client ---
|
# --- Initialize the Client ---
|
||||||
@ -106,11 +104,11 @@ async def on_ready():
|
|||||||
|
|
||||||
|
|
||||||
# --- Main Execution ---
|
# --- Main Execution ---
|
||||||
async def main():
|
def main():
|
||||||
print("Starting Typing Indicator Bot...")
|
print("Starting Typing Indicator Bot...")
|
||||||
try:
|
try:
|
||||||
client.add_cog(TypingCog(client))
|
client.add_cog(TypingCog(client))
|
||||||
await client.run()
|
client.run()
|
||||||
except AuthenticationError:
|
except AuthenticationError:
|
||||||
print("Authentication failed. Check your bot token.")
|
print("Authentication failed. Check your bot token.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -118,9 +116,9 @@ async def main():
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
if not client.is_closed():
|
if not client.is_closed():
|
||||||
await client.close()
|
asyncio.run(client.close())
|
||||||
print("Bot has been shut down.")
|
print("Bot has been shut down.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
main()
|
||||||
|
@ -62,4 +62,5 @@ nav:
|
|||||||
- 'Mentions': 'mentions.md'
|
- 'Mentions': 'mentions.md'
|
||||||
- 'OAuth2': 'oauth2.md'
|
- 'OAuth2': 'oauth2.md'
|
||||||
- 'Presence': 'presence.md'
|
- 'Presence': 'presence.md'
|
||||||
- 'Voice Client': 'voice_client.md'
|
- 'Voice Client': 'voice_client.md'
|
||||||
|
- 'Utility Helpers': 'utils.md'
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "disagreement"
|
name = "disagreement"
|
||||||
version = "0.6.0"
|
version = "0.8.1"
|
||||||
description = "A Python library for the Discord API."
|
description = "A Python library for the Discord API."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
@ -3,7 +3,16 @@ import pytest
|
|||||||
from disagreement.ext.commands.converters import run_converters
|
from disagreement.ext.commands.converters import run_converters
|
||||||
from disagreement.ext.commands.core import CommandContext, Command
|
from disagreement.ext.commands.core import CommandContext, Command
|
||||||
from disagreement.ext.commands.errors import BadArgument
|
from disagreement.ext.commands.errors import BadArgument
|
||||||
from disagreement.models import Message, Member, Role, Guild
|
from disagreement.models import (
|
||||||
|
Message,
|
||||||
|
Member,
|
||||||
|
Role,
|
||||||
|
Guild,
|
||||||
|
User,
|
||||||
|
TextChannel,
|
||||||
|
VoiceChannel,
|
||||||
|
PartialEmoji,
|
||||||
|
)
|
||||||
from disagreement.enums import (
|
from disagreement.enums import (
|
||||||
VerificationLevel,
|
VerificationLevel,
|
||||||
MessageNotificationLevel,
|
MessageNotificationLevel,
|
||||||
@ -11,21 +20,27 @@ from disagreement.enums import (
|
|||||||
MFALevel,
|
MFALevel,
|
||||||
GuildNSFWLevel,
|
GuildNSFWLevel,
|
||||||
PremiumTier,
|
PremiumTier,
|
||||||
|
ChannelType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from disagreement.client import Client
|
from disagreement.client import Client
|
||||||
from disagreement.cache import GuildCache
|
from disagreement.cache import GuildCache, Cache, ChannelCache
|
||||||
|
|
||||||
|
|
||||||
class DummyBot(Client):
|
class DummyBot(Client):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(token="test")
|
super().__init__(token="test")
|
||||||
self._guilds = GuildCache()
|
self._guilds = GuildCache()
|
||||||
|
self._users = Cache()
|
||||||
|
self._channels = ChannelCache()
|
||||||
|
|
||||||
def get_guild(self, guild_id):
|
def get_guild(self, guild_id):
|
||||||
return self._guilds.get(guild_id)
|
return self._guilds.get(guild_id)
|
||||||
|
|
||||||
|
def get_channel(self, channel_id):
|
||||||
|
return self._channels.get(channel_id)
|
||||||
|
|
||||||
async def fetch_member(self, guild_id, member_id):
|
async def fetch_member(self, guild_id, member_id):
|
||||||
guild = self._guilds.get(guild_id)
|
guild = self._guilds.get(guild_id)
|
||||||
return guild.get_member(member_id) if guild else None
|
return guild.get_member(member_id) if guild else None
|
||||||
@ -37,6 +52,12 @@ class DummyBot(Client):
|
|||||||
async def fetch_guild(self, guild_id):
|
async def fetch_guild(self, guild_id):
|
||||||
return self._guilds.get(guild_id)
|
return self._guilds.get(guild_id)
|
||||||
|
|
||||||
|
async def fetch_user(self, user_id):
|
||||||
|
return self._users.get(user_id)
|
||||||
|
|
||||||
|
async def fetch_channel(self, channel_id):
|
||||||
|
return self._channels.get(channel_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def guild_objects():
|
def guild_objects():
|
||||||
@ -60,6 +81,9 @@ def guild_objects():
|
|||||||
guild = Guild(guild_data, client_instance=bot)
|
guild = Guild(guild_data, client_instance=bot)
|
||||||
bot._guilds.set(guild.id, guild)
|
bot._guilds.set(guild.id, guild)
|
||||||
|
|
||||||
|
user = User({"id": "7", "username": "u", "discriminator": "0001"})
|
||||||
|
bot._users.set(user.id, user)
|
||||||
|
|
||||||
member = Member(
|
member = Member(
|
||||||
{
|
{
|
||||||
"user": {"id": "3", "username": "m", "discriminator": "0001"},
|
"user": {"id": "3", "username": "m", "discriminator": "0001"},
|
||||||
@ -86,12 +110,38 @@ def guild_objects():
|
|||||||
guild._members.set(member.id, member)
|
guild._members.set(member.id, member)
|
||||||
guild.roles.append(role)
|
guild.roles.append(role)
|
||||||
|
|
||||||
return guild, member, role
|
text_channel = TextChannel(
|
||||||
|
{
|
||||||
|
"id": "20",
|
||||||
|
"type": ChannelType.GUILD_TEXT.value,
|
||||||
|
"guild_id": guild.id,
|
||||||
|
"permission_overwrites": [],
|
||||||
|
},
|
||||||
|
client_instance=bot,
|
||||||
|
)
|
||||||
|
voice_channel = VoiceChannel(
|
||||||
|
{
|
||||||
|
"id": "21",
|
||||||
|
"type": ChannelType.GUILD_VOICE.value,
|
||||||
|
"guild_id": guild.id,
|
||||||
|
"permission_overwrites": [],
|
||||||
|
},
|
||||||
|
client_instance=bot,
|
||||||
|
)
|
||||||
|
|
||||||
|
guild._channels.set(text_channel.id, text_channel)
|
||||||
|
guild.text_channels.append(text_channel)
|
||||||
|
guild._channels.set(voice_channel.id, voice_channel)
|
||||||
|
guild.voice_channels.append(voice_channel)
|
||||||
|
bot._channels.set(text_channel.id, text_channel)
|
||||||
|
bot._channels.set(voice_channel.id, voice_channel)
|
||||||
|
|
||||||
|
return guild, member, role, user, text_channel, voice_channel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def command_context(guild_objects):
|
def command_context(guild_objects):
|
||||||
guild, member, role = guild_objects
|
guild, member, role, _, _, _ = guild_objects
|
||||||
bot = guild._client
|
bot = guild._client
|
||||||
message_data = {
|
message_data = {
|
||||||
"id": "10",
|
"id": "10",
|
||||||
@ -114,7 +164,7 @@ def command_context(guild_objects):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_member_converter(command_context, guild_objects):
|
async def test_member_converter(command_context, guild_objects):
|
||||||
_, member, _ = guild_objects
|
_, member, _, _, _, _ = guild_objects
|
||||||
mention = f"<@!{member.id}>"
|
mention = f"<@!{member.id}>"
|
||||||
result = await run_converters(command_context, Member, mention)
|
result = await run_converters(command_context, Member, mention)
|
||||||
assert result is member
|
assert result is member
|
||||||
@ -124,7 +174,7 @@ async def test_member_converter(command_context, guild_objects):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_role_converter(command_context, guild_objects):
|
async def test_role_converter(command_context, guild_objects):
|
||||||
_, _, role = guild_objects
|
_, _, role, _, _, _ = guild_objects
|
||||||
mention = f"<@&{role.id}>"
|
mention = f"<@&{role.id}>"
|
||||||
result = await run_converters(command_context, Role, mention)
|
result = await run_converters(command_context, Role, mention)
|
||||||
assert result is role
|
assert result is role
|
||||||
@ -132,13 +182,55 @@ async def test_role_converter(command_context, guild_objects):
|
|||||||
assert result is role
|
assert result is role
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_converter(command_context, guild_objects):
|
||||||
|
_, _, _, user, _, _ = guild_objects
|
||||||
|
mention = f"<@{user.id}>"
|
||||||
|
result = await run_converters(command_context, User, mention)
|
||||||
|
assert result is user
|
||||||
|
result = await run_converters(command_context, User, user.id)
|
||||||
|
assert result is user
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_guild_converter(command_context, guild_objects):
|
async def test_guild_converter(command_context, guild_objects):
|
||||||
guild, _, _ = guild_objects
|
guild, _, _, _, _, _ = guild_objects
|
||||||
result = await run_converters(command_context, Guild, guild.id)
|
result = await run_converters(command_context, Guild, guild.id)
|
||||||
assert result is guild
|
assert result is guild
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_channel_converter(command_context, guild_objects):
|
||||||
|
_, _, _, _, text_channel, _ = guild_objects
|
||||||
|
mention = f"<#{text_channel.id}>"
|
||||||
|
result = await run_converters(command_context, TextChannel, mention)
|
||||||
|
assert result is text_channel
|
||||||
|
result = await run_converters(command_context, TextChannel, text_channel.id)
|
||||||
|
assert result is text_channel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_voice_channel_converter(command_context, guild_objects):
|
||||||
|
_, _, _, _, _, voice_channel = guild_objects
|
||||||
|
mention = f"<#{voice_channel.id}>"
|
||||||
|
result = await run_converters(command_context, VoiceChannel, mention)
|
||||||
|
assert result is voice_channel
|
||||||
|
result = await run_converters(command_context, VoiceChannel, voice_channel.id)
|
||||||
|
assert result is voice_channel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_emoji_converter(command_context):
|
||||||
|
result = await run_converters(command_context, PartialEmoji, "<:smile:1>")
|
||||||
|
assert isinstance(result, PartialEmoji)
|
||||||
|
assert result.id == "1"
|
||||||
|
assert result.name == "smile"
|
||||||
|
|
||||||
|
result = await run_converters(command_context, PartialEmoji, "😄")
|
||||||
|
assert result.id is None
|
||||||
|
assert result.name == "😄"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_member_converter_no_guild():
|
async def test_member_converter_no_guild():
|
||||||
guild_data = {
|
guild_data = {
|
||||||
|
14
tests/test_asset.py
Normal file
14
tests/test_asset.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from disagreement.models import User
|
||||||
|
from disagreement.asset import Asset
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_avatar_returns_asset():
|
||||||
|
user = User({"id": "1", "username": "u", "discriminator": "0001", "avatar": "abc"})
|
||||||
|
avatar = user.avatar
|
||||||
|
assert isinstance(avatar, Asset)
|
||||||
|
assert avatar.url == "https://cdn.discordapp.com/avatars/1/abc.png"
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_avatar_none():
|
||||||
|
user = User({"id": "1", "username": "u", "discriminator": "0001"})
|
||||||
|
assert user.avatar is None
|
83
tests/test_auto_sync_commands.py
Normal file
83
tests/test_auto_sync_commands.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.gateway import GatewayClient
|
||||||
|
from disagreement.event_dispatcher import EventDispatcher
|
||||||
|
|
||||||
|
|
||||||
|
class DummyHTTP:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DummyUser:
|
||||||
|
username = "u"
|
||||||
|
discriminator = "0001"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_sync_on_ready(monkeypatch):
|
||||||
|
client = Client(token="t", application_id="123")
|
||||||
|
http = DummyHTTP()
|
||||||
|
dispatcher = EventDispatcher(client)
|
||||||
|
gw = GatewayClient(
|
||||||
|
http_client=http,
|
||||||
|
event_dispatcher=dispatcher,
|
||||||
|
token="t",
|
||||||
|
intents=0,
|
||||||
|
client_instance=client,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(client, "parse_user", lambda d: DummyUser())
|
||||||
|
monkeypatch.setattr(gw._dispatcher, "dispatch", AsyncMock())
|
||||||
|
sync_mock = AsyncMock()
|
||||||
|
monkeypatch.setattr(client, "sync_application_commands", sync_mock)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"t": "READY",
|
||||||
|
"s": 1,
|
||||||
|
"d": {
|
||||||
|
"session_id": "s1",
|
||||||
|
"resume_gateway_url": "url",
|
||||||
|
"application": {"id": "123"},
|
||||||
|
"user": {"id": "1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
await gw._handle_dispatch(data)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
sync_mock.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_sync_disabled(monkeypatch):
|
||||||
|
client = Client(token="t", application_id="123", sync_commands_on_ready=False)
|
||||||
|
http = DummyHTTP()
|
||||||
|
dispatcher = EventDispatcher(client)
|
||||||
|
gw = GatewayClient(
|
||||||
|
http_client=http,
|
||||||
|
event_dispatcher=dispatcher,
|
||||||
|
token="t",
|
||||||
|
intents=0,
|
||||||
|
client_instance=client,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(client, "parse_user", lambda d: DummyUser())
|
||||||
|
monkeypatch.setattr(gw._dispatcher, "dispatch", AsyncMock())
|
||||||
|
sync_mock = AsyncMock()
|
||||||
|
monkeypatch.setattr(client, "sync_application_commands", sync_mock)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"t": "READY",
|
||||||
|
"s": 1,
|
||||||
|
"d": {
|
||||||
|
"session_id": "s1",
|
||||||
|
"resume_gateway_url": "url",
|
||||||
|
"application": {"id": "123"},
|
||||||
|
"user": {"id": "1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
await gw._handle_dispatch(data)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
sync_mock.assert_not_called()
|
@ -1,6 +1,60 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from disagreement.cache import Cache
|
from disagreement.cache import Cache
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.caching import MemberCacheFlags
|
||||||
|
from disagreement.enums import (
|
||||||
|
ChannelType,
|
||||||
|
ExplicitContentFilterLevel,
|
||||||
|
GuildNSFWLevel,
|
||||||
|
MFALevel,
|
||||||
|
MessageNotificationLevel,
|
||||||
|
PremiumTier,
|
||||||
|
VerificationLevel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict:
|
||||||
|
base = {
|
||||||
|
"id": gid,
|
||||||
|
"name": f"g{gid}",
|
||||||
|
"owner_id": "1",
|
||||||
|
"afk_timeout": 60,
|
||||||
|
"verification_level": VerificationLevel.NONE.value,
|
||||||
|
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||||
|
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||||
|
"roles": [],
|
||||||
|
"emojis": [],
|
||||||
|
"features": [],
|
||||||
|
"mfa_level": MFALevel.NONE.value,
|
||||||
|
"system_channel_flags": 0,
|
||||||
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
|
"channels": [],
|
||||||
|
"members": [],
|
||||||
|
}
|
||||||
|
for i in range(channel_count):
|
||||||
|
base["channels"].append(
|
||||||
|
{
|
||||||
|
"id": f"{gid}-c{i}",
|
||||||
|
"type": ChannelType.GUILD_TEXT.value,
|
||||||
|
"guild_id": gid,
|
||||||
|
"permission_overwrites": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for i in range(member_count):
|
||||||
|
base["members"].append(
|
||||||
|
{
|
||||||
|
"user": {
|
||||||
|
"id": f"{gid}-m{i}",
|
||||||
|
"username": f"u{i}",
|
||||||
|
"discriminator": "0001",
|
||||||
|
},
|
||||||
|
"joined_at": "t",
|
||||||
|
"roles": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
def test_cache_store_and_get():
|
def test_cache_store_and_get():
|
||||||
@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item():
|
|||||||
|
|
||||||
assert cache.get_or_fetch("c", fetch) == 3
|
assert cache.get_or_fetch("c", fetch) == 3
|
||||||
assert called
|
assert called
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_get_all_channels_and_members():
|
||||||
|
client = Client(token="t")
|
||||||
|
client.parse_guild(_guild_payload("1", 2, 2))
|
||||||
|
client.parse_guild(_guild_payload("2", 1, 1))
|
||||||
|
|
||||||
|
channels = {c.id for c in client.get_all_channels()}
|
||||||
|
members = {m.id for m in client.get_all_members()}
|
||||||
|
|
||||||
|
assert channels == {"1-c0", "1-c1", "2-c0"}
|
||||||
|
assert members == {"1-m0", "1-m1", "2-m0"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_get_all_members_disabled_cache():
|
||||||
|
client = Client(token="t", member_cache_flags=MemberCacheFlags.none())
|
||||||
|
client.parse_guild(_guild_payload("1", 1, 2))
|
||||||
|
|
||||||
|
assert client.get_all_members() == []
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# pylint: disable=no-member
|
||||||
|
|
||||||
from disagreement.client import Client
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
42
tests/test_client_uptime.py
Normal file
42
tests/test_client_uptime.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_records_start_time(monkeypatch):
|
||||||
|
start = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
monkeypatch.setattr("disagreement.client.utcnow", lambda: start)
|
||||||
|
|
||||||
|
client = Client(token="t")
|
||||||
|
monkeypatch.setattr(client, "_initialize_gateway", AsyncMock())
|
||||||
|
client._gateway = SimpleNamespace(connect=AsyncMock())
|
||||||
|
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
|
||||||
|
|
||||||
|
assert client.start_time is None
|
||||||
|
await client.connect()
|
||||||
|
assert client.start_time == start
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_uptime(monkeypatch):
|
||||||
|
start = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||||
|
end = start + timedelta(seconds=5)
|
||||||
|
times = [start, end]
|
||||||
|
|
||||||
|
def fake_now():
|
||||||
|
return times.pop(0)
|
||||||
|
|
||||||
|
monkeypatch.setattr("disagreement.client.utcnow", fake_now)
|
||||||
|
|
||||||
|
client = Client(token="t")
|
||||||
|
monkeypatch.setattr(client, "_initialize_gateway", AsyncMock())
|
||||||
|
client._gateway = SimpleNamespace(connect=AsyncMock())
|
||||||
|
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
|
||||||
|
|
||||||
|
await client.connect()
|
||||||
|
assert client.uptime() == timedelta(seconds=5)
|
@ -6,6 +6,7 @@ from disagreement.ext.commands.decorators import (
|
|||||||
check,
|
check,
|
||||||
cooldown,
|
cooldown,
|
||||||
requires_permissions,
|
requires_permissions,
|
||||||
|
is_owner,
|
||||||
)
|
)
|
||||||
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
|
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
|
||||||
from disagreement.permissions import Permissions
|
from disagreement.permissions import Permissions
|
||||||
@ -133,3 +134,44 @@ async def test_requires_permissions_fail(message):
|
|||||||
|
|
||||||
with pytest.raises(CheckFailure):
|
with pytest.raises(CheckFailure):
|
||||||
await cmd.invoke(ctx)
|
await cmd.invoke(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_owner_pass(message):
|
||||||
|
message._client.owner_ids = ["2"]
|
||||||
|
|
||||||
|
@is_owner()
|
||||||
|
async def cb(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cmd = Command(cb)
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message,
|
||||||
|
bot=message._client,
|
||||||
|
prefix="!",
|
||||||
|
command=cmd,
|
||||||
|
invoked_with="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
await cmd.invoke(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_owner_fail(message):
|
||||||
|
message._client.owner_ids = ["1"]
|
||||||
|
|
||||||
|
@is_owner()
|
||||||
|
async def cb(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cmd = Command(cb)
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message,
|
||||||
|
bot=message._client,
|
||||||
|
prefix="!",
|
||||||
|
command=cmd,
|
||||||
|
invoked_with="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(CheckFailure):
|
||||||
|
await cmd.invoke(ctx)
|
||||||
|
59
tests/test_connect_events.py
Normal file
59
tests/test_connect_events.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.shard_manager import ShardManager
|
||||||
|
from disagreement.event_dispatcher import EventDispatcher
|
||||||
|
|
||||||
|
|
||||||
|
class DummyGateway:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.connect = AsyncMock()
|
||||||
|
self.close = AsyncMock()
|
||||||
|
|
||||||
|
dispatcher = kwargs.get("event_dispatcher")
|
||||||
|
shard_id = kwargs.get("shard_id")
|
||||||
|
|
||||||
|
async def emit_connect():
|
||||||
|
await dispatcher.dispatch("CONNECT", {"shard_id": shard_id})
|
||||||
|
|
||||||
|
async def emit_close():
|
||||||
|
await dispatcher.dispatch("DISCONNECT", {"shard_id": shard_id})
|
||||||
|
|
||||||
|
self.connect.side_effect = emit_connect
|
||||||
|
self.close.side_effect = emit_close
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClient:
|
||||||
|
def __init__(self):
|
||||||
|
self._http = object()
|
||||||
|
self._event_dispatcher = EventDispatcher(self)
|
||||||
|
self.token = "t"
|
||||||
|
self.intents = 0
|
||||||
|
self.verbose = False
|
||||||
|
self.gateway_max_retries = 5
|
||||||
|
self.gateway_max_backoff = 60.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_disconnect_events(monkeypatch):
|
||||||
|
monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway)
|
||||||
|
client = DummyClient()
|
||||||
|
manager = ShardManager(client, shard_count=1)
|
||||||
|
|
||||||
|
events: list[tuple[str, int | None]] = []
|
||||||
|
|
||||||
|
async def on_connect(info):
|
||||||
|
events.append(("connect", info.get("shard_id")))
|
||||||
|
|
||||||
|
async def on_disconnect(info):
|
||||||
|
events.append(("disconnect", info.get("shard_id")))
|
||||||
|
|
||||||
|
client._event_dispatcher.register("CONNECT", on_connect)
|
||||||
|
client._event_dispatcher.register("DISCONNECT", on_disconnect)
|
||||||
|
|
||||||
|
await manager.start()
|
||||||
|
await manager.close()
|
||||||
|
|
||||||
|
assert ("connect", 0) in events
|
||||||
|
assert ("disconnect", 0) in events
|
41
tests/test_crosspost_message.py
Normal file
41
tests/test_crosspost_message.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.http import HTTPClient
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.models import Message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_crosspost_message_calls_request():
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(return_value={"id": "m"})
|
||||||
|
data = await http.crosspost_message("c", "m")
|
||||||
|
http.request.assert_called_once_with(
|
||||||
|
"POST",
|
||||||
|
"/channels/c/messages/m/crosspost",
|
||||||
|
)
|
||||||
|
assert data == {"id": "m"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_crosspost_returns_message():
|
||||||
|
payload = {
|
||||||
|
"id": "2",
|
||||||
|
"channel_id": "1",
|
||||||
|
"author": {"id": "3", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
http = SimpleNamespace(crosspost_message=AsyncMock(return_value=payload))
|
||||||
|
client = Client.__new__(Client)
|
||||||
|
client._http = http
|
||||||
|
client.parse_message = lambda d: Message(d, client_instance=client)
|
||||||
|
message = Message(payload, client_instance=client)
|
||||||
|
|
||||||
|
new_msg = await message.crosspost()
|
||||||
|
|
||||||
|
http.crosspost_message.assert_awaited_once_with("1", "2")
|
||||||
|
assert isinstance(new_msg, Message)
|
||||||
|
assert new_msg._client is client
|
@ -22,6 +22,38 @@ def create_dummy_module(name):
|
|||||||
return called
|
return called
|
||||||
|
|
||||||
|
|
||||||
|
def create_async_module(name):
|
||||||
|
mod = types.ModuleType(name)
|
||||||
|
called = {"setup": False, "teardown": False}
|
||||||
|
|
||||||
|
async def setup():
|
||||||
|
called["setup"] = True
|
||||||
|
|
||||||
|
def teardown():
|
||||||
|
called["teardown"] = True
|
||||||
|
|
||||||
|
mod.setup = setup
|
||||||
|
mod.teardown = teardown
|
||||||
|
sys.modules[name] = mod
|
||||||
|
return called
|
||||||
|
|
||||||
|
|
||||||
|
def create_async_teardown_module(name):
|
||||||
|
mod = types.ModuleType(name)
|
||||||
|
called = {"setup": False, "teardown": False}
|
||||||
|
|
||||||
|
def setup():
|
||||||
|
called["setup"] = True
|
||||||
|
|
||||||
|
async def teardown():
|
||||||
|
called["teardown"] = True
|
||||||
|
|
||||||
|
mod.setup = setup
|
||||||
|
mod.teardown = teardown
|
||||||
|
sys.modules[name] = mod
|
||||||
|
return called
|
||||||
|
|
||||||
|
|
||||||
def test_load_and_unload_extension():
|
def test_load_and_unload_extension():
|
||||||
called = create_dummy_module("dummy_ext")
|
called = create_dummy_module("dummy_ext")
|
||||||
|
|
||||||
@ -75,3 +107,23 @@ def test_reload_extension(monkeypatch):
|
|||||||
|
|
||||||
loader.unload_extension("reload_ext")
|
loader.unload_extension("reload_ext")
|
||||||
assert called_second["teardown"] is True
|
assert called_second["teardown"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_setup():
|
||||||
|
called = create_async_module("async_ext")
|
||||||
|
|
||||||
|
loader.load_extension("async_ext")
|
||||||
|
assert called["setup"] is True
|
||||||
|
|
||||||
|
loader.unload_extension("async_ext")
|
||||||
|
assert called["teardown"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_teardown():
|
||||||
|
called = create_async_teardown_module("async_teardown_ext")
|
||||||
|
|
||||||
|
loader.load_extension("async_teardown_ext")
|
||||||
|
assert called["setup"] is True
|
||||||
|
|
||||||
|
loader.unload_extension("async_teardown_ext")
|
||||||
|
assert called["teardown"] is True
|
||||||
|
29
tests/test_get_cog.py
Normal file
29
tests/test_get_cog.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.ext import commands
|
||||||
|
|
||||||
|
|
||||||
|
class DummyCog(commands.Cog):
|
||||||
|
def __init__(self, client: Client) -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_command_handler_get_cog():
|
||||||
|
bot = object()
|
||||||
|
handler = commands.core.CommandHandler(client=bot, prefix="!")
|
||||||
|
cog = DummyCog(bot) # type: ignore[arg-type]
|
||||||
|
handler.add_cog(cog)
|
||||||
|
await asyncio.sleep(0) # allow any scheduled tasks to start
|
||||||
|
assert handler.get_cog("DummyCog") is cog
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_client_get_cog():
|
||||||
|
client = Client(token="t")
|
||||||
|
cog = DummyCog(client)
|
||||||
|
client.add_cog(cog)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert client.get_cog("DummyCog") is cog
|
59
tests/test_get_context.py
Normal file
59
tests/test_get_context.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.ext.commands.core import Command, CommandHandler
|
||||||
|
from disagreement.models import Message
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBot:
|
||||||
|
def __init__(self):
|
||||||
|
self.executed = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_context_parses_without_execution():
|
||||||
|
bot = DummyBot()
|
||||||
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
|
||||||
|
async def foo(ctx, number: int, word: str):
|
||||||
|
bot.executed = True
|
||||||
|
|
||||||
|
handler.add_command(Command(foo, name="foo"))
|
||||||
|
|
||||||
|
msg_data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "!foo 1 bar",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(msg_data, client_instance=bot)
|
||||||
|
|
||||||
|
ctx = await handler.get_context(msg)
|
||||||
|
assert ctx is not None
|
||||||
|
assert ctx.command.name == "foo"
|
||||||
|
assert ctx.args == [1, "bar"]
|
||||||
|
assert bot.executed is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_get_context():
|
||||||
|
client = Client(token="t")
|
||||||
|
|
||||||
|
async def foo(ctx):
|
||||||
|
raise RuntimeError("should not run")
|
||||||
|
|
||||||
|
client.command_handler.add_command(Command(foo, name="foo"))
|
||||||
|
|
||||||
|
msg_data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "!foo",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(msg_data, client_instance=client)
|
||||||
|
|
||||||
|
ctx = await client.get_context(msg)
|
||||||
|
assert ctx is not None
|
||||||
|
assert ctx.command.name == "foo"
|
126
tests/test_guild_channel_create.py
Normal file
126
tests/test_guild_channel_create.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.http import HTTPClient
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.models import Guild, TextChannel, VoiceChannel, CategoryChannel
|
||||||
|
from disagreement.enums import (
|
||||||
|
VerificationLevel,
|
||||||
|
MessageNotificationLevel,
|
||||||
|
ExplicitContentFilterLevel,
|
||||||
|
MFALevel,
|
||||||
|
GuildNSFWLevel,
|
||||||
|
PremiumTier,
|
||||||
|
ChannelType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _guild_data():
|
||||||
|
return {
|
||||||
|
"id": "1",
|
||||||
|
"name": "g",
|
||||||
|
"owner_id": "1",
|
||||||
|
"afk_timeout": 60,
|
||||||
|
"verification_level": VerificationLevel.NONE.value,
|
||||||
|
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||||
|
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||||
|
"roles": [],
|
||||||
|
"emojis": [],
|
||||||
|
"features": [],
|
||||||
|
"mfa_level": MFALevel.NONE.value,
|
||||||
|
"system_channel_flags": 0,
|
||||||
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_create_guild_channel_calls_request():
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(return_value={})
|
||||||
|
payload = {"name": "chan", "type": ChannelType.GUILD_TEXT.value}
|
||||||
|
|
||||||
|
await http.create_guild_channel("1", payload, reason="r")
|
||||||
|
|
||||||
|
http.request.assert_called_once_with(
|
||||||
|
"POST",
|
||||||
|
"/guilds/1/channels",
|
||||||
|
payload=payload,
|
||||||
|
custom_headers={"X-Audit-Log-Reason": "r"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guild_create_text_channel_returns_channel():
|
||||||
|
http = SimpleNamespace(
|
||||||
|
create_guild_channel=AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "10",
|
||||||
|
"type": ChannelType.GUILD_TEXT.value,
|
||||||
|
"guild_id": "1",
|
||||||
|
"permission_overwrites": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
client = Client(token="t")
|
||||||
|
client._http = http
|
||||||
|
guild = Guild(_guild_data(), client_instance=client)
|
||||||
|
|
||||||
|
channel = await guild.create_text_channel("general")
|
||||||
|
|
||||||
|
http.create_guild_channel.assert_awaited_once_with(
|
||||||
|
"1", {"name": "general", "type": ChannelType.GUILD_TEXT.value}, reason=None
|
||||||
|
)
|
||||||
|
assert isinstance(channel, TextChannel)
|
||||||
|
assert client._channels.get("10") is channel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guild_create_voice_channel_returns_channel():
|
||||||
|
http = SimpleNamespace(
|
||||||
|
create_guild_channel=AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "11",
|
||||||
|
"type": ChannelType.GUILD_VOICE.value,
|
||||||
|
"guild_id": "1",
|
||||||
|
"permission_overwrites": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
client = Client(token="t")
|
||||||
|
client._http = http
|
||||||
|
guild = Guild(_guild_data(), client_instance=client)
|
||||||
|
|
||||||
|
channel = await guild.create_voice_channel("Voice")
|
||||||
|
|
||||||
|
http.create_guild_channel.assert_awaited_once_with(
|
||||||
|
"1", {"name": "Voice", "type": ChannelType.GUILD_VOICE.value}, reason=None
|
||||||
|
)
|
||||||
|
assert isinstance(channel, VoiceChannel)
|
||||||
|
assert client._channels.get("11") is channel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guild_create_category_returns_channel():
|
||||||
|
http = SimpleNamespace(
|
||||||
|
create_guild_channel=AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "12",
|
||||||
|
"type": ChannelType.GUILD_CATEGORY.value,
|
||||||
|
"guild_id": "1",
|
||||||
|
"permission_overwrites": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
client = Client(token="t")
|
||||||
|
client._http = http
|
||||||
|
guild = Guild(_guild_data(), client_instance=client)
|
||||||
|
|
||||||
|
channel = await guild.create_category("Cat")
|
||||||
|
|
||||||
|
http.create_guild_channel.assert_awaited_once_with(
|
||||||
|
"1", {"name": "Cat", "type": ChannelType.GUILD_CATEGORY.value}, reason=None
|
||||||
|
)
|
||||||
|
assert isinstance(channel, CategoryChannel)
|
||||||
|
assert client._channels.get("12") is channel
|
63
tests/test_guild_channel_lists.py
Normal file
63
tests/test_guild_channel_lists.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.enums import (
|
||||||
|
ChannelType,
|
||||||
|
VerificationLevel,
|
||||||
|
MessageNotificationLevel,
|
||||||
|
ExplicitContentFilterLevel,
|
||||||
|
MFALevel,
|
||||||
|
GuildNSFWLevel,
|
||||||
|
PremiumTier,
|
||||||
|
)
|
||||||
|
from disagreement.models import TextChannel, VoiceChannel, CategoryChannel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guild_channel_lists_populated():
|
||||||
|
client = Client(token="t")
|
||||||
|
guild_data = {
|
||||||
|
"id": "1",
|
||||||
|
"name": "g",
|
||||||
|
"owner_id": "1",
|
||||||
|
"afk_timeout": 60,
|
||||||
|
"verification_level": VerificationLevel.NONE.value,
|
||||||
|
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||||
|
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||||
|
"roles": [],
|
||||||
|
"emojis": [],
|
||||||
|
"features": [],
|
||||||
|
"mfa_level": MFALevel.NONE.value,
|
||||||
|
"system_channel_flags": 0,
|
||||||
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
|
"channels": [
|
||||||
|
{
|
||||||
|
"id": "10",
|
||||||
|
"type": ChannelType.GUILD_TEXT.value,
|
||||||
|
"guild_id": "1",
|
||||||
|
"permission_overwrites": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "11",
|
||||||
|
"type": ChannelType.GUILD_VOICE.value,
|
||||||
|
"guild_id": "1",
|
||||||
|
"permission_overwrites": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "12",
|
||||||
|
"type": ChannelType.GUILD_CATEGORY.value,
|
||||||
|
"guild_id": "1",
|
||||||
|
"permission_overwrites": [],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
guild = client.parse_guild(guild_data)
|
||||||
|
|
||||||
|
assert len(guild.text_channels) == 1
|
||||||
|
assert isinstance(guild.text_channels[0], TextChannel)
|
||||||
|
assert len(guild.voice_channels) == 1
|
||||||
|
assert isinstance(guild.voice_channels[0], VoiceChannel)
|
||||||
|
assert len(guild.category_channels) == 1
|
||||||
|
assert isinstance(guild.category_channels[0], CategoryChannel)
|
64
tests/test_guild_prune.py
Normal file
64
tests/test_guild_prune.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.http import HTTPClient
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.enums import (
|
||||||
|
VerificationLevel,
|
||||||
|
MessageNotificationLevel,
|
||||||
|
ExplicitContentFilterLevel,
|
||||||
|
MFALevel,
|
||||||
|
GuildNSFWLevel,
|
||||||
|
PremiumTier,
|
||||||
|
)
|
||||||
|
from disagreement.models import Guild
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_get_guild_prune_count_calls_request():
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(return_value={"pruned": 3})
|
||||||
|
count = await http.get_guild_prune_count("1", days=7)
|
||||||
|
http.request.assert_called_once_with("GET", f"/guilds/1/prune", params={"days": 7})
|
||||||
|
assert count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_begin_guild_prune_calls_request():
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(return_value={"pruned": 2})
|
||||||
|
count = await http.begin_guild_prune("1", days=1, compute_count=True)
|
||||||
|
http.request.assert_called_once_with(
|
||||||
|
"POST",
|
||||||
|
f"/guilds/1/prune",
|
||||||
|
payload={"days": 1, "compute_prune_count": True},
|
||||||
|
)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guild_prune_members_calls_http():
|
||||||
|
http = SimpleNamespace(begin_guild_prune=AsyncMock(return_value=1))
|
||||||
|
client = Client(token="t")
|
||||||
|
client._http = http
|
||||||
|
guild_data = {
|
||||||
|
"id": "1",
|
||||||
|
"name": "g",
|
||||||
|
"owner_id": "1",
|
||||||
|
"afk_timeout": 60,
|
||||||
|
"verification_level": VerificationLevel.NONE.value,
|
||||||
|
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||||
|
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||||
|
"roles": [],
|
||||||
|
"emojis": [],
|
||||||
|
"features": [],
|
||||||
|
"mfa_level": MFALevel.NONE.value,
|
||||||
|
"system_channel_flags": 0,
|
||||||
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
|
}
|
||||||
|
guild = Guild(guild_data, client_instance=client)
|
||||||
|
count = await guild.prune_members(2)
|
||||||
|
http.begin_guild_prune.assert_awaited_once_with("1", days=2, compute_count=True)
|
||||||
|
assert count == 1
|
56
tests/test_guild_shard_property.py
Normal file
56
tests/test_guild_shard_property.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from disagreement.models import Guild
|
||||||
|
from disagreement.enums import (
|
||||||
|
VerificationLevel,
|
||||||
|
MessageNotificationLevel,
|
||||||
|
ExplicitContentFilterLevel,
|
||||||
|
MFALevel,
|
||||||
|
GuildNSFWLevel,
|
||||||
|
PremiumTier,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyShard:
|
||||||
|
def __init__(self, shard_id):
|
||||||
|
self.id = shard_id
|
||||||
|
self.count = 1
|
||||||
|
self.gateway = Mock()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.shards = [DummyShard(0)]
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClient:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _guild_data():
|
||||||
|
return {
|
||||||
|
"id": "1",
|
||||||
|
"name": "g",
|
||||||
|
"owner_id": "1",
|
||||||
|
"afk_timeout": 60,
|
||||||
|
"verification_level": VerificationLevel.NONE.value,
|
||||||
|
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||||
|
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||||
|
"roles": [],
|
||||||
|
"emojis": [],
|
||||||
|
"features": [],
|
||||||
|
"mfa_level": MFALevel.NONE.value,
|
||||||
|
"system_channel_flags": 0,
|
||||||
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
|
"shard_id": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_guild_shard_property():
|
||||||
|
client = DummyClient()
|
||||||
|
client._shard_manager = DummyManager()
|
||||||
|
guild = Guild(_guild_data(), client_instance=client, shard_id=0)
|
||||||
|
assert guild.shard_id == 0
|
||||||
|
assert guild.shard is client._shard_manager.shards[0]
|
86
tests/test_hashable_mixin.py
Normal file
86
tests/test_hashable_mixin.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import types
|
||||||
|
from disagreement.models import User, Guild, Channel, Message
|
||||||
|
from disagreement.enums import (
|
||||||
|
VerificationLevel,
|
||||||
|
MessageNotificationLevel,
|
||||||
|
ExplicitContentFilterLevel,
|
||||||
|
MFALevel,
|
||||||
|
GuildNSFWLevel,
|
||||||
|
PremiumTier,
|
||||||
|
ChannelType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _guild_data(gid="1"):
|
||||||
|
return {
|
||||||
|
"id": gid,
|
||||||
|
"name": "g",
|
||||||
|
"owner_id": gid,
|
||||||
|
"afk_timeout": 60,
|
||||||
|
"verification_level": VerificationLevel.NONE.value,
|
||||||
|
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||||
|
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||||
|
"roles": [],
|
||||||
|
"emojis": [],
|
||||||
|
"features": [],
|
||||||
|
"mfa_level": MFALevel.NONE.value,
|
||||||
|
"system_channel_flags": 0,
|
||||||
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _user(uid="1"):
|
||||||
|
return User({"id": uid, "username": "u", "discriminator": "0001"})
|
||||||
|
|
||||||
|
|
||||||
|
def _message(mid="1"):
|
||||||
|
data = {
|
||||||
|
"id": mid,
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
return Message(data, client_instance=types.SimpleNamespace())
|
||||||
|
|
||||||
|
|
||||||
|
def _channel(cid="1"):
|
||||||
|
data = {"id": cid, "type": ChannelType.GUILD_TEXT.value}
|
||||||
|
return Channel(data, client_instance=types.SimpleNamespace())
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_hash_and_eq():
|
||||||
|
a = _user()
|
||||||
|
b = _user()
|
||||||
|
c = _user("2")
|
||||||
|
assert a == b
|
||||||
|
assert hash(a) == hash(b)
|
||||||
|
assert a != c
|
||||||
|
|
||||||
|
|
||||||
|
def test_guild_hash_and_eq():
|
||||||
|
a = Guild(_guild_data(), client_instance=types.SimpleNamespace())
|
||||||
|
b = Guild(_guild_data(), client_instance=types.SimpleNamespace())
|
||||||
|
c = Guild(_guild_data("2"), client_instance=types.SimpleNamespace())
|
||||||
|
assert a == b
|
||||||
|
assert hash(a) == hash(b)
|
||||||
|
assert a != c
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_hash_and_eq():
|
||||||
|
a = _channel()
|
||||||
|
b = _channel()
|
||||||
|
c = _channel("2")
|
||||||
|
assert a == b
|
||||||
|
assert hash(a) == hash(b)
|
||||||
|
assert a != c
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_hash_and_eq():
|
||||||
|
a = _message()
|
||||||
|
b = _message()
|
||||||
|
c = _message("2")
|
||||||
|
assert a == b
|
||||||
|
assert hash(a) == hash(b)
|
||||||
|
assert a != c
|
@ -1,7 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from disagreement.ext.commands.core import CommandHandler, Command
|
from disagreement.ext import commands
|
||||||
|
from disagreement.ext.commands.core import CommandHandler, Command, Group
|
||||||
from disagreement.models import Message
|
from disagreement.models import Message
|
||||||
|
from disagreement.ext.commands.help import HelpCommand
|
||||||
|
|
||||||
|
|
||||||
class DummyBot:
|
class DummyBot:
|
||||||
@ -13,15 +15,21 @@ class DummyBot:
|
|||||||
return {"id": "1", "channel_id": channel_id, "content": content}
|
return {"id": "1", "channel_id": channel_id, "content": content}
|
||||||
|
|
||||||
|
|
||||||
|
class MyCog(commands.Cog):
|
||||||
|
def __init__(self, client) -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def foo(self, ctx: commands.CommandContext) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_help_lists_commands():
|
async def test_help_lists_commands():
|
||||||
bot = DummyBot()
|
bot = DummyBot()
|
||||||
handler = CommandHandler(client=bot, prefix="!")
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
|
||||||
async def foo(ctx):
|
handler.add_cog(MyCog(bot))
|
||||||
pass
|
|
||||||
|
|
||||||
handler.add_command(Command(foo, name="foo", brief="Foo cmd"))
|
|
||||||
|
|
||||||
msg_data = {
|
msg_data = {
|
||||||
"id": "1",
|
"id": "1",
|
||||||
@ -33,6 +41,7 @@ async def test_help_lists_commands():
|
|||||||
msg = Message(msg_data, client_instance=bot)
|
msg = Message(msg_data, client_instance=bot)
|
||||||
await handler.process_commands(msg)
|
await handler.process_commands(msg)
|
||||||
assert any("foo" in m for m in bot.sent)
|
assert any("foo" in m for m in bot.sent)
|
||||||
|
assert any("MyCog" in m for m in bot.sent)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -55,3 +64,65 @@ async def test_help_specific_command():
|
|||||||
msg = Message(msg_data, client_instance=bot)
|
msg = Message(msg_data, client_instance=bot)
|
||||||
await handler.process_commands(msg)
|
await handler.process_commands(msg)
|
||||||
assert any("Bar desc" in m for m in bot.sent)
|
assert any("Bar desc" in m for m in bot.sent)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomHelp(HelpCommand):
|
||||||
|
async def send_command_help(self, ctx, command):
|
||||||
|
await ctx.send(f"custom {command.name}")
|
||||||
|
|
||||||
|
async def send_group_help(self, ctx, group):
|
||||||
|
await ctx.send(f"group {group.name}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_help_methods():
|
||||||
|
bot = DummyBot()
|
||||||
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
handler.remove_command("help")
|
||||||
|
handler.add_command(CustomHelp(handler))
|
||||||
|
|
||||||
|
async def sub(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
group = Group(sub, name="grp")
|
||||||
|
handler.add_command(group)
|
||||||
|
|
||||||
|
msg_data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "!help grp",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(msg_data, client_instance=bot)
|
||||||
|
await handler.process_commands(msg)
|
||||||
|
assert any("group grp" in m for m in bot.sent)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_help_lists_subcommands():
|
||||||
|
bot = DummyBot()
|
||||||
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
|
||||||
|
async def root(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
group = Group(root, name="root")
|
||||||
|
|
||||||
|
@group.command(name="child")
|
||||||
|
async def child(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
handler.add_command(group)
|
||||||
|
|
||||||
|
msg_data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "!help",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(msg_data, client_instance=bot)
|
||||||
|
await handler.process_commands(msg)
|
||||||
|
assert any("root" in m for m in bot.sent)
|
||||||
|
assert any("child" in m for m in bot.sent)
|
||||||
|
31
tests/test_invite.py
Normal file
31
tests/test_invite.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.http import HTTPClient
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.models import Invite
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_get_invite_calls_request():
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(return_value={"code": "abc"})
|
||||||
|
|
||||||
|
result = await http.get_invite("abc")
|
||||||
|
|
||||||
|
http.request.assert_called_once_with("GET", "/invites/abc")
|
||||||
|
assert result == {"code": "abc"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_fetch_invite_returns_invite():
|
||||||
|
http = SimpleNamespace(get_invite=AsyncMock(return_value={"code": "abc"}))
|
||||||
|
client = Client.__new__(Client)
|
||||||
|
client._http = http
|
||||||
|
client._closed = False
|
||||||
|
|
||||||
|
invite = await client.fetch_invite("abc")
|
||||||
|
|
||||||
|
http.get_invite.assert_awaited_once_with("abc")
|
||||||
|
assert isinstance(invite, Invite)
|
39
tests/test_invites.py
Normal file
39
tests/test_invites.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.http import HTTPClient
|
||||||
|
from disagreement.models import TextChannel, Invite
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_channel_invite_calls_request_and_returns_model():
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(return_value={"code": "abc"})
|
||||||
|
invite = await http.create_channel_invite("123", {"max_age": 60}, reason="r")
|
||||||
|
|
||||||
|
http.request.assert_called_once_with(
|
||||||
|
"POST",
|
||||||
|
"/channels/123/invites",
|
||||||
|
payload={"max_age": 60},
|
||||||
|
custom_headers={"X-Audit-Log-Reason": "r"},
|
||||||
|
)
|
||||||
|
assert isinstance(invite, Invite)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_textchannel_create_invite_uses_http():
|
||||||
|
http = SimpleNamespace(
|
||||||
|
create_channel_invite=AsyncMock(return_value=Invite.from_dict({"code": "a"}))
|
||||||
|
)
|
||||||
|
client = Client(token="t")
|
||||||
|
client._http = http
|
||||||
|
|
||||||
|
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||||
|
invite = await channel.create_invite(max_age=30, reason="why")
|
||||||
|
|
||||||
|
http.create_channel_invite.assert_awaited_once_with(
|
||||||
|
"c", {"max_age": 30}, reason="why"
|
||||||
|
)
|
||||||
|
assert isinstance(invite, Invite)
|
@ -1,4 +1,21 @@
|
|||||||
from disagreement.models import Member
|
import pytest # pylint: disable=E0401
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.enums import (
|
||||||
|
VerificationLevel,
|
||||||
|
MessageNotificationLevel,
|
||||||
|
ExplicitContentFilterLevel,
|
||||||
|
MFALevel,
|
||||||
|
GuildNSFWLevel,
|
||||||
|
PremiumTier,
|
||||||
|
)
|
||||||
|
from disagreement.models import Member, Guild, Role
|
||||||
|
from disagreement.permissions import Permissions
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClient(Client):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(token="test")
|
||||||
|
|
||||||
|
|
||||||
def _make_member(member_id: str, username: str, nick: str | None):
|
def _make_member(member_id: str, username: str, nick: str | None):
|
||||||
@ -12,6 +29,58 @@ def _make_member(member_id: str, username: str, nick: str | None):
|
|||||||
return Member(data, client_instance=None)
|
return Member(data, client_instance=None)
|
||||||
|
|
||||||
|
|
||||||
|
def _base_guild(client: Client) -> Guild:
|
||||||
|
data = {
|
||||||
|
"id": "1",
|
||||||
|
"name": "g",
|
||||||
|
"owner_id": "1",
|
||||||
|
"afk_timeout": 60,
|
||||||
|
"verification_level": VerificationLevel.NONE.value,
|
||||||
|
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||||
|
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||||
|
"roles": [],
|
||||||
|
"emojis": [],
|
||||||
|
"features": [],
|
||||||
|
"mfa_level": MFALevel.NONE.value,
|
||||||
|
"system_channel_flags": 0,
|
||||||
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
|
}
|
||||||
|
guild = Guild(data, client_instance=client)
|
||||||
|
client._guilds.set(guild.id, guild)
|
||||||
|
return guild
|
||||||
|
|
||||||
|
|
||||||
|
def _role(guild: Guild, rid: str, perms: Permissions) -> Role:
|
||||||
|
role = Role(
|
||||||
|
{
|
||||||
|
"id": rid,
|
||||||
|
"name": f"r{rid}",
|
||||||
|
"color": 0,
|
||||||
|
"hoist": False,
|
||||||
|
"position": 0,
|
||||||
|
"permissions": str(int(perms)),
|
||||||
|
"managed": False,
|
||||||
|
"mentionable": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
guild.roles.append(role)
|
||||||
|
return role
|
||||||
|
|
||||||
|
|
||||||
|
def _member(guild: Guild, client: Client, *roles: Role) -> Member:
|
||||||
|
data = {
|
||||||
|
"user": {"id": "10", "username": "u", "discriminator": "0001"},
|
||||||
|
"joined_at": "t",
|
||||||
|
"roles": [r.id for r in roles] or [guild.id],
|
||||||
|
}
|
||||||
|
member = Member(data, client_instance=client)
|
||||||
|
member.guild_id = guild.id
|
||||||
|
member._client = client
|
||||||
|
guild._members.set(member.id, member)
|
||||||
|
return member
|
||||||
|
|
||||||
|
|
||||||
def test_display_name_prefers_nick():
|
def test_display_name_prefers_nick():
|
||||||
member = _make_member("1", "u", "nickname")
|
member = _make_member("1", "u", "nickname")
|
||||||
assert member.display_name == "nickname"
|
assert member.display_name == "nickname"
|
||||||
@ -20,3 +89,25 @@ def test_display_name_prefers_nick():
|
|||||||
def test_display_name_falls_back_to_username():
|
def test_display_name_falls_back_to_username():
|
||||||
member = _make_member("2", "u2", None)
|
member = _make_member("2", "u2", None)
|
||||||
assert member.display_name == "u2"
|
assert member.display_name == "u2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_guild_permissions_from_roles():
|
||||||
|
client = DummyClient()
|
||||||
|
guild = _base_guild(client)
|
||||||
|
everyone = _role(guild, guild.id, Permissions.VIEW_CHANNEL)
|
||||||
|
mod = _role(guild, "2", Permissions.MANAGE_MESSAGES)
|
||||||
|
member = _member(guild, client, everyone, mod)
|
||||||
|
|
||||||
|
perms = member.guild_permissions
|
||||||
|
assert perms & Permissions.VIEW_CHANNEL
|
||||||
|
assert perms & Permissions.MANAGE_MESSAGES
|
||||||
|
assert not perms & Permissions.BAN_MEMBERS
|
||||||
|
|
||||||
|
|
||||||
|
def test_guild_permissions_administrator_role_grants_all():
|
||||||
|
client = DummyClient()
|
||||||
|
guild = _base_guild(client)
|
||||||
|
admin = _role(guild, "2", Permissions.ADMINISTRATOR)
|
||||||
|
member = _member(guild, client, admin)
|
||||||
|
|
||||||
|
assert member.guild_permissions == Permissions(~0)
|
||||||
|
36
tests/test_member_voice.py
Normal file
36
tests/test_member_voice.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from disagreement.models import Member, VoiceState
|
||||||
|
|
||||||
|
|
||||||
|
def test_member_voice_dataclass():
|
||||||
|
data = {
|
||||||
|
"user": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||||
|
"joined_at": "t",
|
||||||
|
"roles": [],
|
||||||
|
"voice_state": {
|
||||||
|
"guild_id": "g",
|
||||||
|
"channel_id": "c",
|
||||||
|
"user_id": "1",
|
||||||
|
"session_id": "s",
|
||||||
|
"deaf": False,
|
||||||
|
"mute": True,
|
||||||
|
"self_deaf": False,
|
||||||
|
"self_mute": False,
|
||||||
|
"self_video": False,
|
||||||
|
"suppress": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
member = Member(data, client_instance=None)
|
||||||
|
voice = member.voice
|
||||||
|
assert isinstance(voice, VoiceState)
|
||||||
|
assert voice.channel_id == "c"
|
||||||
|
assert voice.mute is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_member_voice_none():
|
||||||
|
data = {
|
||||||
|
"user": {"id": "2", "username": "u2", "discriminator": "0001"},
|
||||||
|
"joined_at": "t",
|
||||||
|
"roles": [],
|
||||||
|
}
|
||||||
|
member = Member(data, client_instance=None)
|
||||||
|
assert member.voice is None
|
@ -21,3 +21,19 @@ def test_clean_content_removes_mentions():
|
|||||||
def test_clean_content_no_mentions():
|
def test_clean_content_no_mentions():
|
||||||
msg = make_message("Just text")
|
msg = make_message("Just text")
|
||||||
assert msg.clean_content == "Just text"
|
assert msg.clean_content == "Just text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_created_at_parses_timestamp():
|
||||||
|
ts = "2024-05-04T12:34:56+00:00"
|
||||||
|
msg = make_message("hi")
|
||||||
|
msg.timestamp = ts
|
||||||
|
assert msg.created_at.isoformat() == ts
|
||||||
|
|
||||||
|
|
||||||
|
def test_edited_at_parses_timestamp_or_none():
|
||||||
|
ts = "2024-05-04T12:35:56+00:00"
|
||||||
|
msg = make_message("hi")
|
||||||
|
msg.timestamp = ts
|
||||||
|
assert msg.edited_at is None
|
||||||
|
msg.edited_timestamp = ts
|
||||||
|
assert msg.edited_at.isoformat() == ts
|
||||||
|
15
tests/test_object.py
Normal file
15
tests/test_object.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from disagreement.object import Object
|
||||||
|
|
||||||
|
|
||||||
|
def test_object_int():
|
||||||
|
obj = Object(123)
|
||||||
|
assert int(obj) == 123
|
||||||
|
|
||||||
|
|
||||||
|
def test_object_equality_and_hash():
|
||||||
|
a = Object(1)
|
||||||
|
b = Object(1)
|
||||||
|
c = Object(2)
|
||||||
|
assert a == b
|
||||||
|
assert a != c
|
||||||
|
assert hash(a) == hash(b)
|
23
tests/test_paginator.py
Normal file
23
tests/test_paginator.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from disagreement.utils import Paginator
|
||||||
|
|
||||||
|
|
||||||
|
def test_paginator_single_page():
|
||||||
|
p = Paginator(limit=10)
|
||||||
|
p.add_line("hi")
|
||||||
|
p.add_line("there")
|
||||||
|
assert p.pages == ["hi\nthere"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_paginator_splits_pages():
|
||||||
|
p = Paginator(limit=10)
|
||||||
|
p.add_line("12345")
|
||||||
|
p.add_line("67890")
|
||||||
|
assert p.pages == ["12345", "67890"]
|
||||||
|
p.add_line("xyz")
|
||||||
|
assert p.pages == ["12345", "67890\nxyz"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_paginator_handles_long_line():
|
||||||
|
p = Paginator(limit=5)
|
||||||
|
p.add_line("abcdef")
|
||||||
|
assert p.pages == ["abcde", "f"]
|
@ -32,3 +32,11 @@ def test_missing_permissions():
|
|||||||
current, Permissions.SEND_MESSAGES, Permissions.MANAGE_MESSAGES
|
current, Permissions.SEND_MESSAGES, Permissions.MANAGE_MESSAGES
|
||||||
)
|
)
|
||||||
assert missing == [Permissions.MANAGE_MESSAGES]
|
assert missing == [Permissions.MANAGE_MESSAGES]
|
||||||
|
|
||||||
|
|
||||||
|
def test_permissions_all():
|
||||||
|
all_value = Permissions.all()
|
||||||
|
union = Permissions(0)
|
||||||
|
for perm in Permissions:
|
||||||
|
union |= perm
|
||||||
|
assert all_value == union
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import io
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
@ -38,7 +39,9 @@ async def test_http_send_message_with_files_uses_formdata():
|
|||||||
"timestamp": "t",
|
"timestamp": "t",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await http.send_message("c", "hi", files=[File("f.txt", b"data")])
|
await http.send_message(
|
||||||
|
"c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")]
|
||||||
|
)
|
||||||
args, kwargs = http.request.call_args
|
args, kwargs = http.request.call_args
|
||||||
assert kwargs["is_json"] is False
|
assert kwargs["is_json"] is False
|
||||||
|
|
||||||
@ -75,7 +78,33 @@ async def test_client_send_message_passes_files():
|
|||||||
"timestamp": "t",
|
"timestamp": "t",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await client.send_message("c", "hi", files=[File("f.txt", b"data")])
|
await client.send_message(
|
||||||
|
"c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")]
|
||||||
|
)
|
||||||
client._http.send_message.assert_awaited_once()
|
client._http.send_message.assert_awaited_once()
|
||||||
kwargs = client._http.send_message.call_args.kwargs
|
kwargs = client._http.send_message.call_args.kwargs
|
||||||
assert kwargs["files"][0].filename == "f.txt"
|
assert kwargs["files"][0].filename == "f.txt"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_from_path(tmp_path):
|
||||||
|
file_path = tmp_path / "path.txt"
|
||||||
|
file_path.write_bytes(b"ok")
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await http.send_message("c", "hi", files=[File(file_path)])
|
||||||
|
_, kwargs = http.request.call_args
|
||||||
|
assert kwargs["is_json"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_spoiler():
|
||||||
|
f = File(io.BytesIO(b"d"), filename="a.txt", spoiler=True)
|
||||||
|
assert f.filename == "SPOILER_a.txt"
|
||||||
|
@ -82,3 +82,24 @@ async def test_before_after_loop_callbacks() -> None:
|
|||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert events and events[0] == "before"
|
assert events and events[0] == "before"
|
||||||
assert "after" in events
|
assert "after" in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_interval_and_current_loop() -> None:
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
@tasks.loop(seconds=0.01)
|
||||||
|
async def ticker() -> None:
|
||||||
|
nonlocal count
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
ticker.start()
|
||||||
|
await asyncio.sleep(0.03)
|
||||||
|
initial = ticker.current_loop
|
||||||
|
ticker.change_interval(seconds=0.02)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
ticker.stop()
|
||||||
|
|
||||||
|
assert initial >= 2
|
||||||
|
assert ticker.current_loop > initial
|
||||||
|
assert count == ticker.current_loop
|
||||||
|
@ -1,8 +1,54 @@
|
|||||||
from datetime import timezone
|
from datetime import datetime, timezone
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
from disagreement.utils import utcnow
|
from disagreement.utils import (
|
||||||
|
escape_markdown,
|
||||||
|
escape_mentions,
|
||||||
|
utcnow,
|
||||||
|
snowflake_time,
|
||||||
|
find,
|
||||||
|
get
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_utcnow_timezone():
|
def test_utcnow_timezone():
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
assert now.tzinfo == timezone.utc
|
assert now.tzinfo == timezone.utc
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_returns_matching_element():
|
||||||
|
seq = [1, 2, 3]
|
||||||
|
assert find(lambda x: x > 1, seq) == 2
|
||||||
|
assert find(lambda x: x > 3, seq) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_matches_attributes():
|
||||||
|
items = [
|
||||||
|
SimpleNamespace(id=1, name="a"),
|
||||||
|
SimpleNamespace(id=2, name="b"),
|
||||||
|
]
|
||||||
|
assert get(items, id=2) is items[1]
|
||||||
|
assert get(items, id=1, name="a") is items[0]
|
||||||
|
assert get(items, name="c") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_snowflake_time():
|
||||||
|
dt = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||||
|
ms = int(dt.timestamp() * 1000) - 1420070400000
|
||||||
|
snowflake = ms << 22
|
||||||
|
assert snowflake_time(snowflake) == dt
|
||||||
|
|
||||||
|
|
||||||
|
def test_escape_markdown():
|
||||||
|
text = "**bold** _under_ ~strike~ `code` > quote | pipe"
|
||||||
|
escaped = escape_markdown(text)
|
||||||
|
assert (
|
||||||
|
escaped
|
||||||
|
== "\\*\\*bold\\*\\* \\_under\\_ \\~strike\\~ \\`code\\` \\> quote \\| pipe"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_escape_mentions():
|
||||||
|
text = "Hello @everyone and <@123>!"
|
||||||
|
escaped = escape_mentions(text)
|
||||||
|
assert escaped == "Hello @\u200beveryone and <@\u200b123>!"
|
||||||
|
@ -59,6 +59,17 @@ class DummySource(AudioSource):
|
|||||||
return b""
|
return b""
|
||||||
|
|
||||||
|
|
||||||
|
class SlowSource(AudioSource):
|
||||||
|
def __init__(self, chunks):
|
||||||
|
self.chunks = list(chunks)
|
||||||
|
|
||||||
|
async def read(self) -> bytes:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if self.chunks:
|
||||||
|
return self.chunks.pop(0)
|
||||||
|
return b""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_voice_client_handshake():
|
async def test_voice_client_handshake():
|
||||||
hello = {"d": {"heartbeat_interval": 50}}
|
hello = {"d": {"heartbeat_interval": 50}}
|
||||||
@ -205,3 +216,49 @@ async def test_voice_client_volume_scaling(monkeypatch):
|
|||||||
samples[1] = int(samples[1] * 0.5)
|
samples[1] = int(samples[1] * 0.5)
|
||||||
expected = samples.tobytes()
|
expected = samples.tobytes()
|
||||||
assert udp.sent == [expected]
|
assert udp.sent == [expected]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pause_resume_and_status():
|
||||||
|
ws = DummyWebSocket(
|
||||||
|
[
|
||||||
|
{"d": {"heartbeat_interval": 50}},
|
||||||
|
{"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}},
|
||||||
|
{"d": {"secret_key": []}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
udp = DummyUDP()
|
||||||
|
vc = VoiceClient(
|
||||||
|
client=DummyVoiceClient(),
|
||||||
|
endpoint="ws://localhost",
|
||||||
|
session_id="sess",
|
||||||
|
token="tok",
|
||||||
|
guild_id=1,
|
||||||
|
user_id=2,
|
||||||
|
ws=ws,
|
||||||
|
udp=udp,
|
||||||
|
)
|
||||||
|
await vc.connect()
|
||||||
|
vc._heartbeat_task.cancel()
|
||||||
|
|
||||||
|
src = SlowSource([b"a", b"b", b"c"])
|
||||||
|
await vc.play(src, wait=False)
|
||||||
|
|
||||||
|
while not udp.sent:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert vc.is_playing()
|
||||||
|
vc.pause()
|
||||||
|
assert vc.is_paused()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
sent = len(udp.sent)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert len(udp.sent) == sent
|
||||||
|
assert not vc.is_playing()
|
||||||
|
|
||||||
|
vc.resume()
|
||||||
|
assert not vc.is_paused()
|
||||||
|
|
||||||
|
await vc._play_task
|
||||||
|
assert udp.sent == [b"a", b"b", b"c"]
|
||||||
|
assert not vc.is_playing()
|
||||||
|
35
tests/test_walk_commands.py
Normal file
35
tests/test_walk_commands.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.ext.commands.core import CommandHandler, Command, Group
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBot:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_walk_commands_recurses_groups():
|
||||||
|
bot = DummyBot()
|
||||||
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
|
||||||
|
async def root(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
root_group = Group(root, name="root")
|
||||||
|
|
||||||
|
@root_group.command(name="child")
|
||||||
|
async def child(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@root_group.group(name="sub")
|
||||||
|
async def sub(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@sub.command(name="leaf")
|
||||||
|
async def leaf(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
handler.add_command(root_group)
|
||||||
|
|
||||||
|
names = [cmd.name for cmd in handler.walk_commands()]
|
||||||
|
assert set(names) == {"help", "root", "child", "sub", "leaf"}
|
@ -146,6 +146,16 @@ def test_webhook_from_url_parses_id_and_token():
|
|||||||
assert webhook.url == url
|
assert webhook.url == url
|
||||||
|
|
||||||
|
|
||||||
|
def test_webhook_from_token_builds_url_and_fields():
|
||||||
|
from disagreement.models import Webhook
|
||||||
|
|
||||||
|
webhook = Webhook.from_token("123", "token")
|
||||||
|
|
||||||
|
assert webhook.id == "123"
|
||||||
|
assert webhook.token == "token"
|
||||||
|
assert webhook.url == "https://discord.com/api/webhooks/123/token"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_webhook_calls_request():
|
async def test_execute_webhook_calls_request():
|
||||||
http = HTTPClient(token="t")
|
http = HTTPClient(token="t")
|
||||||
@ -185,3 +195,31 @@ async def test_webhook_send_uses_http():
|
|||||||
|
|
||||||
http.execute_webhook.assert_awaited_once()
|
http.execute_webhook.assert_awaited_once()
|
||||||
assert isinstance(msg, Message)
|
assert isinstance(msg, Message)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_webhook_calls_request():
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(return_value={"id": "1"})
|
||||||
|
|
||||||
|
await http.get_webhook("1")
|
||||||
|
|
||||||
|
http.request.assert_called_once_with("GET", "/webhooks/1", use_auth_header=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_fetch_webhook_returns_model():
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.models import Webhook
|
||||||
|
|
||||||
|
http = SimpleNamespace(get_webhook=AsyncMock(return_value={"id": "1"}))
|
||||||
|
client = Client(token="test")
|
||||||
|
client._http = http
|
||||||
|
client._closed = False
|
||||||
|
|
||||||
|
webhook = await client.fetch_webhook("1")
|
||||||
|
|
||||||
|
http.get_webhook.assert_awaited_once_with("1")
|
||||||
|
assert isinstance(webhook, Webhook)
|
||||||
|
assert client._webhooks.get("1") is webhook
|
||||||
|
Loading…
x
Reference in New Issue
Block a user