Merge branch 'master' of https://github.com/Slipstreamm/disagreement
This commit is contained in:
commit
a702c66603
22
README.md
22
README.md
@ -23,6 +23,13 @@ pip install -e .
|
|||||||
|
|
||||||
Requires Python 3.10 or newer.
|
Requires Python 3.10 or newer.
|
||||||
|
|
||||||
|
To run the example scripts, you'll need the `python-dotenv` package to load
|
||||||
|
environment variables. Install the development extras with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install "disagreement[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
## Basic Usage
|
## Basic Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -102,6 +109,20 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying
|
|||||||
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
|
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
|
||||||
session parameter supported by ``aiohttp``.
|
session parameter supported by ``aiohttp``.
|
||||||
|
|
||||||
|
### Default Allowed Mentions
|
||||||
|
|
||||||
|
Specify default mention behaviour for all outgoing messages when constructing the client:
|
||||||
|
|
||||||
|
```python
|
||||||
|
client = disagreement.Client(
|
||||||
|
token=token,
|
||||||
|
allowed_mentions={"parse": [], "replied_user": False},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This dictionary is used whenever ``send_message`` is called without an explicit
|
||||||
|
``allowed_mentions`` argument.
|
||||||
|
|
||||||
### Defining Subcommands with `AppCommandGroup`
|
### Defining Subcommands with `AppCommandGroup`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -120,6 +141,7 @@ async def show(ctx: AppCommandContext, key: str):
|
|||||||
@slash_command(name="set", description="Update a setting.", parent=admin_group)
|
@slash_command(name="set", description="Update a setting.", parent=admin_group)
|
||||||
async def set_setting(ctx: AppCommandContext, key: str, value: str):
|
async def set_setting(ctx: AppCommandContext, key: str, value: str):
|
||||||
...
|
...
|
||||||
|
```
|
||||||
## Fetching Guilds
|
## Fetching Guilds
|
||||||
|
|
||||||
Use `Client.fetch_guild` to retrieve a guild from the Discord API if it
|
Use `Client.fetch_guild` to retrieve a guild from the Discord API if it
|
||||||
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import io
|
import io
|
||||||
|
import shlex
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
@ -35,15 +36,27 @@ class FFmpegAudioSource(AudioSource):
|
|||||||
A filename, URL, or file-like object to read from.
|
A filename, URL, or file-like object to read from.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, source: Union[str, io.BufferedIOBase]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
source: Union[str, io.BufferedIOBase],
|
||||||
|
*,
|
||||||
|
before_options: Optional[str] = None,
|
||||||
|
options: Optional[str] = None,
|
||||||
|
volume: float = 1.0,
|
||||||
|
):
|
||||||
self.source = source
|
self.source = source
|
||||||
|
self.before_options = before_options
|
||||||
|
self.options = options
|
||||||
|
self.volume = volume
|
||||||
self.process: Optional[asyncio.subprocess.Process] = None
|
self.process: Optional[asyncio.subprocess.Process] = None
|
||||||
self._feeder: Optional[asyncio.Task] = None
|
self._feeder: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
async def _spawn(self) -> None:
|
async def _spawn(self) -> None:
|
||||||
if isinstance(self.source, str):
|
if isinstance(self.source, str):
|
||||||
args = [
|
args = ["ffmpeg"]
|
||||||
"ffmpeg",
|
if self.before_options:
|
||||||
|
args += shlex.split(self.before_options)
|
||||||
|
args += [
|
||||||
"-i",
|
"-i",
|
||||||
self.source,
|
self.source,
|
||||||
"-f",
|
"-f",
|
||||||
@ -54,14 +67,18 @@ class FFmpegAudioSource(AudioSource):
|
|||||||
"2",
|
"2",
|
||||||
"pipe:1",
|
"pipe:1",
|
||||||
]
|
]
|
||||||
|
if self.options:
|
||||||
|
args += shlex.split(self.options)
|
||||||
self.process = await asyncio.create_subprocess_exec(
|
self.process = await asyncio.create_subprocess_exec(
|
||||||
*args,
|
*args,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.DEVNULL,
|
stderr=asyncio.subprocess.DEVNULL,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
args = [
|
args = ["ffmpeg"]
|
||||||
"ffmpeg",
|
if self.before_options:
|
||||||
|
args += shlex.split(self.before_options)
|
||||||
|
args += [
|
||||||
"-i",
|
"-i",
|
||||||
"pipe:0",
|
"pipe:0",
|
||||||
"-f",
|
"-f",
|
||||||
@ -72,6 +89,8 @@ class FFmpegAudioSource(AudioSource):
|
|||||||
"2",
|
"2",
|
||||||
"pipe:1",
|
"pipe:1",
|
||||||
]
|
]
|
||||||
|
if self.options:
|
||||||
|
args += shlex.split(self.options)
|
||||||
self.process = await asyncio.create_subprocess_exec(
|
self.process = await asyncio.create_subprocess_exec(
|
||||||
*args,
|
*args,
|
||||||
stdin=asyncio.subprocess.PIPE,
|
stdin=asyncio.subprocess.PIPE,
|
||||||
@ -115,6 +134,7 @@ class FFmpegAudioSource(AudioSource):
|
|||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self.source.close()
|
self.source.close()
|
||||||
|
|
||||||
|
|
||||||
class AudioSink:
|
class AudioSink:
|
||||||
"""Abstract base class for audio sinks."""
|
"""Abstract base class for audio sinks."""
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .models import Channel, Guild, Member
|
from .models import Channel, Guild, Member
|
||||||
@ -11,15 +12,22 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
class Cache(Generic[T]):
|
class Cache(Generic[T]):
|
||||||
"""Simple in-memory cache with optional TTL support."""
|
"""Simple in-memory cache with optional TTL and max size support."""
|
||||||
|
|
||||||
def __init__(self, ttl: Optional[float] = None) -> None:
|
def __init__(
|
||||||
|
self, ttl: Optional[float] = None, maxlen: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
self.ttl = ttl
|
self.ttl = ttl
|
||||||
self._data: Dict[str, tuple[T, Optional[float]]] = {}
|
self.maxlen = maxlen
|
||||||
|
self._data: "OrderedDict[str, tuple[T, Optional[float]]]" = OrderedDict()
|
||||||
|
|
||||||
def set(self, key: str, value: T) -> None:
|
def set(self, key: str, value: T) -> None:
|
||||||
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
|
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
|
||||||
|
if key in self._data:
|
||||||
|
self._data.move_to_end(key)
|
||||||
self._data[key] = (value, expiry)
|
self._data[key] = (value, expiry)
|
||||||
|
if self.maxlen is not None and len(self._data) > self.maxlen:
|
||||||
|
self._data.popitem(last=False)
|
||||||
|
|
||||||
def get(self, key: str) -> Optional[T]:
|
def get(self, key: str) -> Optional[T]:
|
||||||
item = self._data.get(key)
|
item = self._data.get(key)
|
||||||
@ -29,6 +37,7 @@ class Cache(Generic[T]):
|
|||||||
if expiry is not None and expiry < time.monotonic():
|
if expiry is not None and expiry < time.monotonic():
|
||||||
self.invalidate(key)
|
self.invalidate(key)
|
||||||
return None
|
return None
|
||||||
|
self._data.move_to_end(key)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def invalidate(self, key: str) -> None:
|
def invalidate(self, key: str) -> None:
|
||||||
|
@ -8,10 +8,10 @@ class _MemberCacheFlagValue:
|
|||||||
flag: int
|
flag: int
|
||||||
|
|
||||||
def __init__(self, func: Callable[[Any], bool]):
|
def __init__(self, func: Callable[[Any], bool]):
|
||||||
self.flag = getattr(func, 'flag', 0)
|
self.flag = getattr(func, "flag", 0)
|
||||||
self.__doc__ = func.__doc__
|
self.__doc__ = func.__doc__
|
||||||
|
|
||||||
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any:
|
def __get__(self, instance: "MemberCacheFlags", owner: type) -> Any:
|
||||||
if instance is None:
|
if instance is None:
|
||||||
return self
|
return self
|
||||||
return instance.value & self.flag != 0
|
return instance.value & self.flag != 0
|
||||||
@ -23,23 +23,24 @@ class _MemberCacheFlagValue:
|
|||||||
instance.value &= ~self.flag
|
instance.value &= ~self.flag
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'<{self.__class__.__name__} flag={self.flag}>'
|
return f"<{self.__class__.__name__} flag={self.flag}>"
|
||||||
|
|
||||||
|
|
||||||
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
|
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
|
||||||
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
|
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
|
||||||
setattr(func, 'flag', flag)
|
setattr(func, "flag", flag)
|
||||||
return _MemberCacheFlagValue(func)
|
return _MemberCacheFlagValue(func)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class MemberCacheFlags:
|
class MemberCacheFlags:
|
||||||
__slots__ = ('value',)
|
__slots__ = ("value",)
|
||||||
|
|
||||||
VALID_FLAGS: ClassVar[Dict[str, int]] = {
|
VALID_FLAGS: ClassVar[Dict[str, int]] = {
|
||||||
'joined': 1 << 0,
|
"joined": 1 << 0,
|
||||||
'voice': 1 << 1,
|
"voice": 1 << 1,
|
||||||
'online': 1 << 2,
|
"online": 1 << 2,
|
||||||
}
|
}
|
||||||
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
|
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
|
||||||
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
|
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
|
||||||
@ -48,7 +49,7 @@ class MemberCacheFlags:
|
|||||||
self.value = self.DEFAULT_FLAGS
|
self.value = self.DEFAULT_FLAGS
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key not in self.VALID_FLAGS:
|
if key not in self.VALID_FLAGS:
|
||||||
raise TypeError(f'{key!r} is not a valid member cache flag.')
|
raise TypeError(f"{key!r} is not a valid member cache flag.")
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -67,7 +68,7 @@ class MemberCacheFlags:
|
|||||||
return hash(self.value)
|
return hash(self.value)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'<MemberCacheFlags value={self.value}>'
|
return f"<MemberCacheFlags value={self.value}>"
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Tuple[str, bool]]:
|
def __iter__(self) -> Iterator[Tuple[str, bool]]:
|
||||||
for name in self.VALID_FLAGS:
|
for name in self.VALID_FLAGS:
|
||||||
@ -92,17 +93,17 @@ class MemberCacheFlags:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def only_joined(cls) -> MemberCacheFlags:
|
def only_joined(cls) -> MemberCacheFlags:
|
||||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
|
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
|
||||||
return cls._from_value(cls.VALID_FLAGS['joined'])
|
return cls._from_value(cls.VALID_FLAGS["joined"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def only_voice(cls) -> MemberCacheFlags:
|
def only_voice(cls) -> MemberCacheFlags:
|
||||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
|
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
|
||||||
return cls._from_value(cls.VALID_FLAGS['voice'])
|
return cls._from_value(cls.VALID_FLAGS["voice"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def only_online(cls) -> MemberCacheFlags:
|
def only_online(cls) -> MemberCacheFlags:
|
||||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
|
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
|
||||||
return cls._from_value(cls.VALID_FLAGS['online'])
|
return cls._from_value(cls.VALID_FLAGS["online"])
|
||||||
|
|
||||||
@flag_value(1 << 0)
|
@flag_value(1 << 0)
|
||||||
def joined(self) -> bool:
|
def joined(self) -> bool:
|
||||||
|
@ -36,6 +36,7 @@ from .ext import loader as ext_loader
|
|||||||
from .interactions import Interaction, Snowflake
|
from .interactions import Interaction, Snowflake
|
||||||
from .error_handler import setup_global_error_handler
|
from .error_handler import setup_global_error_handler
|
||||||
from .voice_client import VoiceClient
|
from .voice_client import VoiceClient
|
||||||
|
from .models import Activity
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .models import (
|
from .models import (
|
||||||
@ -75,13 +76,21 @@ class Client:
|
|||||||
intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`.
|
intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`.
|
||||||
You might need to enable privileged intents in your bot's application page.
|
You might need to enable privileged intents in your bot's application page.
|
||||||
loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations.
|
loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations.
|
||||||
Defaults to `asyncio.get_event_loop()`.
|
Defaults to the running loop
|
||||||
|
via `asyncio.get_running_loop()`,
|
||||||
|
or a new loop from
|
||||||
|
`asyncio.new_event_loop()` if
|
||||||
|
none is running.
|
||||||
command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]):
|
command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]):
|
||||||
The prefix(es) for commands. Defaults to '!'.
|
The prefix(es) for commands. Defaults to '!'.
|
||||||
verbose (bool): If True, print raw HTTP and Gateway traffic for debugging.
|
verbose (bool): If True, print raw HTTP and Gateway traffic for debugging.
|
||||||
|
mention_replies (bool): Whether replies mention the author by default.
|
||||||
|
allowed_mentions (Optional[Dict[str, Any]]): Default allowed mentions for messages.
|
||||||
http_options (Optional[Dict[str, Any]]): Extra options passed to
|
http_options (Optional[Dict[str, Any]]): Extra options passed to
|
||||||
:class:`HTTPClient` for creating the internal
|
:class:`HTTPClient` for creating the internal
|
||||||
:class:`aiohttp.ClientSession`.
|
:class:`aiohttp.ClientSession`.
|
||||||
|
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
|
||||||
|
in the cache. When ``None``, the cache size is unlimited.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -95,10 +104,12 @@ class Client:
|
|||||||
application_id: Optional[Union[str, int]] = None,
|
application_id: Optional[Union[str, int]] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
mention_replies: bool = False,
|
mention_replies: bool = False,
|
||||||
|
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||||
shard_count: Optional[int] = None,
|
shard_count: Optional[int] = None,
|
||||||
gateway_max_retries: int = 5,
|
gateway_max_retries: int = 5,
|
||||||
gateway_max_backoff: float = 60.0,
|
gateway_max_backoff: float = 60.0,
|
||||||
member_cache_flags: Optional[MemberCacheFlags] = None,
|
member_cache_flags: Optional[MemberCacheFlags] = None,
|
||||||
|
message_cache_maxlen: Optional[int] = None,
|
||||||
http_options: Optional[Dict[str, Any]] = None,
|
http_options: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
if not token:
|
if not token:
|
||||||
@ -108,6 +119,7 @@ class Client:
|
|||||||
self.member_cache_flags: MemberCacheFlags = (
|
self.member_cache_flags: MemberCacheFlags = (
|
||||||
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
|
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
|
||||||
)
|
)
|
||||||
|
self.message_cache_maxlen: Optional[int] = message_cache_maxlen
|
||||||
self.intents: int = intents if intents is not None else GatewayIntent.default()
|
self.intents: int = intents if intents is not None else GatewayIntent.default()
|
||||||
if loop:
|
if loop:
|
||||||
self.loop: asyncio.AbstractEventLoop = loop
|
self.loop: asyncio.AbstractEventLoop = loop
|
||||||
@ -157,7 +169,7 @@ class Client:
|
|||||||
self._guilds: GuildCache = GuildCache()
|
self._guilds: GuildCache = GuildCache()
|
||||||
self._channels: ChannelCache = ChannelCache()
|
self._channels: ChannelCache = ChannelCache()
|
||||||
self._users: Cache["User"] = Cache()
|
self._users: Cache["User"] = Cache()
|
||||||
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour
|
self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen)
|
||||||
self._views: Dict[Snowflake, "View"] = {}
|
self._views: Dict[Snowflake, "View"] = {}
|
||||||
self._persistent_views: Dict[str, "View"] = {}
|
self._persistent_views: Dict[str, "View"] = {}
|
||||||
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
||||||
@ -165,6 +177,7 @@ class Client:
|
|||||||
|
|
||||||
# Default whether replies mention the user
|
# Default whether replies mention the user
|
||||||
self.mention_replies: bool = mention_replies
|
self.mention_replies: bool = mention_replies
|
||||||
|
self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions
|
||||||
|
|
||||||
# Basic signal handling for graceful shutdown
|
# Basic signal handling for graceful shutdown
|
||||||
# This might be better handled by the user's application code, but can be a nice default.
|
# This might be better handled by the user's application code, but can be a nice default.
|
||||||
@ -435,8 +448,7 @@ class Client:
|
|||||||
async def change_presence(
|
async def change_presence(
|
||||||
self,
|
self,
|
||||||
status: str,
|
status: str,
|
||||||
activity_name: Optional[str] = None,
|
activity: Optional[Activity] = None,
|
||||||
activity_type: int = 0,
|
|
||||||
since: int = 0,
|
since: int = 0,
|
||||||
afk: bool = False,
|
afk: bool = False,
|
||||||
):
|
):
|
||||||
@ -445,8 +457,7 @@ class Client:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible").
|
status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible").
|
||||||
activity_name (Optional[str]): The name of the activity.
|
activity (Optional[Activity]): Activity instance describing what the bot is doing.
|
||||||
activity_type (int): The type of the activity.
|
|
||||||
since (int): The timestamp (in milliseconds) of when the client went idle.
|
since (int): The timestamp (in milliseconds) of when the client went idle.
|
||||||
afk (bool): Whether the client is AFK.
|
afk (bool): Whether the client is AFK.
|
||||||
"""
|
"""
|
||||||
@ -456,8 +467,7 @@ class Client:
|
|||||||
if self._gateway:
|
if self._gateway:
|
||||||
await self._gateway.update_presence(
|
await self._gateway.update_presence(
|
||||||
status=status,
|
status=status,
|
||||||
activity_name=activity_name,
|
activity=activity,
|
||||||
activity_type=activity_type,
|
|
||||||
since=since,
|
since=since,
|
||||||
afk=afk,
|
afk=afk,
|
||||||
)
|
)
|
||||||
@ -693,7 +703,7 @@ class Client:
|
|||||||
)
|
)
|
||||||
# import traceback
|
# import traceback
|
||||||
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
|
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
|
||||||
|
|
||||||
async def on_command_completion(self, ctx: "CommandContext") -> None:
|
async def on_command_completion(self, ctx: "CommandContext") -> None:
|
||||||
"""
|
"""
|
||||||
Default command completion handler. Called when a command has successfully completed.
|
Default command completion handler. Called when a command has successfully completed.
|
||||||
@ -1010,7 +1020,7 @@ class Client:
|
|||||||
embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`.
|
embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`.
|
||||||
Discord supports up to 10 embeds per message.
|
Discord supports up to 10 embeds per message.
|
||||||
components (Optional[List[ActionRow]]): A list of ActionRow components to include.
|
components (Optional[List[ActionRow]]): A list of ActionRow components to include.
|
||||||
allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message.
|
allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. Defaults to :attr:`Client.allowed_mentions`.
|
||||||
message_reference (Optional[Dict[str, Any]]): Message reference for replying.
|
message_reference (Optional[Dict[str, Any]]): Message reference for replying.
|
||||||
attachments (Optional[List[Any]]): Attachments to include with the message.
|
attachments (Optional[List[Any]]): Attachments to include with the message.
|
||||||
files (Optional[List[Any]]): Files to upload with the message.
|
files (Optional[List[Any]]): Files to upload with the message.
|
||||||
@ -1057,6 +1067,9 @@ class Client:
|
|||||||
if isinstance(comp, ComponentModel)
|
if isinstance(comp, ComponentModel)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if allowed_mentions is None:
|
||||||
|
allowed_mentions = self.allowed_mentions
|
||||||
|
|
||||||
message_data = await self._http.send_message(
|
message_data = await self._http.send_message(
|
||||||
channel_id=channel_id,
|
channel_id=channel_id,
|
||||||
content=content,
|
content=content,
|
||||||
@ -1428,6 +1441,24 @@ class Client:
|
|||||||
|
|
||||||
await self._http.delete_guild_template(guild_id, template_code)
|
await self._http.delete_guild_template(guild_id, template_code)
|
||||||
|
|
||||||
|
async def fetch_widget(self, guild_id: Snowflake) -> Dict[str, Any]:
|
||||||
|
"""|coro| Fetch a guild's widget settings."""
|
||||||
|
|
||||||
|
if self._closed:
|
||||||
|
raise DisagreementException("Client is closed.")
|
||||||
|
|
||||||
|
return await self._http.get_guild_widget(guild_id)
|
||||||
|
|
||||||
|
async def edit_widget(
|
||||||
|
self, guild_id: Snowflake, payload: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""|coro| Edit a guild's widget settings."""
|
||||||
|
|
||||||
|
if self._closed:
|
||||||
|
raise DisagreementException("Client is closed.")
|
||||||
|
|
||||||
|
return await self._http.edit_guild_widget(guild_id, payload)
|
||||||
|
|
||||||
async def fetch_scheduled_events(
|
async def fetch_scheduled_events(
|
||||||
self, guild_id: Snowflake
|
self, guild_id: Snowflake
|
||||||
) -> List["ScheduledEvent"]:
|
) -> List["ScheduledEvent"]:
|
||||||
@ -1514,35 +1545,35 @@ class Client:
|
|||||||
return [self.parse_invite(inv) for inv in data]
|
return [self.parse_invite(inv) for inv in data]
|
||||||
|
|
||||||
def add_persistent_view(self, view: "View") -> None:
|
def add_persistent_view(self, view: "View") -> None:
|
||||||
"""
|
"""
|
||||||
Registers a persistent view with the client.
|
Registers a persistent view with the client.
|
||||||
|
|
||||||
Persistent views have a timeout of `None` and their components must have a `custom_id`.
|
Persistent views have a timeout of `None` and their components must have a `custom_id`.
|
||||||
This allows the view to be re-instantiated across bot restarts.
|
This allows the view to be re-instantiated across bot restarts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
view (View): The view instance to register.
|
view (View): The view instance to register.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the view is not persistent (timeout is not None) or if a component's
|
ValueError: If the view is not persistent (timeout is not None) or if a component's
|
||||||
custom_id is already registered.
|
custom_id is already registered.
|
||||||
"""
|
"""
|
||||||
if self.is_ready():
|
if self.is_ready():
|
||||||
print(
|
print(
|
||||||
"Warning: Adding a persistent view after the client is ready. "
|
"Warning: Adding a persistent view after the client is ready. "
|
||||||
"This view will only be available for interactions on this session."
|
"This view will only be available for interactions on this session."
|
||||||
)
|
)
|
||||||
|
|
||||||
if view.timeout is not None:
|
if view.timeout is not None:
|
||||||
raise ValueError("Persistent views must have a timeout of None.")
|
raise ValueError("Persistent views must have a timeout of None.")
|
||||||
|
|
||||||
for item in view.children:
|
for item in view.children:
|
||||||
if item.custom_id: # Ensure custom_id is not None
|
if item.custom_id: # Ensure custom_id is not None
|
||||||
if item.custom_id in self._persistent_views:
|
if item.custom_id in self._persistent_views:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"A component with custom_id '{item.custom_id}' is already registered."
|
f"A component with custom_id '{item.custom_id}' is already registered."
|
||||||
)
|
)
|
||||||
self._persistent_views[item.custom_id] = view
|
self._persistent_views[item.custom_id] = view
|
||||||
|
|
||||||
# --- Application Command Methods ---
|
# --- Application Command Methods ---
|
||||||
async def process_interaction(self, interaction: Interaction) -> None:
|
async def process_interaction(self, interaction: Interaction) -> None:
|
||||||
|
@ -375,6 +375,15 @@ class OverwriteType(IntEnum):
|
|||||||
MEMBER = 1
|
MEMBER = 1
|
||||||
|
|
||||||
|
|
||||||
|
class AutoArchiveDuration(IntEnum):
|
||||||
|
"""Thread auto-archive duration in minutes."""
|
||||||
|
|
||||||
|
HOUR = 60
|
||||||
|
DAY = 1440
|
||||||
|
THREE_DAYS = 4320
|
||||||
|
WEEK = 10080
|
||||||
|
|
||||||
|
|
||||||
# --- Component Enums ---
|
# --- Component Enums ---
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,11 @@ def setup_global_error_handler(
|
|||||||
The handler logs unhandled exceptions so they don't crash the bot.
|
The handler logs unhandled exceptions so they don't crash the bot.
|
||||||
"""
|
"""
|
||||||
if loop is None:
|
if loop is None:
|
||||||
loop = asyncio.get_event_loop()
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
if not logging.getLogger().hasHandlers():
|
if not logging.getLogger().hasHandlers():
|
||||||
setup_logging(logging.ERROR)
|
setup_logging(logging.ERROR)
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
# disagreement/ext/commands/converters.py
|
# disagreement/ext/commands/converters.py
|
||||||
|
# pyright: reportIncompatibleMethodOverride=false
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import re
|
import re
|
||||||
|
import inspect
|
||||||
|
|
||||||
from .errors import BadArgument
|
from .errors import BadArgument
|
||||||
from disagreement.models import Member, Guild, Role
|
from disagreement.models import Member, Guild, Role
|
||||||
@ -36,6 +38,20 @@ class Converter(ABC, Generic[T]):
|
|||||||
raise NotImplementedError("Converter subclass must implement convert method.")
|
raise NotImplementedError("Converter subclass must implement convert method.")
|
||||||
|
|
||||||
|
|
||||||
|
class Greedy(list):
|
||||||
|
"""Type hint helper to greedily consume arguments."""
|
||||||
|
|
||||||
|
converter: Any = None
|
||||||
|
|
||||||
|
def __class_getitem__(cls, param: Any) -> type: # pyright: ignore[override]
|
||||||
|
if isinstance(param, tuple):
|
||||||
|
if len(param) != 1:
|
||||||
|
raise TypeError("Greedy[...] expects a single parameter")
|
||||||
|
param = param[0]
|
||||||
|
name = f"Greedy[{getattr(param, '__name__', str(param))}]"
|
||||||
|
return type(name, (Greedy,), {"converter": param})
|
||||||
|
|
||||||
|
|
||||||
# --- Built-in Type Converters ---
|
# --- Built-in Type Converters ---
|
||||||
|
|
||||||
|
|
||||||
@ -169,7 +185,3 @@ async def run_converters(ctx: "CommandContext", annotation: Any, argument: str)
|
|||||||
raise BadArgument(f"No converter found for type annotation '{annotation}'.")
|
raise BadArgument(f"No converter found for type annotation '{annotation}'.")
|
||||||
|
|
||||||
return argument # Default to string if no annotation or annotation is str
|
return argument # Default to string if no annotation or annotation is str
|
||||||
|
|
||||||
|
|
||||||
# Need to import inspect for the run_converters function
|
|
||||||
import inspect
|
|
||||||
|
@ -29,7 +29,7 @@ from .errors import (
|
|||||||
CheckFailure,
|
CheckFailure,
|
||||||
CommandInvokeError,
|
CommandInvokeError,
|
||||||
)
|
)
|
||||||
from .converters import run_converters, DEFAULT_CONVERTERS, Converter
|
from .converters import Greedy, run_converters, DEFAULT_CONVERTERS, Converter
|
||||||
from disagreement.typing import Typing
|
from disagreement.typing import Typing
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -46,29 +46,39 @@ class GroupMixin:
|
|||||||
self.commands: Dict[str, "Command"] = {}
|
self.commands: Dict[str, "Command"] = {}
|
||||||
self.name: str = ""
|
self.name: str = ""
|
||||||
|
|
||||||
def command(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
|
def command(
|
||||||
|
self, **attrs: Any
|
||||||
|
) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
|
||||||
def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
|
def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
|
||||||
cmd = Command(func, **attrs)
|
cmd = Command(func, **attrs)
|
||||||
cmd.cog = getattr(self, "cog", None)
|
cmd.cog = getattr(self, "cog", None)
|
||||||
self.add_command(cmd)
|
self.add_command(cmd)
|
||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def group(self, **attrs: Any) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
|
def group(
|
||||||
|
self, **attrs: Any
|
||||||
|
) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
|
||||||
def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
|
def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
|
||||||
cmd = Group(func, **attrs)
|
cmd = Group(func, **attrs)
|
||||||
cmd.cog = getattr(self, "cog", None)
|
cmd.cog = getattr(self, "cog", None)
|
||||||
self.add_command(cmd)
|
self.add_command(cmd)
|
||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def add_command(self, command: "Command") -> None:
|
def add_command(self, command: "Command") -> None:
|
||||||
if command.name in self.commands:
|
if command.name in self.commands:
|
||||||
raise ValueError(f"Command '{command.name}' is already registered in group '{self.name}'.")
|
raise ValueError(
|
||||||
|
f"Command '{command.name}' is already registered in group '{self.name}'."
|
||||||
|
)
|
||||||
self.commands[command.name.lower()] = command
|
self.commands[command.name.lower()] = command
|
||||||
for alias in command.aliases:
|
for alias in command.aliases:
|
||||||
if alias in self.commands:
|
if alias in self.commands:
|
||||||
logger.warning(f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias.")
|
logger.warning(
|
||||||
|
f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias."
|
||||||
|
)
|
||||||
self.commands[alias.lower()] = command
|
self.commands[alias.lower()] = command
|
||||||
|
|
||||||
def get_command(self, name: str) -> Optional["Command"]:
|
def get_command(self, name: str) -> Optional["Command"]:
|
||||||
@ -181,6 +191,7 @@ class Command(GroupMixin):
|
|||||||
|
|
||||||
class Group(Command):
|
class Group(Command):
|
||||||
"""A command that can have subcommands."""
|
"""A command that can have subcommands."""
|
||||||
|
|
||||||
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
|
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
|
||||||
super().__init__(callback, **attrs)
|
super().__init__(callback, **attrs)
|
||||||
|
|
||||||
@ -494,7 +505,34 @@ class CommandHandler:
|
|||||||
None # Holds the raw string for current param
|
None # Holds the raw string for current param
|
||||||
)
|
)
|
||||||
|
|
||||||
if view.eof: # No more input string
|
annotation = param.annotation
|
||||||
|
if inspect.isclass(annotation) and issubclass(annotation, Greedy):
|
||||||
|
greedy_values = []
|
||||||
|
converter_type = annotation.converter
|
||||||
|
while not view.eof:
|
||||||
|
view.skip_whitespace()
|
||||||
|
if view.eof:
|
||||||
|
break
|
||||||
|
start = view.index
|
||||||
|
if view.buffer[view.index] == '"':
|
||||||
|
arg_str_value = view.get_quoted_string()
|
||||||
|
if arg_str_value == "" and view.buffer[view.index] == '"':
|
||||||
|
raise BadArgument(
|
||||||
|
f"Unterminated quoted string for argument '{param.name}'."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arg_str_value = view.get_word()
|
||||||
|
try:
|
||||||
|
converted = await run_converters(
|
||||||
|
ctx, converter_type, arg_str_value
|
||||||
|
)
|
||||||
|
except BadArgument:
|
||||||
|
view.index = start
|
||||||
|
break
|
||||||
|
greedy_values.append(converted)
|
||||||
|
final_value_for_param = greedy_values
|
||||||
|
arg_str_value = None
|
||||||
|
elif view.eof: # No more input string
|
||||||
if param.default is not inspect.Parameter.empty:
|
if param.default is not inspect.Parameter.empty:
|
||||||
final_value_for_param = param.default
|
final_value_for_param = param.default
|
||||||
elif param.kind != inspect.Parameter.VAR_KEYWORD:
|
elif param.kind != inspect.Parameter.VAR_KEYWORD:
|
||||||
@ -656,7 +694,9 @@ class CommandHandler:
|
|||||||
elif command.invoke_without_command:
|
elif command.invoke_without_command:
|
||||||
view.index -= len(potential_subcommand) + view.previous
|
view.index -= len(potential_subcommand) + view.previous
|
||||||
else:
|
else:
|
||||||
raise CommandNotFound(f"Subcommand '{potential_subcommand}' not found.")
|
raise CommandNotFound(
|
||||||
|
f"Subcommand '{potential_subcommand}' not found."
|
||||||
|
)
|
||||||
|
|
||||||
ctx = CommandContext(
|
ctx = CommandContext(
|
||||||
message=message,
|
message=message,
|
||||||
@ -681,7 +721,9 @@ class CommandHandler:
|
|||||||
if hasattr(self.client, "on_command_error"):
|
if hasattr(self.client, "on_command_error"):
|
||||||
await self.client.on_command_error(ctx, e)
|
await self.client.on_command_error(ctx, e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Unexpected error invoking command '%s': %s", original_command.name, e)
|
logger.error(
|
||||||
|
"Unexpected error invoking command '%s': %s", original_command.name, e
|
||||||
|
)
|
||||||
exc = CommandInvokeError(e)
|
exc = CommandInvokeError(e)
|
||||||
if hasattr(self.client, "on_command_error"):
|
if hasattr(self.client, "on_command_error"):
|
||||||
await self.client.on_command_error(ctx, exc)
|
await self.client.on_command_error(ctx, exc)
|
||||||
|
@ -218,6 +218,7 @@ def requires_permissions(
|
|||||||
|
|
||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def has_role(
|
def has_role(
|
||||||
name_or_id: str | int,
|
name_or_id: str | int,
|
||||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
@ -241,9 +242,7 @@ def has_role(
|
|||||||
raise CheckFailure("Could not resolve author to a guild member.")
|
raise CheckFailure("Could not resolve author to a guild member.")
|
||||||
|
|
||||||
# Create a list of the member's role objects by looking them up in the guild's roles list
|
# Create a list of the member's role objects by looking them up in the guild's roles list
|
||||||
member_roles = [
|
member_roles = [role for role in ctx.guild.roles if role.id in author.roles]
|
||||||
role for role in ctx.guild.roles if role.id in author.roles
|
|
||||||
]
|
|
||||||
|
|
||||||
if any(
|
if any(
|
||||||
role.id == str(name_or_id) or role.name == name_or_id
|
role.id == str(name_or_id) or role.name == name_or_id
|
||||||
@ -278,9 +277,7 @@ def has_any_role(
|
|||||||
if not author:
|
if not author:
|
||||||
raise CheckFailure("Could not resolve author to a guild member.")
|
raise CheckFailure("Could not resolve author to a guild member.")
|
||||||
|
|
||||||
member_roles = [
|
member_roles = [role for role in ctx.guild.roles if role.id in author.roles]
|
||||||
role for role in ctx.guild.roles if role.id in author.roles
|
|
||||||
]
|
|
||||||
# Convert names_or_ids to a set for efficient lookup
|
# Convert names_or_ids to a set for efficient lookup
|
||||||
names_or_ids_set = set(map(str, names_or_ids))
|
names_or_ids_set = set(map(str, names_or_ids))
|
||||||
|
|
||||||
|
@ -14,6 +14,8 @@ import time
|
|||||||
import random
|
import random
|
||||||
from typing import Optional, TYPE_CHECKING, Any, Dict
|
from typing import Optional, TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
|
from .models import Activity
|
||||||
|
|
||||||
from .enums import GatewayOpcode, GatewayIntent
|
from .enums import GatewayOpcode, GatewayIntent
|
||||||
from .errors import GatewayException, DisagreementException, AuthenticationError
|
from .errors import GatewayException, DisagreementException, AuthenticationError
|
||||||
from .interactions import Interaction
|
from .interactions import Interaction
|
||||||
@ -63,7 +65,11 @@ class GatewayClient:
|
|||||||
self._max_backoff: float = max_backoff
|
self._max_backoff: float = max_backoff
|
||||||
|
|
||||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||||
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
try:
|
||||||
|
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
self._heartbeat_interval: Optional[float] = None
|
self._heartbeat_interval: Optional[float] = None
|
||||||
self._last_sequence: Optional[int] = None
|
self._last_sequence: Optional[int] = None
|
||||||
self._session_id: Optional[str] = None
|
self._session_id: Optional[str] = None
|
||||||
@ -213,26 +219,17 @@ class GatewayClient:
|
|||||||
async def update_presence(
|
async def update_presence(
|
||||||
self,
|
self,
|
||||||
status: str,
|
status: str,
|
||||||
activity_name: Optional[str] = None,
|
activity: Optional[Activity] = None,
|
||||||
activity_type: int = 0,
|
*,
|
||||||
since: int = 0,
|
since: int = 0,
|
||||||
afk: bool = False,
|
afk: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""Sends the presence update payload to the Gateway."""
|
"""Sends the presence update payload to the Gateway."""
|
||||||
payload = {
|
payload = {
|
||||||
"op": GatewayOpcode.PRESENCE_UPDATE,
|
"op": GatewayOpcode.PRESENCE_UPDATE,
|
||||||
"d": {
|
"d": {
|
||||||
"since": since,
|
"since": since,
|
||||||
"activities": (
|
"activities": [activity.to_dict()] if activity else [],
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": activity_name,
|
|
||||||
"type": activity_type,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
if activity_name
|
|
||||||
else []
|
|
||||||
),
|
|
||||||
"status": status,
|
"status": status,
|
||||||
"afk": afk,
|
"afk": afk,
|
||||||
},
|
},
|
||||||
@ -353,7 +350,10 @@ class GatewayClient:
|
|||||||
future._members.extend(raw_event_d_payload.get("members", [])) # type: ignore
|
future._members.extend(raw_event_d_payload.get("members", [])) # type: ignore
|
||||||
|
|
||||||
# If this is the last chunk, resolve the future
|
# If this is the last chunk, resolve the future
|
||||||
if raw_event_d_payload.get("chunk_index") == raw_event_d_payload.get("chunk_count", 1) - 1:
|
if (
|
||||||
|
raw_event_d_payload.get("chunk_index")
|
||||||
|
== raw_event_d_payload.get("chunk_count", 1) - 1
|
||||||
|
):
|
||||||
future.set_result(future._members) # type: ignore
|
future.set_result(future._members) # type: ignore
|
||||||
del self._member_chunk_requests[nonce]
|
del self._member_chunk_requests[nonce]
|
||||||
|
|
||||||
|
@ -601,18 +601,18 @@ class HTTPClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def delete_user_reaction(
|
async def delete_user_reaction(
|
||||||
self,
|
self,
|
||||||
channel_id: "Snowflake",
|
channel_id: "Snowflake",
|
||||||
message_id: "Snowflake",
|
message_id: "Snowflake",
|
||||||
emoji: str,
|
emoji: str,
|
||||||
user_id: "Snowflake",
|
user_id: "Snowflake",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Removes another user's reaction from a message."""
|
"""Removes another user's reaction from a message."""
|
||||||
encoded = quote(emoji)
|
encoded = quote(emoji)
|
||||||
await self.request(
|
await self.request(
|
||||||
"DELETE",
|
"DELETE",
|
||||||
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}",
|
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/{user_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_reactions(
|
async def get_reactions(
|
||||||
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
|
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
|
||||||
@ -910,6 +910,20 @@ class HTTPClient:
|
|||||||
"""Fetches a guild object for a given guild ID."""
|
"""Fetches a guild object for a given guild ID."""
|
||||||
return await self.request("GET", f"/guilds/{guild_id}")
|
return await self.request("GET", f"/guilds/{guild_id}")
|
||||||
|
|
||||||
|
async def get_guild_widget(self, guild_id: "Snowflake") -> Dict[str, Any]:
|
||||||
|
"""Fetches the guild widget settings."""
|
||||||
|
|
||||||
|
return await self.request("GET", f"/guilds/{guild_id}/widget")
|
||||||
|
|
||||||
|
async def edit_guild_widget(
|
||||||
|
self, guild_id: "Snowflake", payload: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Edits the guild widget settings."""
|
||||||
|
|
||||||
|
return await self.request(
|
||||||
|
"PATCH", f"/guilds/{guild_id}/widget", payload=payload
|
||||||
|
)
|
||||||
|
|
||||||
async def get_guild_templates(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
|
async def get_guild_templates(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
|
||||||
"""Fetches all templates for the given guild."""
|
"""Fetches all templates for the given guild."""
|
||||||
return await self.request("GET", f"/guilds/{guild_id}/templates")
|
return await self.request("GET", f"/guilds/{guild_id}/templates")
|
||||||
|
@ -6,6 +6,7 @@ Data models for Discord objects.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||||
|
|
||||||
@ -24,6 +25,7 @@ from .enums import ( # These enums will need to be defined in disagreement/enum
|
|||||||
PremiumTier,
|
PremiumTier,
|
||||||
GuildFeature,
|
GuildFeature,
|
||||||
ChannelType,
|
ChannelType,
|
||||||
|
AutoArchiveDuration,
|
||||||
ComponentType,
|
ComponentType,
|
||||||
ButtonStyle, # Added for Button
|
ButtonStyle, # Added for Button
|
||||||
GuildScheduledEventPrivacyLevel,
|
GuildScheduledEventPrivacyLevel,
|
||||||
@ -39,6 +41,7 @@ if TYPE_CHECKING:
|
|||||||
from .enums import OverwriteType # For PermissionOverwrite model
|
from .enums import OverwriteType # For PermissionOverwrite model
|
||||||
from .ui.view import View
|
from .ui.view import View
|
||||||
from .interactions import Snowflake
|
from .interactions import Snowflake
|
||||||
|
from .typing import Typing
|
||||||
|
|
||||||
# Forward reference Message if it were used in type hints before its definition
|
# Forward reference Message if it were used in type hints before its definition
|
||||||
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
|
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
|
||||||
@ -114,31 +117,39 @@ class Message:
|
|||||||
# self.mention_roles: List[str] = data.get("mention_roles", [])
|
# self.mention_roles: List[str] = data.get("mention_roles", [])
|
||||||
# self.mention_everyone: bool = data.get("mention_everyone", False)
|
# self.mention_everyone: bool = data.get("mention_everyone", False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def clean_content(self) -> str:
|
||||||
|
"""Returns message content without user, role, or channel mentions."""
|
||||||
|
|
||||||
|
pattern = re.compile(r"<@!?\d+>|<#\d+>|<@&\d+>")
|
||||||
|
cleaned = pattern.sub("", self.content)
|
||||||
|
return " ".join(cleaned.split())
|
||||||
|
|
||||||
async def pin(self) -> None:
|
async def pin(self) -> None:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
Pins this message to its channel.
|
Pins this message to its channel.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
HTTPException
|
HTTPException
|
||||||
Pinning the message failed.
|
Pinning the message failed.
|
||||||
"""
|
"""
|
||||||
await self._client._http.pin_message(self.channel_id, self.id)
|
await self._client._http.pin_message(self.channel_id, self.id)
|
||||||
self.pinned = True
|
self.pinned = True
|
||||||
|
|
||||||
async def unpin(self) -> None:
|
async def unpin(self) -> None:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
Unpins this message from its channel.
|
Unpins this message from its channel.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
HTTPException
|
HTTPException
|
||||||
Unpinning the message failed.
|
Unpinning the message failed.
|
||||||
"""
|
"""
|
||||||
await self._client._http.unpin_message(self.channel_id, self.id)
|
await self._client._http.unpin_message(self.channel_id, self.id)
|
||||||
self.pinned = False
|
self.pinned = False
|
||||||
|
|
||||||
async def reply(
|
async def reply(
|
||||||
self,
|
self,
|
||||||
@ -241,16 +252,16 @@ class Message:
|
|||||||
await self._client.add_reaction(self.channel_id, self.id, emoji)
|
await self._client.add_reaction(self.channel_id, self.id, emoji)
|
||||||
|
|
||||||
async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None:
|
async def remove_reaction(self, emoji: str, member: Optional[User] = None) -> None:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
Removes a reaction from this message.
|
Removes a reaction from this message.
|
||||||
If no ``member`` is provided, removes the bot's own reaction.
|
If no ``member`` is provided, removes the bot's own reaction.
|
||||||
"""
|
"""
|
||||||
if member:
|
if member:
|
||||||
await self._client._http.delete_user_reaction(
|
await self._client._http.delete_user_reaction(
|
||||||
self.channel_id, self.id, emoji, member.id
|
self.channel_id, self.id, emoji, member.id
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._client.remove_reaction(self.channel_id, self.id, emoji)
|
await self._client.remove_reaction(self.channel_id, self.id, emoji)
|
||||||
|
|
||||||
async def clear_reactions(self) -> None:
|
async def clear_reactions(self) -> None:
|
||||||
"""|coro| Remove all reactions from this message."""
|
"""|coro| Remove all reactions from this message."""
|
||||||
@ -280,7 +291,7 @@ class Message:
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
*,
|
*,
|
||||||
auto_archive_duration: Optional[int] = None,
|
auto_archive_duration: Optional[AutoArchiveDuration] = None,
|
||||||
rate_limit_per_user: Optional[int] = None,
|
rate_limit_per_user: Optional[int] = None,
|
||||||
reason: Optional[str] = None,
|
reason: Optional[str] = None,
|
||||||
) -> "Thread":
|
) -> "Thread":
|
||||||
@ -292,9 +303,9 @@ class Message:
|
|||||||
----------
|
----------
|
||||||
name: str
|
name: str
|
||||||
The name of the thread.
|
The name of the thread.
|
||||||
auto_archive_duration: Optional[int]
|
auto_archive_duration: Optional[AutoArchiveDuration]
|
||||||
The duration in minutes to automatically archive the thread after recent activity.
|
How long before the thread is automatically archived after recent activity.
|
||||||
Can be one of 60, 1440, 4320, 10080.
|
See :class:`AutoArchiveDuration` for allowed values.
|
||||||
rate_limit_per_user: Optional[int]
|
rate_limit_per_user: Optional[int]
|
||||||
The number of seconds a user has to wait before sending another message.
|
The number of seconds a user has to wait before sending another message.
|
||||||
reason: Optional[str]
|
reason: Optional[str]
|
||||||
@ -307,7 +318,7 @@ class Message:
|
|||||||
"""
|
"""
|
||||||
payload: Dict[str, Any] = {"name": name}
|
payload: Dict[str, Any] = {"name": name}
|
||||||
if auto_archive_duration is not None:
|
if auto_archive_duration is not None:
|
||||||
payload["auto_archive_duration"] = auto_archive_duration
|
payload["auto_archive_duration"] = int(auto_archive_duration)
|
||||||
if rate_limit_per_user is not None:
|
if rate_limit_per_user is not None:
|
||||||
payload["rate_limit_per_user"] = rate_limit_per_user
|
payload["rate_limit_per_user"] = rate_limit_per_user
|
||||||
|
|
||||||
@ -530,8 +541,42 @@ class Embed:
|
|||||||
payload["fields"] = [f.to_dict() for f in self.fields]
|
payload["fields"] = [f.to_dict() for f in self.fields]
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
# Convenience methods for building embeds can be added here
|
# Convenience methods mirroring ``discord.py``'s ``Embed`` API
|
||||||
# e.g., set_author, add_field, set_footer, set_image, etc.
|
|
||||||
|
def set_author(
|
||||||
|
self, *, name: str, url: Optional[str] = None, icon_url: Optional[str] = None
|
||||||
|
) -> "Embed":
|
||||||
|
"""Set the embed author and return ``self`` for chaining."""
|
||||||
|
|
||||||
|
data: Dict[str, Any] = {"name": name}
|
||||||
|
if url:
|
||||||
|
data["url"] = url
|
||||||
|
if icon_url:
|
||||||
|
data["icon_url"] = icon_url
|
||||||
|
self.author = EmbedAuthor(data)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_field(self, *, name: str, value: str, inline: bool = False) -> "Embed":
|
||||||
|
"""Add a field to the embed."""
|
||||||
|
|
||||||
|
field = EmbedField({"name": name, "value": value, "inline": inline})
|
||||||
|
self.fields.append(field)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_footer(self, *, text: str, icon_url: Optional[str] = None) -> "Embed":
|
||||||
|
"""Set the embed footer."""
|
||||||
|
|
||||||
|
data: Dict[str, Any] = {"text": text}
|
||||||
|
if icon_url:
|
||||||
|
data["icon_url"] = icon_url
|
||||||
|
self.footer = EmbedFooter(data)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_image(self, url: str) -> "Embed":
|
||||||
|
"""Set the embed image."""
|
||||||
|
|
||||||
|
self.image = EmbedImage({"url": url})
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class Attachment:
|
class Attachment:
|
||||||
@ -1088,7 +1133,9 @@ class Guild:
|
|||||||
|
|
||||||
# Internal caches, populated by events or specific fetches
|
# Internal caches, populated by events or specific fetches
|
||||||
self._channels: ChannelCache = ChannelCache()
|
self._channels: ChannelCache = ChannelCache()
|
||||||
self._members: MemberCache = MemberCache(getattr(client_instance, "member_cache_flags", MemberCacheFlags()))
|
self._members: MemberCache = MemberCache(
|
||||||
|
getattr(client_instance, "member_cache_flags", MemberCacheFlags())
|
||||||
|
)
|
||||||
self._threads: Dict[str, "Thread"] = {}
|
self._threads: Dict[str, "Thread"] = {}
|
||||||
|
|
||||||
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
||||||
@ -1128,6 +1175,16 @@ class Guild:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Guild id='{self.id}' name='{self.name}'>"
|
return f"<Guild id='{self.id}' name='{self.name}'>"
|
||||||
|
|
||||||
|
async def fetch_widget(self) -> Dict[str, Any]:
|
||||||
|
"""|coro| Fetch this guild's widget settings."""
|
||||||
|
|
||||||
|
return await self._client.fetch_widget(self.id)
|
||||||
|
|
||||||
|
async def edit_widget(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""|coro| Edit this guild's widget settings."""
|
||||||
|
|
||||||
|
return await self._client.edit_widget(self.id, payload)
|
||||||
|
|
||||||
async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]:
|
async def fetch_members(self, *, limit: Optional[int] = None) -> List["Member"]:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
@ -1278,7 +1335,45 @@ class Channel:
|
|||||||
return base
|
return base
|
||||||
|
|
||||||
|
|
||||||
class TextChannel(Channel):
|
class Messageable:
|
||||||
|
"""Mixin for channels that can send messages and show typing."""
|
||||||
|
|
||||||
|
_client: "Client"
|
||||||
|
id: str
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
embed: Optional["Embed"] = None,
|
||||||
|
embeds: Optional[List["Embed"]] = None,
|
||||||
|
components: Optional[List["ActionRow"]] = None,
|
||||||
|
) -> "Message":
|
||||||
|
if not hasattr(self._client, "send_message"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Client.send_message is required for Messageable.send"
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self._client.send_message(
|
||||||
|
channel_id=self.id,
|
||||||
|
content=content,
|
||||||
|
embed=embed,
|
||||||
|
embeds=embeds,
|
||||||
|
components=components,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def trigger_typing(self) -> None:
|
||||||
|
await self._client._http.trigger_typing(self.id)
|
||||||
|
|
||||||
|
def typing(self) -> "Typing":
|
||||||
|
if not hasattr(self._client, "typing"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Client.typing is required for Messageable.typing"
|
||||||
|
)
|
||||||
|
return self._client.typing(self.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TextChannel(Channel, Messageable):
|
||||||
"""Represents a guild text channel or announcement channel."""
|
"""Represents a guild text channel or announcement channel."""
|
||||||
|
|
||||||
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
||||||
@ -1304,27 +1399,6 @@ class TextChannel(Channel):
|
|||||||
|
|
||||||
return message_pager(self, limit=limit, before=before, after=after)
|
return message_pager(self, limit=limit, before=before, after=after)
|
||||||
|
|
||||||
async def send(
|
|
||||||
self,
|
|
||||||
content: Optional[str] = None,
|
|
||||||
*,
|
|
||||||
embed: Optional[Embed] = None,
|
|
||||||
embeds: Optional[List[Embed]] = None,
|
|
||||||
components: Optional[List["ActionRow"]] = None, # Added components
|
|
||||||
) -> "Message": # Forward reference Message
|
|
||||||
if not hasattr(self._client, "send_message"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Client.send_message is required for TextChannel.send"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self._client.send_message(
|
|
||||||
channel_id=self.id,
|
|
||||||
content=content,
|
|
||||||
embed=embed,
|
|
||||||
embeds=embeds,
|
|
||||||
components=components,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def purge(
|
async def purge(
|
||||||
self, limit: int, *, before: "Snowflake | None" = None
|
self, limit: int, *, before: "Snowflake | None" = None
|
||||||
) -> List["Snowflake"]:
|
) -> List["Snowflake"]:
|
||||||
@ -1347,41 +1421,41 @@ class TextChannel(Channel):
|
|||||||
return ids
|
return ids
|
||||||
|
|
||||||
def get_partial_message(self, id: int) -> "PartialMessage":
|
def get_partial_message(self, id: int) -> "PartialMessage":
|
||||||
"""Returns a :class:`PartialMessage` for the given ID.
|
"""Returns a :class:`PartialMessage` for the given ID.
|
||||||
|
|
||||||
This allows performing actions on a message without fetching it first.
|
This allows performing actions on a message without fetching it first.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
id: int
|
id: int
|
||||||
The ID of the message to get a partial instance of.
|
The ID of the message to get a partial instance of.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
PartialMessage
|
PartialMessage
|
||||||
The partial message instance.
|
The partial message instance.
|
||||||
"""
|
"""
|
||||||
return PartialMessage(id=str(id), channel=self)
|
return PartialMessage(id=str(id), channel=self)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
|
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
|
||||||
|
|
||||||
async def pins(self) -> List["Message"]:
|
async def pins(self) -> List["Message"]:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
Fetches all pinned messages in this channel.
|
Fetches all pinned messages in this channel.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List[Message]
|
List[Message]
|
||||||
The pinned messages.
|
The pinned messages.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
HTTPException
|
HTTPException
|
||||||
Fetching the pinned messages failed.
|
Fetching the pinned messages failed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages_data = await self._client._http.get_pinned_messages(self.id)
|
messages_data = await self._client._http.get_pinned_messages(self.id)
|
||||||
return [self._client.parse_message(m) for m in messages_data]
|
return [self._client.parse_message(m) for m in messages_data]
|
||||||
|
|
||||||
@ -1390,7 +1464,7 @@ class TextChannel(Channel):
|
|||||||
name: str,
|
name: str,
|
||||||
*,
|
*,
|
||||||
type: ChannelType = ChannelType.PUBLIC_THREAD,
|
type: ChannelType = ChannelType.PUBLIC_THREAD,
|
||||||
auto_archive_duration: Optional[int] = None,
|
auto_archive_duration: Optional[AutoArchiveDuration] = None,
|
||||||
invitable: Optional[bool] = None,
|
invitable: Optional[bool] = None,
|
||||||
rate_limit_per_user: Optional[int] = None,
|
rate_limit_per_user: Optional[int] = None,
|
||||||
reason: Optional[str] = None,
|
reason: Optional[str] = None,
|
||||||
@ -1406,8 +1480,8 @@ class TextChannel(Channel):
|
|||||||
type: ChannelType
|
type: ChannelType
|
||||||
The type of thread to create. Defaults to PUBLIC_THREAD.
|
The type of thread to create. Defaults to PUBLIC_THREAD.
|
||||||
Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD.
|
Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD.
|
||||||
auto_archive_duration: Optional[int]
|
auto_archive_duration: Optional[AutoArchiveDuration]
|
||||||
The duration in minutes to automatically archive the thread after recent activity.
|
How long before the thread is automatically archived after recent activity.
|
||||||
invitable: Optional[bool]
|
invitable: Optional[bool]
|
||||||
Whether non-moderators can invite other non-moderators to a private thread.
|
Whether non-moderators can invite other non-moderators to a private thread.
|
||||||
Only applicable to private threads.
|
Only applicable to private threads.
|
||||||
@ -1426,7 +1500,7 @@ class TextChannel(Channel):
|
|||||||
"type": type.value,
|
"type": type.value,
|
||||||
}
|
}
|
||||||
if auto_archive_duration is not None:
|
if auto_archive_duration is not None:
|
||||||
payload["auto_archive_duration"] = auto_archive_duration
|
payload["auto_archive_duration"] = int(auto_archive_duration)
|
||||||
if invitable is not None and type == ChannelType.PRIVATE_THREAD:
|
if invitable is not None and type == ChannelType.PRIVATE_THREAD:
|
||||||
payload["invitable"] = invitable
|
payload["invitable"] = invitable
|
||||||
if rate_limit_per_user is not None:
|
if rate_limit_per_user is not None:
|
||||||
@ -1606,7 +1680,9 @@ class Thread(TextChannel): # Threads are a specialized TextChannel
|
|||||||
"""
|
"""
|
||||||
await self._client._http.leave_thread(self.id)
|
await self._client._http.leave_thread(self.id)
|
||||||
|
|
||||||
async def archive(self, locked: bool = False, *, reason: Optional[str] = None) -> "Thread":
|
async def archive(
|
||||||
|
self, locked: bool = False, *, reason: Optional[str] = None
|
||||||
|
) -> "Thread":
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
Archives this thread.
|
Archives this thread.
|
||||||
@ -1631,7 +1707,7 @@ class Thread(TextChannel): # Threads are a specialized TextChannel
|
|||||||
return cast("Thread", self._client.parse_channel(data))
|
return cast("Thread", self._client.parse_channel(data))
|
||||||
|
|
||||||
|
|
||||||
class DMChannel(Channel):
|
class DMChannel(Channel, Messageable):
|
||||||
"""Represents a Direct Message channel."""
|
"""Represents a Direct Message channel."""
|
||||||
|
|
||||||
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
||||||
@ -1645,27 +1721,6 @@ class DMChannel(Channel):
|
|||||||
def recipient(self) -> Optional[User]:
|
def recipient(self) -> Optional[User]:
|
||||||
return self.recipients[0] if self.recipients else None
|
return self.recipients[0] if self.recipients else None
|
||||||
|
|
||||||
async def send(
|
|
||||||
self,
|
|
||||||
content: Optional[str] = None,
|
|
||||||
*,
|
|
||||||
embed: Optional[Embed] = None,
|
|
||||||
embeds: Optional[List[Embed]] = None,
|
|
||||||
components: Optional[List["ActionRow"]] = None, # Added components
|
|
||||||
) -> "Message":
|
|
||||||
if not hasattr(self._client, "send_message"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Client.send_message is required for DMChannel.send"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self._client.send_message(
|
|
||||||
channel_id=self.id,
|
|
||||||
content=content,
|
|
||||||
embed=embed,
|
|
||||||
embeds=embeds,
|
|
||||||
components=components,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def history(
|
async def history(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -2356,6 +2411,37 @@ class ThreadMember:
|
|||||||
return f"<ThreadMember user_id='{self.user_id}' thread_id='{self.id}'>"
|
return f"<ThreadMember user_id='{self.user_id}' thread_id='{self.id}'>"
|
||||||
|
|
||||||
|
|
||||||
|
class Activity:
|
||||||
|
"""Represents a user's presence activity."""
|
||||||
|
|
||||||
|
def __init__(self, name: str, type: int) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.type = type
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {"name": self.name, "type": self.type}
|
||||||
|
|
||||||
|
|
||||||
|
class Game(Activity):
|
||||||
|
"""Represents a playing activity."""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
super().__init__(name, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class Streaming(Activity):
|
||||||
|
"""Represents a streaming activity."""
|
||||||
|
|
||||||
|
def __init__(self, name: str, url: str) -> None:
|
||||||
|
super().__init__(name, 1)
|
||||||
|
self.url = url
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
payload = super().to_dict()
|
||||||
|
payload["url"] = self.url
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
class PresenceUpdate:
|
class PresenceUpdate:
|
||||||
"""Represents a PRESENCE_UPDATE event."""
|
"""Represents a PRESENCE_UPDATE event."""
|
||||||
|
|
||||||
@ -2366,7 +2452,17 @@ class PresenceUpdate:
|
|||||||
self.user = User(data["user"])
|
self.user = User(data["user"])
|
||||||
self.guild_id: Optional[str] = data.get("guild_id")
|
self.guild_id: Optional[str] = data.get("guild_id")
|
||||||
self.status: Optional[str] = data.get("status")
|
self.status: Optional[str] = data.get("status")
|
||||||
self.activities: List[Dict[str, Any]] = data.get("activities", [])
|
self.activities: List[Activity] = []
|
||||||
|
for activity in data.get("activities", []):
|
||||||
|
act_type = activity.get("type", 0)
|
||||||
|
name = activity.get("name", "")
|
||||||
|
if act_type == 0:
|
||||||
|
obj = Game(name)
|
||||||
|
elif act_type == 1:
|
||||||
|
obj = Streaming(name, activity.get("url", ""))
|
||||||
|
else:
|
||||||
|
obj = Activity(name, act_type)
|
||||||
|
self.activities.append(obj)
|
||||||
self.client_status: Dict[str, Any] = data.get("client_status", {})
|
self.client_status: Dict[str, Any] = data.get("client_status", {})
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
@ -72,7 +72,7 @@ class View:
|
|||||||
rows: List[ActionRow] = []
|
rows: List[ActionRow] = []
|
||||||
|
|
||||||
for item in self.children:
|
for item in self.children:
|
||||||
rows.append(ActionRow(components=[item]))
|
rows.append(ActionRow(components=[item]))
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
|
@ -7,9 +7,26 @@ import asyncio
|
|||||||
import contextlib
|
import contextlib
|
||||||
import socket
|
import socket
|
||||||
import threading
|
import threading
|
||||||
|
from array import array
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_volume(data: bytes, volume: float) -> bytes:
|
||||||
|
samples = array("h")
|
||||||
|
samples.frombytes(data)
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
scaled = int(sample * volume)
|
||||||
|
if scaled > 32767:
|
||||||
|
scaled = 32767
|
||||||
|
elif scaled < -32768:
|
||||||
|
scaled = -32768
|
||||||
|
samples[i] = scaled
|
||||||
|
return samples.tobytes()
|
||||||
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, Sequence
|
from typing import TYPE_CHECKING, Optional, Sequence
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
# The following import is correct, but may be flagged by Pylance if the virtual
|
# The following import is correct, but may be flagged by Pylance if the virtual
|
||||||
# environment is not configured correctly.
|
# environment is not configured correctly.
|
||||||
from nacl.secret import SecretBox
|
from nacl.secret import SecretBox
|
||||||
@ -180,6 +197,9 @@ class VoiceClient:
|
|||||||
data = await self._current_source.read()
|
data = await self._current_source.read()
|
||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
|
volume = getattr(self._current_source, "volume", 1.0)
|
||||||
|
if volume != 1.0:
|
||||||
|
data = _apply_volume(data, volume)
|
||||||
await self.send_audio_frame(data)
|
await self.send_audio_frame(data)
|
||||||
finally:
|
finally:
|
||||||
await self._current_source.close()
|
await self._current_source.close()
|
||||||
|
22
docs/embeds.md
Normal file
22
docs/embeds.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Embeds
|
||||||
|
|
||||||
|
`Embed` objects can be constructed piece by piece much like in `discord.py`.
|
||||||
|
These helper methods return the embed instance so you can chain calls.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import Embed
|
||||||
|
|
||||||
|
embed = (
|
||||||
|
Embed()
|
||||||
|
.set_author(name="Disagreement", url="https://example.com", icon_url="https://cdn.example.com/bot.png")
|
||||||
|
.add_field(name="Info", value="Some details")
|
||||||
|
.set_footer(text="Made with Disagreement")
|
||||||
|
.set_image(url="https://cdn.example.com/image.png")
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Call `to_dict()` to convert the embed back to a payload dictionary before sending:
|
||||||
|
|
||||||
|
```python
|
||||||
|
payload = embed.to_dict()
|
||||||
|
```
|
23
docs/mentions.md
Normal file
23
docs/mentions.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# Controlling Mentions
|
||||||
|
|
||||||
|
The client exposes settings to control how mentions behave in outgoing messages.
|
||||||
|
|
||||||
|
## Default Allowed Mentions
|
||||||
|
|
||||||
|
Use the ``allowed_mentions`` parameter of :class:`disagreement.Client` to set a
|
||||||
|
default for all messages:
|
||||||
|
|
||||||
|
```python
|
||||||
|
client = disagreement.Client(
|
||||||
|
token="YOUR_TOKEN",
|
||||||
|
allowed_mentions={"parse": [], "replied_user": False},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
When ``Client.send_message`` is called without an explicit ``allowed_mentions``
|
||||||
|
argument this value will be used.
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
- [Commands](commands.md)
|
||||||
|
- [HTTP Client Options](http_client.md)
|
@ -1,6 +1,7 @@
|
|||||||
# Updating Presence
|
# Updating Presence
|
||||||
|
|
||||||
The `Client.change_presence` method allows you to update the bot's status and displayed activity.
|
The `Client.change_presence` method allows you to update the bot's status and displayed activity.
|
||||||
|
Pass an :class:`~disagreement.models.Activity` (such as :class:`~disagreement.models.Game` or :class:`~disagreement.models.Streaming`) to describe what your bot is doing.
|
||||||
|
|
||||||
## Status Strings
|
## Status Strings
|
||||||
|
|
||||||
@ -22,8 +23,18 @@ An activity dictionary must include a `name` and a `type` field. The type value
|
|||||||
| `4` | Custom |
|
| `4` | Custom |
|
||||||
| `5` | Competing |
|
| `5` | Competing |
|
||||||
|
|
||||||
Example:
|
Example using the provided activity classes:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
await client.change_presence(status="idle", activity={"name": "with Discord", "type": 0})
|
from disagreement.models import Game
|
||||||
|
|
||||||
|
await client.change_presence(status="idle", activity=Game("with Discord"))
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also specify a streaming URL:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import Streaming
|
||||||
|
|
||||||
|
await client.change_presence(status="online", activity=Streaming("My Stream", "https://twitch.tv/someone"))
|
||||||
```
|
```
|
||||||
|
18
docs/threads.md
Normal file
18
docs/threads.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Threads
|
||||||
|
|
||||||
|
`Message.create_thread` and `TextChannel.create_thread` let you start new threads.
|
||||||
|
Use :class:`AutoArchiveDuration` to control when a thread is automatically archived.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.enums import AutoArchiveDuration
|
||||||
|
|
||||||
|
await message.create_thread(
|
||||||
|
"discussion",
|
||||||
|
auto_archive_duration=AutoArchiveDuration.DAY,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
- [Message History](message_history.md)
|
||||||
|
- [Caching](caching.md)
|
@ -39,9 +39,14 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
load_dotenv()
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
# Optional: Configure logging for more insight, especially for gateway events
|
# Optional: Configure logging for more insight, especially for gateway events
|
||||||
# logging.basicConfig(level=logging.DEBUG) # For very verbose output
|
# logging.basicConfig(level=logging.DEBUG) # For very verbose output
|
||||||
|
@ -37,9 +37,15 @@ from disagreement.interactions import (
|
|||||||
InteractionResponsePayload,
|
InteractionResponsePayload,
|
||||||
InteractionCallbackData,
|
InteractionCallbackData,
|
||||||
)
|
)
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
# Get the bot token and application ID from the environment variables
|
# Get the bot token and application ID from the environment variables
|
||||||
token = os.getenv("DISCORD_BOT_TOKEN")
|
token = os.getenv("DISCORD_BOT_TOKEN")
|
||||||
|
@ -15,9 +15,14 @@ from disagreement.ext.app_commands import (
|
|||||||
)
|
)
|
||||||
from disagreement.models import User, Message
|
from disagreement.models import User, Message
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
load_dotenv()
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
|
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
|
||||||
APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "")
|
APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "")
|
||||||
|
@ -4,7 +4,11 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
# Allow running from the examples folder without installing
|
# Allow running from the examples folder without installing
|
||||||
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||||
@ -12,7 +16,8 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
|
|||||||
|
|
||||||
from disagreement import Client
|
from disagreement import Client
|
||||||
|
|
||||||
load_dotenv()
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||||
|
|
||||||
|
@ -36,9 +36,14 @@ from disagreement.enums import (
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
load_dotenv()
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
# --- Define a Test Cog ---
|
# --- Define a Test Cog ---
|
||||||
|
@ -10,9 +10,15 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
|
|||||||
|
|
||||||
from disagreement.client import Client
|
from disagreement.client import Client
|
||||||
from disagreement.models import TextChannel
|
from disagreement.models import TextChannel
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
|
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
|
||||||
CHANNEL_ID = os.environ.get("DISCORD_CHANNEL_ID", "")
|
CHANNEL_ID = os.environ.get("DISCORD_CHANNEL_ID", "")
|
||||||
|
@ -2,14 +2,20 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
from disagreement import Client, ui
|
from disagreement import Client, ui
|
||||||
from disagreement.enums import GatewayIntent, TextInputStyle
|
from disagreement.enums import GatewayIntent, TextInputStyle
|
||||||
from disagreement.ext.app_commands.decorators import slash_command
|
from disagreement.ext.app_commands.decorators import slash_command
|
||||||
from disagreement.ext.app_commands.context import AppCommandContext
|
from disagreement.ext.app_commands.context import AppCommandContext
|
||||||
|
|
||||||
load_dotenv()
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
token = os.getenv("DISCORD_BOT_TOKEN", "")
|
token = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||||
application_id = os.getenv("DISCORD_APPLICATION_ID", "")
|
application_id = os.getenv("DISCORD_APPLICATION_ID", "")
|
||||||
|
@ -3,7 +3,11 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
@ -11,7 +15,8 @@ from disagreement import Client, GatewayIntent, ui # type: ignore
|
|||||||
from disagreement.ext.app_commands.decorators import slash_command
|
from disagreement.ext.app_commands.decorators import slash_command
|
||||||
from disagreement.ext.app_commands.context import AppCommandContext
|
from disagreement.ext.app_commands.context import AppCommandContext
|
||||||
|
|
||||||
load_dotenv()
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
|
TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||||
APP_ID = os.getenv("DISCORD_APPLICATION_ID", "")
|
APP_ID = os.getenv("DISCORD_APPLICATION_ID", "")
|
||||||
|
|
||||||
|
@ -9,9 +9,15 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
|
|||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
import disagreement
|
import disagreement
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||||
if not TOKEN:
|
if not TOKEN:
|
||||||
|
@ -10,11 +10,16 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
|
|||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
except ImportError: # pragma: no cover - example helper
|
||||||
|
load_dotenv = None
|
||||||
|
print("python-dotenv is not installed. Environment variables will not be loaded")
|
||||||
|
|
||||||
import disagreement
|
import disagreement
|
||||||
|
|
||||||
load_dotenv()
|
if load_dotenv:
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
|
_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
|
||||||
_GUILD_ID = os.getenv("DISCORD_GUILD_ID")
|
_GUILD_ID = os.getenv("DISCORD_GUILD_ID")
|
||||||
|
@ -15,3 +15,14 @@ def test_cache_ttl_expiry():
|
|||||||
assert cache.get("b") == 1
|
assert cache.get("b") == 1
|
||||||
time.sleep(0.02)
|
time.sleep(0.02)
|
||||||
assert cache.get("b") is None
|
assert cache.get("b") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_lru_eviction():
|
||||||
|
cache = Cache(maxlen=2)
|
||||||
|
cache.set("a", 1)
|
||||||
|
cache.set("b", 2)
|
||||||
|
assert cache.get("a") == 1
|
||||||
|
cache.set("c", 3)
|
||||||
|
assert cache.get("b") is None
|
||||||
|
assert cache.get("a") == 1
|
||||||
|
assert cache.get("c") == 3
|
||||||
|
23
tests/test_client_message_cache.py
Normal file
23
tests/test_client_message_cache.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
def _add_message(client: Client, message_id: str) -> None:
|
||||||
|
data = {
|
||||||
|
"id": message_id,
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "u", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
client.parse_message(data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_message_cache_size():
|
||||||
|
client = Client(token="t", message_cache_maxlen=1)
|
||||||
|
_add_message(client, "1")
|
||||||
|
assert client._messages.get("1").id == "1"
|
||||||
|
_add_message(client, "2")
|
||||||
|
assert client._messages.get("1") is None
|
||||||
|
assert client._messages.get("2").id == "2"
|
18
tests/test_embed_methods.py
Normal file
18
tests/test_embed_methods.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from disagreement.models import Embed
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_helper_methods():
|
||||||
|
embed = (
|
||||||
|
Embed()
|
||||||
|
.set_author(name="name", url="url", icon_url="icon")
|
||||||
|
.add_field(name="n", value="v")
|
||||||
|
.set_footer(text="footer", icon_url="icon")
|
||||||
|
.set_image(url="https://example.com/image.png")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embed.author.name == "name"
|
||||||
|
assert embed.author.url == "url"
|
||||||
|
assert embed.author.icon_url == "icon"
|
||||||
|
assert len(embed.fields) == 1 and embed.fields[0].name == "n"
|
||||||
|
assert embed.footer.text == "footer"
|
||||||
|
assert embed.image.url == "https://example.com/image.png"
|
@ -24,7 +24,7 @@ class DummyDispatcher:
|
|||||||
|
|
||||||
class DummyClient:
|
class DummyClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_running_loop()
|
||||||
self.application_id = None # Mock application_id for Client.connect
|
self.application_id = None # Mock application_id for Client.connect
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ async def test_client_connect_backoff(monkeypatch):
|
|||||||
client = Client(
|
client = Client(
|
||||||
token="test_token",
|
token="test_token",
|
||||||
intents=0,
|
intents=0,
|
||||||
loop=asyncio.get_event_loop(),
|
loop=asyncio.get_running_loop(),
|
||||||
command_prefix="!",
|
command_prefix="!",
|
||||||
verbose=False,
|
verbose=False,
|
||||||
mention_replies=False,
|
mention_replies=False,
|
||||||
|
23
tests/test_message_clean_content.py
Normal file
23
tests/test_message_clean_content.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import types
|
||||||
|
from disagreement.models import Message
|
||||||
|
|
||||||
|
|
||||||
|
def make_message(content: str) -> Message:
|
||||||
|
data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": content,
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
return Message(data, client_instance=types.SimpleNamespace())
|
||||||
|
|
||||||
|
|
||||||
|
def test_clean_content_removes_mentions():
|
||||||
|
msg = make_message("Hello <@123> <#456> <@&789> world")
|
||||||
|
assert msg.clean_content == "Hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_clean_content_no_mentions():
|
||||||
|
msg = make_message("Just text")
|
||||||
|
assert msg.clean_content == "Just text"
|
@ -2,6 +2,7 @@ import pytest
|
|||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
from disagreement.client import Client
|
from disagreement.client import Client
|
||||||
|
from disagreement.models import Game
|
||||||
from disagreement.errors import DisagreementException
|
from disagreement.errors import DisagreementException
|
||||||
|
|
||||||
|
|
||||||
@ -18,11 +19,11 @@ class DummyGateway(MagicMock):
|
|||||||
async def test_change_presence_passes_arguments():
|
async def test_change_presence_passes_arguments():
|
||||||
client = Client(token="t")
|
client = Client(token="t")
|
||||||
client._gateway = DummyGateway()
|
client._gateway = DummyGateway()
|
||||||
|
game = Game("hi")
|
||||||
await client.change_presence(status="idle", activity_name="hi", activity_type=0)
|
await client.change_presence(status="idle", activity=game)
|
||||||
|
|
||||||
client._gateway.update_presence.assert_awaited_once_with(
|
client._gateway.update_presence.assert_awaited_once_with(
|
||||||
status="idle", activity_name="hi", activity_type=0, since=0, afk=False
|
status="idle", activity=game, since=0, afk=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import io
|
||||||
|
from array import array
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.audio import AudioSource, FFmpegAudioSource
|
||||||
|
|
||||||
from disagreement.voice_client import VoiceClient
|
from disagreement.voice_client import VoiceClient
|
||||||
from disagreement.audio import AudioSource
|
|
||||||
from disagreement.client import Client
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
@ -137,3 +140,68 @@ async def test_play_and_switch_sources():
|
|||||||
await vc.play(DummySource([b"c"]))
|
await vc.play(DummySource([b"c"]))
|
||||||
|
|
||||||
assert udp.sent == [b"a", b"b", b"c"]
|
assert udp.sent == [b"a", b"b", b"c"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ffmpeg_source_custom_options(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
def __init__(self):
|
||||||
|
self.stdout = io.BytesIO(b"")
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def fake_exec(*args, **kwargs):
|
||||||
|
captured["args"] = args
|
||||||
|
return DummyProcess()
|
||||||
|
|
||||||
|
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_exec)
|
||||||
|
src = FFmpegAudioSource(
|
||||||
|
"file.mp3", before_options="-reconnect 1", options="-vn", volume=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
await src._spawn()
|
||||||
|
|
||||||
|
cmd = captured["args"]
|
||||||
|
assert "-reconnect" in cmd
|
||||||
|
assert "-vn" in cmd
|
||||||
|
assert src.volume == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_voice_client_volume_scaling(monkeypatch):
|
||||||
|
ws = DummyWebSocket(
|
||||||
|
[
|
||||||
|
{"d": {"heartbeat_interval": 50}},
|
||||||
|
{"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}},
|
||||||
|
{"d": {"secret_key": []}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
udp = DummyUDP()
|
||||||
|
vc = VoiceClient(
|
||||||
|
client=DummyVoiceClient(),
|
||||||
|
endpoint="ws://localhost",
|
||||||
|
session_id="sess",
|
||||||
|
token="tok",
|
||||||
|
guild_id=1,
|
||||||
|
user_id=2,
|
||||||
|
ws=ws,
|
||||||
|
udp=udp,
|
||||||
|
)
|
||||||
|
await vc.connect()
|
||||||
|
vc._heartbeat_task.cancel()
|
||||||
|
|
||||||
|
chunk = b"\x10\x00\x10\x00"
|
||||||
|
src = DummySource([chunk])
|
||||||
|
src.volume = 0.5
|
||||||
|
|
||||||
|
await vc.play(src)
|
||||||
|
|
||||||
|
samples = array("h")
|
||||||
|
samples.frombytes(chunk)
|
||||||
|
samples[0] = int(samples[0] * 0.5)
|
||||||
|
samples[1] = int(samples[1] * 0.5)
|
||||||
|
expected = samples.tobytes()
|
||||||
|
assert udp.sent == [expected]
|
||||||
|
50
tests/test_widget.py
Normal file
50
tests/test_widget.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.http import HTTPClient
|
||||||
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_guild_widget_calls_request():
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(return_value={})
|
||||||
|
await http.get_guild_widget("1")
|
||||||
|
http.request.assert_called_once_with("GET", "/guilds/1/widget")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_edit_guild_widget_calls_request():
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
http.request = AsyncMock(return_value={})
|
||||||
|
payload = {"enabled": True}
|
||||||
|
await http.edit_guild_widget("1", payload)
|
||||||
|
http.request.assert_called_once_with("PATCH", "/guilds/1/widget", payload=payload)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_fetch_widget_returns_data():
|
||||||
|
http = SimpleNamespace(get_guild_widget=AsyncMock(return_value={"enabled": True}))
|
||||||
|
client = Client.__new__(Client)
|
||||||
|
client._http = http
|
||||||
|
client._closed = False
|
||||||
|
|
||||||
|
data = await client.fetch_widget("1")
|
||||||
|
|
||||||
|
http.get_guild_widget.assert_awaited_once_with("1")
|
||||||
|
assert data == {"enabled": True}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_edit_widget_returns_data():
|
||||||
|
http = SimpleNamespace(edit_guild_widget=AsyncMock(return_value={"enabled": False}))
|
||||||
|
client = Client.__new__(Client)
|
||||||
|
client._http = http
|
||||||
|
client._closed = False
|
||||||
|
|
||||||
|
payload = {"enabled": False}
|
||||||
|
data = await client.edit_widget("1", payload)
|
||||||
|
|
||||||
|
http.edit_guild_widget.assert_awaited_once_with("1", payload)
|
||||||
|
assert data == {"enabled": False}
|
Loading…
x
Reference in New Issue
Block a user