Compare commits

...

80 Commits

Author SHA1 Message Date
380feddeeb
Merge branch 'master' of https://github.com/Slipstreamm/disagreement 2025-06-15 20:53:38 -06:00
3beaed8a1b
Testing commit signing 2025-06-15 20:50:52 -06:00
e5ad932321
Add recursive command enumeration (#115) 2025-06-15 20:46:20 -06:00
8e88aaec2f
Implement channel and member aggregation (#117) 2025-06-15 20:42:21 -06:00
d710487fc2
Add voice playback control (#111)
Some checks are pending
Deploy MkDocs / deploy (push) Waiting to run
2025-06-15 20:39:30 -06:00
506adeca20
Add channel and emoji converters (#112) 2025-06-15 20:39:28 -06:00
e2061adc55
Add logout method (#114) 2025-06-15 20:39:26 -06:00
132521fa39
Add Object class and partial docs (#113)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 20:39:23 -06:00
cec747a575
Improve help command (#116)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 20:39:20 -06:00
17751d3b09
Add HashableById mixin and tests (#118) 2025-06-15 20:39:16 -06:00
4b3b6aeb45
Add Asset model and avatar helpers (#119) 2025-06-15 20:39:14 -06:00
aa55aa1d4c
feat: persist views (#120)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 20:39:12 -06:00
80f64c1f73
Add guild_permissions property and tests (#97) 2025-06-15 18:55:52 -06:00
f5f8f6908c
Add get and find helpers (#103) 2025-06-15 18:55:17 -06:00
3437050f0e
Add snowflake_time utility (#105)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 18:52:38 -06:00
87d67eb63b
Add invite fetching support (#95) 2025-06-15 18:50:06 -06:00
9c10ab0f70
Add timestamp datetime properties (#96) 2025-06-15 18:50:03 -06:00
2008dd33d1
Add webhook retrieval (#98) 2025-06-15 18:49:59 -06:00
de40aa2c29
Add UserConverter and tests (#99) 2025-06-15 18:49:57 -06:00
2056a3ddcf
Add VoiceState dataclass and member voice property (#100) 2025-06-15 18:49:55 -06:00
ccf55adba2
Add channel invite creation (#101) 2025-06-15 18:49:53 -06:00
a335ed972c
Add guild channel creation methods (#102) 2025-06-15 18:49:51 -06:00
2586d3cd0d
Add message crosspost support (#104) 2025-06-15 18:49:48 -06:00
7f9647a442
feat(tasks): allow runtime interval change (#106) 2025-06-15 18:49:45 -06:00
a222dec661
Add Permissions.all convenience (#107)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 18:49:43 -06:00
3f7c286322
Add guild prune support (#108) 2025-06-15 18:49:41 -06:00
cc17d11509
Add markdown and mention escaping helpers (#109) 2025-06-15 18:49:39 -06:00
9fabf1fbac
Add Webhook.from_token and fetch support (#110) 2025-06-15 18:49:37 -06:00
223c86cb78
Add cog retrieval helpers (#89)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 18:15:45 -06:00
98afb89629
Add Guild.me property (#90)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 18:13:06 -06:00
095e7e7192
Add get_context parsing for commands (#91) 2025-06-15 18:12:47 -06:00
c1c5cfb41a
Add Paginator utility and tests (#92)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 18:12:10 -06:00
8be234c1f0
Support async teardown in extension loader (#93) 2025-06-15 18:11:53 -06:00
1464937f6f
Add async start and sync run (#94)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 18:11:44 -06:00
5d66eb79cc chore(client): Remove merge conflict marker 2025-06-15 15:28:55 -06:00
5d72643390
Add is_owner decorator and owner checks (#81) 2025-06-15 15:23:52 -06:00
c7eb8563de
Add channel lists to Guild (#82) 2025-06-15 15:22:49 -06:00
a68bbe7826
Add guilds property to client (#83) 2025-06-15 15:21:01 -06:00
6eff962682
feat: dispatch connect events (#80)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-15 15:20:13 -06:00
f24c1befac
feat(client): track connection time (#84) 2025-06-15 15:20:06 -06:00
c811e2b578
Extend File to handle streams (#85) 2025-06-15 15:20:04 -06:00
9f2fc0857b
feat(client): sync commands on ready (#86) 2025-06-15 15:20:00 -06:00
775dce0c80
Store shard id on guild and expose shard property (#87) 2025-06-15 15:19:58 -06:00
a93ad432b7
Add thread and invite event parsing (#88) 2025-06-15 15:19:56 -06:00
3a264f4530 feat(ext-loader): Support async setup functions
Allow extension `setup` functions to be asynchronous. The loader now checks if `module.setup` returns an awaitable and runs it using asyncio, handling cases where an event loop is already running or not.

This enables extensions to perform asynchronous initialization tasks.
2025-06-15 15:17:42 -06:00
Slipstreamm
a41a301927 fix(core): Improve client ready state and user parsing
The `_ready_event` is now set in `GatewayClient` immediately after
receiving the `READY` payload, before dispatching `on_ready` to user code.
This ensures `Client.wait_until_ready()` and `Client.is_ready()`
accurately reflect the client's state before dependent user logic executes.

This change allows simplifying `Client.sync_commands` by removing
redundant `wait_until_ready()` calls and `application_id` checks,
as the application ID is guaranteed to be available upon READY.

Additionally, `User` model initialization is improved to correctly handle
nested user data found in certain API payloads (e.g., within `member`
objects in events like `PresenceUpdate`).

Add `SOUNDBOARD` and `VIDEO_QUALITY_720_60FPS` to `GuildFeature` enum.
2025-06-14 23:49:33 -06:00
Slipstreamm
bd16b1c026 chore(release): Bump version to 0.8.0 2025-06-14 21:53:41 -06:00
460583ef30
Fix pylint errors due to conditional imports (#78) 2025-06-14 21:52:59 -06:00
f1ca18a62a
Fix pylint errors due to conditional imports (#79) 2025-06-14 21:52:29 -06:00
Slipstreamm
2c8e426353 chore(release): Bump version to 0.7.0 2025-06-14 21:41:30 -06:00
c9aec0dc7e
Improve command sync and DM support (#77) 2025-06-14 21:40:52 -06:00
Slipstreamm
bd92806c4c bump 2 0.6.0 2025-06-14 19:06:57 -06:00
Slipstreamm
e965a675c1 refactor(api): Re-export common symbols from top-level package
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
Makes commonly used classes, functions, and decorators from `disagreement.ext` and `disagreement.ui` submodules directly accessible under the `disagreement` namespace.

This change simplifies import statements for users, leading to cleaner and more concise code. Documentation and examples have been updated to reflect these new, simplified import paths.
2025-06-14 18:57:12 -06:00
Slipstreamm
9237d12a24 docs(imports): Update import paths in documentation examples
Adjust examples to reflect the new top-level exposure of classes and enums, such as `Client`, `Permissions`, `Embed`, and `Button`, making imports simpler.
2025-06-14 18:44:04 -06:00
Slipstreamm
420c57df30 chore: Bump version to 0.5.0 in __init__.py and pyproject.toml 2025-06-14 18:22:20 -06:00
Slipstreamm
b039b2e948 refactor(init): Consolidate module imports and exports
This commit refactors the `disagreement/__init__.py` file to import and export new models, enums, and components.

The primary changes are:
- Add imports and exports for `Member`, `Role`, `Attachment`, `Channel`, `ActionRow`, `Button`, `SelectOption`, `SelectMenu`, `Embed`, `PartialEmoji`, `Section`, `TextDisplay`, `Thumbnail`, `UnfurledMediaItem`, `MediaGallery`, `MediaGalleryItem`, `Container`, and `Guild` from `disagreement.models`.
- Add imports and exports for `ButtonStyle`, `ChannelType`, `MessageFlags`, `InteractionType`, `InteractionCallbackType`, and `ComponentType` from `disagreement.enums`.
- Add `Interaction` from `disagreement.interactions`.
- Add `ui` and `ext` as top-level modules.
- Update `disagreement.ext/__init__.py` to expose `app_commands`, `commands`, and `tasks`.

These changes consolidate the library's public API, making new features more accessible.
The example files were also updated to use the direct imports from the `disagreement` package or its `ext` subpackage, improving readability and consistency.
2025-06-14 18:17:57 -06:00
f58ffe8321
Apply global allowed_mentions setting (#76)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
2025-06-13 22:10:19 -06:00
Slipstreamm
ffdb922142 ci(mirror): Make mirror remote addition idempotent
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
Wrap the `git remote add mirror` command in a conditional check.

This ensures the remote is only added if it doesn't already exist,
preventing potential errors if the command is executed multiple times
or if the remote is somehow already configured in the runner environment.
2025-06-13 00:30:31 -06:00
Slipstreamm
2b8f29bde2 chore(docs): Add CNAME file for custom domain configuration 2025-06-13 00:26:44 -06:00
Slipstreamm
f7a47619ac ci(workflows): Migrate to self-hosted runners
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled
Switch GitHub Actions workflows (`docs`, `mirror`, `pypi`) from `ubuntu-latest` to `self-hosted` runners.

This change also updates the Python environment setup in `docs.yml` and `pypi.yml` to manually create and activate a virtual environment (`venv`). This provides more control over the Python environment on self-hosted machines and ensures dependencies are isolated.
2025-06-13 00:23:55 -06:00
Slipstreamm
675aab39ce chore(ci): Activate virtual environment before running Pyright and tests 2025-06-13 00:21:23 -06:00
Slipstreamm
a2bdc66ced chore(deps): Add aiohttp and python-dotenv to test dependencies 2025-06-13 00:16:23 -06:00
Slipstreamm
6fb371455b Update exclusion patterns in pyrightconfig.json to include virtual environments 2025-06-13 00:12:37 -06:00
Slipstreamm
8a228a9e1b Refactor Pyright execution step to simplify command 2025-06-13 00:11:50 -06:00
Slipstreamm
1505bdfd0a Add output of pyrightconfig.json before running Pyright 2025-06-13 00:10:33 -06:00
Slipstreamm
7354ff2244 Simplify dependency installation in CI workflow 2025-06-13 00:10:24 -06:00
Slipstreamm
66eb50833b Add output of pyrightconfig.json before running Pyright 2025-06-13 00:08:19 -06:00
Slipstreamm
398c2c34c0 Update CI workflow to include current directory output before running Pyright 2025-06-13 00:06:52 -06:00
2e72103b6a
Update ci.yml 2025-06-13 00:03:28 -06:00
91821e1c1d
Update ci.yml 2025-06-12 23:57:06 -06:00
12b14b9187
Update ci.yml 2025-06-12 23:55:41 -06:00
fae9cddb88
Update ci.yml 2025-06-12 23:52:54 -06:00
fd9ce4bbb8
Update ci.yml 2025-06-12 23:48:26 -06:00
3adce99f22
Update ci.yml 2025-06-12 23:40:59 -06:00
075811982d
Update ci.yml 2025-06-12 23:40:16 -06:00
aa01d74c01
Update ci.yml 2025-06-12 23:38:29 -06:00
ae45cc898d
Update ci.yml 2025-06-12 23:25:00 -06:00
890742b177
Update ci.yml 2025-06-12 23:20:05 -06:00
d0e55d3706
Update ci.yml 2025-06-12 23:10:37 -06:00
b5ee8dc408
Update ci.yml 2025-06-12 22:53:57 -06:00
103 changed files with 3948 additions and 721 deletions

View File

@ -17,28 +17,30 @@ on:
- 'requirements.txt' - 'requirements.txt'
- 'pyproject.toml' - 'pyproject.toml'
- 'setup.py' - 'setup.py'
workflow_dispatch:
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: self-hosted
steps: steps:
- uses: actions/checkout@v4 - name: Checkout code
uses: actions/checkout@v4
- name: Set up Python - name: Install deps
uses: actions/setup-python@v5 run: |
with: python -m venv venv
python-version: '3.13' source venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
- name: Run Pyright
run: |
source venv/bin/activate
pyright
- name: Install dependencies - name: Run Tests
run: | run: |
python -m pip install --upgrade pip source venv/bin/activate
pip install -r requirements.txt pytest tests/
pip install -e .
npm install -g pyright
- name: Run Pyright
run: pyright
- name: Run tests
run: |
pytest tests/

View File

@ -21,7 +21,7 @@ on:
jobs: jobs:
deploy: deploy:
runs-on: ubuntu-latest runs-on: self-hosted
permissions: permissions:
contents: write contents: write
@ -30,12 +30,14 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 run: |
with: python -m venv venv
python-version: '3.13' source venv/bin/activate
pip install --upgrade pip
- name: Install dependencies - name: Install dependencies
run: | run: |
source venv/bin/activate
pip install mkdocs mkdocs-material pip install mkdocs mkdocs-material
- name: Configure Git author from GitHub Actions metadata - name: Configure Git author from GitHub Actions metadata
@ -46,5 +48,6 @@ jobs:
- name: Deploy docs - name: Deploy docs
env: env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: mkdocs gh-deploy --force --clean run: |
source venv/bin/activate
mkdocs gh-deploy --force --clean

View File

@ -3,11 +3,12 @@ name: Mirror to Gitea
on: on:
push: push:
branches: branches:
- master # or change to your default branch - master
jobs: jobs:
mirror: mirror:
runs-on: ubuntu-latest runs-on: self-hosted
steps: steps:
- name: Checkout repo - name: Checkout repo
uses: actions/checkout@v3 uses: actions/checkout@v3
@ -23,5 +24,7 @@ jobs:
env: env:
MIRROR_PAT: ${{ secrets.MIRROR_PAT }} MIRROR_PAT: ${{ secrets.MIRROR_PAT }}
run: | run: |
git remote add mirror https://slipstream:${MIRROR_PAT}@git.slipstreamm.dev/slipstream/disagreement.git if ! git remote | grep -q "^mirror$"; then
git remote add mirror https://slipstream:${MIRROR_PAT}@git.slipstreamm.dev/slipstream/disagreement.git
fi
git push --mirror mirror git push --mirror mirror

View File

@ -3,14 +3,14 @@ name: Publish to PyPI
on: on:
push: push:
tags: tags:
- 'v*' # only trigger on version tags like v1.0.0 - 'v*'
jobs: jobs:
build-and-publish: build-and-publish:
runs-on: ubuntu-latest runs-on: self-hosted
permissions: permissions:
id-token: write # required for trusted publishing, if used id-token: write
contents: read contents: read
steps: steps:
@ -27,17 +27,19 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 run: |
with: python -m venv venv
python-version: '3.13' source venv/bin/activate
pip install --upgrade pip
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip source venv/bin/activate
pip install build twine pip install build twine
- name: Build package - name: Build package
run: | run: |
source venv/bin/activate
python -m build python -m build
- name: Publish to PyPI - name: Publish to PyPI
@ -45,4 +47,5 @@ jobs:
TWINE_USERNAME: __token__ TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: | run: |
source venv/bin/activate
python -m twine upload dist/* python -m twine upload dist/*

View File

@ -13,6 +13,8 @@ A Python library for interacting with the Discord API, with a focus on bot devel
- Message component helpers - Message component helpers
- `Message.jump_url` property for quick links to messages - `Message.jump_url` property for quick links to messages
- Built-in caching layer - Built-in caching layer
- `Guild.me` property to access the bot's member object
- Easy CDN asset handling via the `Asset` model
- Experimental voice support - Experimental voice support
- Helpful error handling utilities - Helpful error handling utilities
@ -61,13 +63,9 @@ if not token:
intents = disagreement.GatewayIntent.default() | disagreement.GatewayIntent.MESSAGE_CONTENT intents = disagreement.GatewayIntent.default() | disagreement.GatewayIntent.MESSAGE_CONTENT
client = disagreement.Client(token=token, command_prefix="!", intents=intents, mention_replies=True) client = disagreement.Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
async def main() -> None:
client.add_cog(Basics(client))
await client.run()
client.add_cog(Basics(client))
if __name__ == "__main__": client.run()
asyncio.run(main())
``` ```
### Global Error Handling ### Global Error Handling
@ -114,6 +112,10 @@ These options are forwarded to ``HTTPClient`` when it creates the underlying
``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other ``aiohttp.ClientSession``. You can specify a custom ``connector`` or any other
session parameter supported by ``aiohttp``. session parameter supported by ``aiohttp``.
### Logging Out
Call ``Client.logout`` to disconnect from the Gateway and clear the current bot token while keeping the HTTP session alive. Assign a new token and call ``connect`` or ``run`` to log back in.
### Default Allowed Mentions ### Default Allowed Mentions
Specify default mention behaviour for all outgoing messages when constructing the client: Specify default mention behaviour for all outgoing messages when constructing the client:
@ -126,8 +128,19 @@ client = disagreement.Client(
) )
``` ```
This dictionary is used whenever ``send_message`` is called without an explicit This dictionary is used whenever ``send_message`` or helpers like ``Message.reply``
``allowed_mentions`` argument. are called without an explicit ``allowed_mentions`` argument.
### Working With Assets
Properties like ``User.avatar`` and ``Guild.icon`` return :class:`disagreement.Asset` objects.
Use ``read`` to get the bytes or ``save`` to write them to disk.
```python
user = await client.fetch_user(123)
data = await user.avatar.read()
await user.avatar.save("avatar.png")
```
### Defining Subcommands with `AppCommandGroup` ### Defining Subcommands with `AppCommandGroup`

View File

@ -12,10 +12,35 @@ __title__ = "disagreement"
__author__ = "Slipstream" __author__ = "Slipstream"
__license__ = "BSD 3-Clause License" __license__ = "BSD 3-Clause License"
__copyright__ = "Copyright 2025 Slipstream" __copyright__ = "Copyright 2025 Slipstream"
__version__ = "0.4.2" __version__ = "0.8.1"
from .client import Client, AutoShardedClient from .client import Client, AutoShardedClient
from .models import Message, User, Reaction, AuditLogEntry from .asset import Asset
from .models import (
Message,
User,
Reaction,
AuditLogEntry,
Member,
Role,
Attachment,
Channel,
ActionRow,
Button,
SelectOption,
SelectMenu,
Embed,
PartialEmoji,
Section,
TextDisplay,
Thumbnail,
UnfurledMediaItem,
MediaGallery,
MediaGalleryItem,
Container,
Guild,
)
from .object import Object
from .voice_client import VoiceClient from .voice_client import VoiceClient
from .audio import AudioSource, FFmpegAudioSource from .audio import AudioSource, FFmpegAudioSource
from .typing import Typing from .typing import Typing
@ -28,12 +53,73 @@ from .errors import (
NotFound, NotFound,
) )
from .color import Color from .color import Color
from .utils import utcnow, message_pager from .utils import (
from .enums import GatewayIntent, GatewayOpcode utcnow,
message_pager,
get,
find,
escape_markdown,
escape_mentions,
snowflake_time,
)
from .enums import (
GatewayIntent,
GatewayOpcode,
ButtonStyle,
ChannelType,
MessageFlags,
InteractionType,
InteractionCallbackType,
ComponentType,
)
from .error_handler import setup_global_error_handler from .error_handler import setup_global_error_handler
from .hybrid_context import HybridContext from .hybrid_context import HybridContext
from .ext import tasks from .interactions import Interaction
from .logging_config import setup_logging from .logging_config import setup_logging
from . import ui, ext
from .ext.app_commands import (
AppCommand,
AppCommandContext,
AppCommandGroup,
MessageCommand,
OptionMetadata,
SlashCommand,
UserCommand,
group,
hybrid_command,
message_command,
slash_command,
subcommand,
subcommand_group,
)
from .ext.commands import (
BadArgument,
CheckAnyFailure,
CheckFailure,
Cog,
Command,
CommandContext,
CommandError,
CommandInvokeError,
CommandNotFound,
CommandOnCooldown,
MaxConcurrencyReached,
MissingRequiredArgument,
ArgumentParsingError,
check,
check_any,
command,
cooldown,
has_any_role,
is_owner,
has_role,
listener,
max_concurrency,
requires_permissions,
)
from .ext.tasks import Task, loop
from .ui import Item, Modal, Select, TextInput, View, button, select, text_input
import logging import logging
@ -41,10 +127,29 @@ import logging
__all__ = [ __all__ = [
"Client", "Client",
"AutoShardedClient", "AutoShardedClient",
"Asset",
"Message", "Message",
"User", "User",
"Reaction", "Reaction",
"AuditLogEntry", "AuditLogEntry",
"Member",
"Role",
"Attachment",
"Channel",
"ActionRow",
"Button",
"SelectOption",
"SelectMenu",
"Embed",
"PartialEmoji",
"Section",
"TextDisplay",
"Thumbnail",
"UnfurledMediaItem",
"MediaGallery",
"MediaGalleryItem",
"Container",
"Object",
"VoiceClient", "VoiceClient",
"AudioSource", "AudioSource",
"FFmpegAudioSource", "FFmpegAudioSource",
@ -57,13 +162,72 @@ __all__ = [
"NotFound", "NotFound",
"Color", "Color",
"utcnow", "utcnow",
"escape_markdown",
"escape_mentions",
"message_pager", "message_pager",
"get",
"find",
"snowflake_time",
"GatewayIntent", "GatewayIntent",
"GatewayOpcode", "GatewayOpcode",
"ButtonStyle",
"ChannelType",
"MessageFlags",
"InteractionType",
"InteractionCallbackType",
"ComponentType",
"setup_global_error_handler", "setup_global_error_handler",
"HybridContext", "HybridContext",
"tasks", "Interaction",
"setup_logging", "setup_logging",
"ui",
"ext",
"AppCommand",
"AppCommandContext",
"AppCommandGroup",
"MessageCommand",
"OptionMetadata",
"SlashCommand",
"UserCommand",
"group",
"hybrid_command",
"message_command",
"slash_command",
"subcommand",
"subcommand_group",
"BadArgument",
"CheckAnyFailure",
"CheckFailure",
"Cog",
"Command",
"CommandContext",
"CommandError",
"CommandInvokeError",
"CommandNotFound",
"CommandOnCooldown",
"MaxConcurrencyReached",
"MissingRequiredArgument",
"ArgumentParsingError",
"check",
"check_any",
"command",
"cooldown",
"has_any_role",
"is_owner",
"has_role",
"listener",
"max_concurrency",
"requires_permissions",
"Task",
"loop",
"Item",
"Modal",
"Select",
"TextInput",
"View",
"button",
"select",
"text_input",
] ]

51
disagreement/asset.py Normal file
View File

@ -0,0 +1,51 @@
"""Utility class for Discord CDN assets."""
from __future__ import annotations
import os
from typing import IO, Optional, Union, TYPE_CHECKING
import aiohttp # pylint: disable=import-error
if TYPE_CHECKING:
from .client import Client
class Asset:
"""Represents a CDN asset such as an avatar or icon."""
def __init__(self, url: str, client_instance: Optional["Client"] = None) -> None:
self.url = url
self._client = client_instance
async def read(self) -> bytes:
"""Read the asset's bytes."""
session: Optional[aiohttp.ClientSession] = None
if self._client is not None:
await self._client._http._ensure_session() # type: ignore[attr-defined]
session = self._client._http._session # type: ignore[attr-defined]
if session is None:
session = aiohttp.ClientSession()
close = True
else:
close = False
async with session.get(self.url) as resp:
data = await resp.read()
if close:
await session.close()
return data
async def save(self, fp: Union[str, os.PathLike[str], IO[bytes]]) -> None:
"""Save the asset to the given file path or file-like object."""
data = await self.read()
if isinstance(fp, (str, os.PathLike)):
path = os.fspath(fp)
with open(path, "wb") as file:
file.write(data)
else:
fp.write(data)
def __repr__(self) -> str:
return f"<Asset url='{self.url}'>"

View File

@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]):
def _should_cache(self, member: Member) -> bool: def _should_cache(self, member: Member) -> bool:
"""Determines if a member should be cached based on the flags.""" """Determines if a member should be cached based on the flags."""
if self.flags.all: if self.flags.all_enabled:
return True return True
if self.flags.none: if self.flags.no_flags:
return False return False
if self.flags.online and member.status != "offline": if self.flags.online and member.status != "offline":

View File

@ -74,6 +74,14 @@ class MemberCacheFlags:
for name in self.VALID_FLAGS: for name in self.VALID_FLAGS:
yield name, getattr(self, name) yield name, getattr(self, name)
@property
def all_enabled(self) -> bool:
return self.value == self.ALL_FLAGS
@property
def no_flags(self) -> bool:
return self.value == 0
def __int__(self) -> int: def __int__(self) -> int:
return self.value return self.value

View File

@ -2,8 +2,11 @@
The main Client class for interacting with the Discord API. The main Client class for interacting with the Discord API.
""" """
import asyncio import asyncio
import signal import signal
import json
import os
import importlib
from typing import ( from typing import (
Optional, Optional,
Callable, Callable,
@ -14,8 +17,13 @@ from typing import (
Union, Union,
List, List,
Dict, Dict,
cast,
) )
from types import ModuleType from types import ModuleType
PERSISTENT_VIEWS_FILE = "persistent_views.json"
from datetime import datetime, timedelta
from .http import HTTPClient from .http import HTTPClient
from .gateway import GatewayClient from .gateway import GatewayClient
@ -35,6 +43,7 @@ from .interactions import Interaction, Snowflake
from .error_handler import setup_global_error_handler from .error_handler import setup_global_error_handler
from .voice_client import VoiceClient from .voice_client import VoiceClient
from .models import Activity from .models import Activity
from .utils import utcnow
if TYPE_CHECKING: if TYPE_CHECKING:
from .models import ( from .models import (
@ -64,7 +73,16 @@ if TYPE_CHECKING:
from .ext.app_commands.commands import AppCommand, AppCommandGroup from .ext.app_commands.commands import AppCommand, AppCommandGroup
class Client: def _update_list(lst: List[Any], item: Any) -> None:
"""Replace an item with the same ID in a list or append if missing."""
for i, existing in enumerate(lst):
if getattr(existing, "id", None) == getattr(item, "id", None):
lst[i] = item
return
lst.append(item)
class Client:
""" """
Represents a client connection that connects to Discord. Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API. This class is used to interact with the Discord WebSocket and API.
@ -89,6 +107,9 @@ class Client:
:class:`aiohttp.ClientSession`. :class:`aiohttp.ClientSession`.
message_cache_maxlen (Optional[int]): Maximum number of messages to keep message_cache_maxlen (Optional[int]): Maximum number of messages to keep
in the cache. When ``None``, the cache size is unlimited. in the cache. When ``None``, the cache size is unlimited.
sync_commands_on_ready (bool): If ``True``, automatically call
:meth:`Client.sync_application_commands` after the ``READY`` event
when :attr:`Client.application_id` is available.
""" """
def __init__( def __init__(
@ -109,7 +130,10 @@ class Client:
member_cache_flags: Optional[MemberCacheFlags] = None, member_cache_flags: Optional[MemberCacheFlags] = None,
message_cache_maxlen: Optional[int] = None, message_cache_maxlen: Optional[int] = None,
http_options: Optional[Dict[str, Any]] = None, http_options: Optional[Dict[str, Any]] = None,
owner_ids: Optional[List[Union[str, int]]] = None,
sync_commands_on_ready: bool = True,
): ):
if not token: if not token:
raise ValueError("A bot token must be provided.") raise ValueError("A bot token must be provided.")
@ -139,13 +163,14 @@ class Client:
**(http_options or {}), **(http_options or {}),
) )
self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self) self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self)
self._gateway: Optional[GatewayClient] = ( self._gateway: Optional[GatewayClient] = (
None # Initialized in run() or connect() None # Initialized in start() or connect()
) )
self.shard_count: Optional[int] = shard_count self.shard_count: Optional[int] = shard_count
self.gateway_max_retries: int = gateway_max_retries self.gateway_max_retries: int = gateway_max_retries
self.gateway_max_backoff: float = gateway_max_backoff self.gateway_max_backoff: float = gateway_max_backoff
self._shard_manager: Optional[ShardManager] = None self._shard_manager: Optional[ShardManager] = None
self.owner_ids: List[str] = [str(o) for o in owner_ids] if owner_ids else []
# Initialize CommandHandler # Initialize CommandHandler
self.command_handler: CommandHandler = CommandHandler( self.command_handler: CommandHandler = CommandHandler(
@ -163,6 +188,8 @@ class Client:
None # The bot's own user object, populated on READY None # The bot's own user object, populated on READY
) )
self.start_time: Optional[datetime] = None
# Internal Caches # Internal Caches
self._guilds: GuildCache = GuildCache() self._guilds: GuildCache = GuildCache()
self._channels: ChannelCache = ChannelCache() self._channels: ChannelCache = ChannelCache()
@ -171,11 +198,15 @@ class Client:
self._views: Dict[Snowflake, "View"] = {} self._views: Dict[Snowflake, "View"] = {}
self._persistent_views: Dict[str, "View"] = {} self._persistent_views: Dict[str, "View"] = {}
self._voice_clients: Dict[Snowflake, VoiceClient] = {} self._voice_clients: Dict[Snowflake, VoiceClient] = {}
self._webhooks: Dict[Snowflake, "Webhook"] = {} self._webhooks: Dict[Snowflake, "Webhook"] = {}
# Load persistent views stored on disk
self._load_persistent_views()
# Default whether replies mention the user # Default whether replies mention the user
self.mention_replies: bool = mention_replies self.mention_replies: bool = mention_replies
self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions
self.sync_commands_on_ready: bool = sync_commands_on_ready
# Basic signal handling for graceful shutdown # Basic signal handling for graceful shutdown
# This might be better handled by the user's application code, but can be a nice default. # This might be better handled by the user's application code, but can be a nice default.
@ -187,13 +218,46 @@ class Client:
self.loop.add_signal_handler( self.loop.add_signal_handler(
signal.SIGTERM, lambda: self.loop.create_task(self.close()) signal.SIGTERM, lambda: self.loop.create_task(self.close())
) )
except NotImplementedError: except NotImplementedError:
# add_signal_handler is not available on all platforms (e.g., Windows default event loop policy) # add_signal_handler is not available on all platforms (e.g., Windows default event loop policy)
# Users on these platforms would need to handle shutdown differently. # Users on these platforms would need to handle shutdown differently.
print( print(
"Warning: Signal handlers for SIGINT/SIGTERM could not be added. " "Warning: Signal handlers for SIGINT/SIGTERM could not be added. "
"Graceful shutdown via signals might not work as expected on this platform." "Graceful shutdown via signals might not work as expected on this platform."
) )
def _load_persistent_views(self) -> None:
"""Load registered persistent views from disk."""
if not os.path.isfile(PERSISTENT_VIEWS_FILE):
return
try:
with open(PERSISTENT_VIEWS_FILE, "r") as fp:
mapping = json.load(fp)
except Exception as e: # pragma: no cover - best effort load
print(f"Failed to load persistent views: {e}")
return
for custom_id, path in mapping.items():
try:
module_name, class_name = path.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
view = cls()
self._persistent_views[custom_id] = view
except Exception as e: # pragma: no cover - best effort load
print(f"Failed to initialize persistent view {path}: {e}")
def _save_persistent_views(self) -> None:
"""Persist registered views to disk."""
data = {}
for custom_id, view in self._persistent_views.items():
cls = view.__class__
data[custom_id] = f"{cls.__module__}.{cls.__name__}"
try:
with open(PERSISTENT_VIEWS_FILE, "w") as fp:
json.dump(data, fp)
except Exception as e: # pragma: no cover - best effort save
print(f"Failed to save persistent views: {e}")
async def _initialize_gateway(self): async def _initialize_gateway(self):
"""Initializes the GatewayClient if it doesn't exist.""" """Initializes the GatewayClient if it doesn't exist."""
@ -238,14 +302,15 @@ class Client:
f"Client connected using {self.shard_count} shards, waiting for READY signal..." f"Client connected using {self.shard_count} shards, waiting for READY signal..."
) )
await self.wait_until_ready() await self.wait_until_ready()
self.start_time = utcnow()
print("Client is READY!") print("Client is READY!")
return return
await self._initialize_gateway() await self._initialize_gateway()
assert self._gateway is not None # Should be initialized by now assert self._gateway is not None # Should be initialized by now
retry_delay = 5 # seconds retry_delay = 5 # seconds
max_retries = 5 # For initial connection attempts by Client.run, Gateway has its own internal retries for some cases. max_retries = 5 # For initial connection attempts by Client.start, Gateway has its own internal retries for some cases.
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
@ -254,6 +319,7 @@ class Client:
# and its READY handler will set self._ready_event via dispatcher. # and its READY handler will set self._ready_event via dispatcher.
print("Client connected to Gateway, waiting for READY signal...") print("Client connected to Gateway, waiting for READY signal...")
await self.wait_until_ready() # Wait for the READY event from Gateway await self.wait_until_ready() # Wait for the READY event from Gateway
self.start_time = utcnow()
print("Client is READY!") print("Client is READY!")
return # Successfully connected and ready return # Successfully connected and ready
except AuthenticationError: # Non-recoverable by retry here except AuthenticationError: # Non-recoverable by retry here
@ -262,25 +328,24 @@ class Client:
raise raise
except DisagreementException as e: # Includes GatewayException except DisagreementException as e: # Includes GatewayException
print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}") print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1: if attempt < max_retries - 1:
print(f"Retrying in {retry_delay} seconds...") print(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay) await asyncio.sleep(retry_delay)
retry_delay = min( retry_delay = min(
retry_delay * 2, 60 retry_delay * 2, 60
) # Exponential backoff up to 60s ) # Exponential backoff up to 60s
else: else:
print("Max connection retries reached. Giving up.") print("Max connection retries reached. Giving up.")
await self.close() # Ensure cleanup await self.close() # Ensure cleanup
raise raise
if max_retries == 0: # If max_retries was 0, means no retries attempted if max_retries == 0: # If max_retries was 0, means no retries attempted
raise DisagreementException("Connection failed with 0 retries allowed.") raise DisagreementException("Connection failed with 0 retries allowed.")
async def run(self) -> None: async def start(self) -> None:
""" """
A blocking call that connects the client to Discord and runs until the client is closed. Connect the client to Discord and run until the client is closed.
This method is a coroutine. This method is a coroutine containing the main run loop logic.
It handles login, Gateway connection, and keeping the connection alive. """
"""
if self._closed: if self._closed:
raise DisagreementException("Client is already closed.") raise DisagreementException("Client is already closed.")
@ -337,15 +402,19 @@ class Client:
except Exception as e: except Exception as e:
print(f"Error checking gateway receive task: {e}") print(f"Error checking gateway receive task: {e}")
break # Exit on other errors break # Exit on other errors
await asyncio.sleep(1) # Main loop check interval await asyncio.sleep(1) # Main loop check interval
except DisagreementException as e: except DisagreementException as e:
print(f"Client run loop encountered an error: {e}") print(f"Client run loop encountered an error: {e}")
# Error already logged by connect or other methods # Error already logged by connect or other methods
except asyncio.CancelledError: except asyncio.CancelledError:
print("Client run loop was cancelled.") print("Client run loop was cancelled.")
finally: finally:
if not self._closed: if not self._closed:
await self.close() await self.close()
def run(self) -> None:
"""Synchronously start the client using :func:`asyncio.run`."""
asyncio.run(self.start())
async def close(self) -> None: async def close(self) -> None:
""" """
@ -367,6 +436,7 @@ class Client:
await self._http.close() await self._http.close()
self._ready_event.set() # Ensure any waiters for ready are unblocked self._ready_event.set() # Ensure any waiters for ready are unblocked
self.start_time = None
print("Client closed.") print("Client closed.")
async def __aenter__(self) -> "Client": async def __aenter__(self) -> "Client":
@ -384,15 +454,23 @@ class Client:
await self.close() await self.close()
return False return False
async def close_gateway(self, code: int = 1000) -> None: async def close_gateway(self, code: int = 1000) -> None:
"""Closes only the gateway connection, allowing for potential reconnect.""" """Closes only the gateway connection, allowing for potential reconnect."""
if self._shard_manager: if self._shard_manager:
await self._shard_manager.close() await self._shard_manager.close()
self._shard_manager = None self._shard_manager = None
if self._gateway: if self._gateway:
await self._gateway.close(code=code) await self._gateway.close(code=code)
self._gateway = None self._gateway = None
self._ready_event.clear() # No longer ready if gateway is closed self._ready_event.clear() # No longer ready if gateway is closed
async def logout(self) -> None:
"""Invalidate the bot token and disconnect from the Gateway."""
await self.close_gateway()
self.token = ""
self._http.token = ""
self.user = None
self.start_time = None
def is_closed(self) -> bool: def is_closed(self) -> bool:
"""Indicates if the client has been closed.""" """Indicates if the client has been closed."""
@ -415,6 +493,17 @@ class Client:
latency = getattr(self._gateway, "latency_ms", None) latency = getattr(self._gateway, "latency_ms", None)
return round(latency, 2) if latency is not None else None return round(latency, 2) if latency is not None else None
@property
def guilds(self) -> List["Guild"]:
"""Returns all guilds from the internal cache."""
return self._guilds.values()
def uptime(self) -> Optional[timedelta]:
"""Return the duration since the client connected, or ``None`` if not connected."""
if self.start_time is None:
return None
return utcnow() - self.start_time
async def wait_until_ready(self) -> None: async def wait_until_ready(self) -> None:
"""|coro| """|coro|
Waits until the client is fully connected to Discord and the initial state is processed. Waits until the client is fully connected to Discord and the initial state is processed.
@ -529,38 +618,43 @@ class Client:
print(f"Message: {message.content}") print(f"Message: {message.content}")
""" """
def decorator( def decorator(
coro: Callable[..., Awaitable[None]], coro: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]: ) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(coro): if not asyncio.iscoroutinefunction(coro):
raise TypeError("Event registered must be a coroutine function.") raise TypeError("Event registered must be a coroutine function.")
self._event_dispatcher.register(event_name.upper(), coro) self._event_dispatcher.register(event_name.upper(), coro)
return coro return coro
return decorator
def add_listener(
self, event_name: str, coro: Callable[..., Awaitable[None]]
) -> None:
"""Register ``coro`` to listen for ``event_name``."""
self._event_dispatcher.register(event_name, coro)
def remove_listener(
self, event_name: str, coro: Callable[..., Awaitable[None]]
) -> None:
"""Remove ``coro`` from ``event_name`` listeners."""
self._event_dispatcher.unregister(event_name, coro)
async def _process_message_for_commands(self, message: "Message") -> None: return decorator
"""Internal listener to process messages for commands."""
# Make sure message object is valid and not from a bot (optional, common check) def add_listener(
if ( self, event_name: str, coro: Callable[..., Awaitable[None]]
not message or not message.author or message.author.bot ) -> None:
): # Add .bot check to User model """Register ``coro`` to listen for ``event_name``."""
return
await self.command_handler.process_commands(message) self._event_dispatcher.register(event_name, coro)
def remove_listener(
self, event_name: str, coro: Callable[..., Awaitable[None]]
) -> None:
"""Remove ``coro`` from ``event_name`` listeners."""
self._event_dispatcher.unregister(event_name, coro)
async def _process_message_for_commands(self, message: "Message") -> None:
"""Internal listener to process messages for commands."""
# Make sure message object is valid and not from a bot (optional, common check)
if (
not message or not message.author or message.author.bot
): # Add .bot check to User model
return
await self.command_handler.process_commands(message)
async def get_context(self, message: "Message") -> Optional["CommandContext"]:
"""Return a :class:`CommandContext` for ``message`` without executing the command."""
return await self.command_handler.get_context(message)
# --- Command Framework Methods --- # --- Command Framework Methods ---
@ -587,7 +681,7 @@ class Client:
f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'." f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'."
) )
def remove_cog(self, cog_name: str) -> Optional[Cog]: def remove_cog(self, cog_name: str) -> Optional[Cog]:
""" """
Removes a Cog from the bot. Removes a Cog from the bot.
@ -614,7 +708,12 @@ class Client:
# Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique # Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique
# (e.g. if it needs type or if groups and commands can share names). # (e.g. if it needs type or if groups and commands can share names).
# For now, assuming name is sufficient for removal from the handler's flat list. # For now, assuming name is sufficient for removal from the handler's flat list.
return removed_cog return removed_cog
def get_cog(self, name: str) -> Optional[Cog]:
"""Return a loaded cog by name if present."""
return self.command_handler.get_cog(name)
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]): def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
""" """
@ -754,14 +853,19 @@ class Client:
"""Parses user data and returns a User object, updating cache.""" """Parses user data and returns a User object, updating cache."""
from .models import User # Ensure User model is available from .models import User # Ensure User model is available
user = User(data) user = User(data, client_instance=self)
self._users.set(user.id, user) # Cache the user self._users.set(user.id, user) # Cache the user
return user return user
def parse_channel(self, data: Dict[str, Any]) -> "Channel": def parse_channel(self, data: Dict[str, Any]) -> "Channel":
"""Parses channel data and returns a Channel object, updating caches.""" """Parses channel data and returns a Channel object, updating caches."""
from .models import channel_factory from .models import (
channel_factory,
TextChannel,
VoiceChannel,
CategoryChannel,
)
channel = channel_factory(data, self) channel = channel_factory(data, self)
self._channels.set(channel.id, channel) self._channels.set(channel.id, channel)
@ -769,6 +873,12 @@ class Client:
guild = self._guilds.get(channel.guild_id) guild = self._guilds.get(channel.guild_id)
if guild: if guild:
guild._channels.set(channel.id, channel) guild._channels.set(channel.id, channel)
if isinstance(channel, TextChannel):
_update_list(guild.text_channels, channel)
elif isinstance(channel, VoiceChannel):
_update_list(guild.voice_channels, channel)
elif isinstance(channel, CategoryChannel):
_update_list(guild.category_channels, channel)
return channel return channel
def parse_message(self, data: Dict[str, Any]) -> "Message": def parse_message(self, data: Dict[str, Any]) -> "Message":
@ -929,7 +1039,8 @@ class Client:
"""Parses guild data and returns a Guild object, updating cache.""" """Parses guild data and returns a Guild object, updating cache."""
from .models import Guild from .models import Guild
guild = Guild(data, client_instance=self) shard_id = data.get("shard_id")
guild = Guild(data, client_instance=self, shard_id=shard_id)
self._guilds.set(guild.id, guild) self._guilds.set(guild.id, guild)
presences = {p["user"]["id"]: p for p in data.get("presences", [])} presences = {p["user"]["id"]: p for p in data.get("presences", [])}
@ -1107,6 +1218,23 @@ class Client:
return self.parse_message(message_data) return self.parse_message(message_data)
async def create_dm(self, user_id: Snowflake) -> "DMChannel":
"""|coro| Create or fetch a DM channel with a user."""
from .models import DMChannel
dm_data = await self._http.create_dm(user_id)
return cast(DMChannel, self.parse_channel(dm_data))
async def send_dm(
self,
user_id: Snowflake,
content: Optional[str] = None,
**kwargs: Any,
) -> "Message":
"""|coro| Convenience method to send a direct message to a user."""
channel = await self.create_dm(user_id)
return await self.send_message(channel.id, content=content, **kwargs)
def typing(self, channel_id: str) -> Typing: def typing(self, channel_id: str) -> Typing:
"""Return a context manager to show a typing indicator in a channel.""" """Return a context manager to show a typing indicator in a channel."""
@ -1320,13 +1448,33 @@ class Client:
return self._channels.get(channel_id) return self._channels.get(channel_id)
def get_message(self, message_id: Snowflake) -> Optional["Message"]: def get_message(self, message_id: Snowflake) -> Optional["Message"]:
"""Returns a message from the internal cache.""" """Returns a message from the internal cache."""
return self._messages.get(message_id)
def get_all_channels(self) -> List["Channel"]:
"""Return all channels cached in every guild."""
channels: List["Channel"] = []
for guild in self._guilds.values():
channels.extend(guild._channels.values())
return channels
def get_all_members(self) -> List["Member"]:
"""Return all cached members across all guilds.
When member caching is disabled via :class:`MemberCacheFlags.none`, this
list will always be empty.
"""
members: List["Member"] = []
for guild in self._guilds.values():
members.extend(guild._members.values())
return members
return self._messages.get(message_id) async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
"""Fetches a guild by ID from Discord and caches it."""
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
"""Fetches a guild by ID from Discord and caches it."""
if self._closed: if self._closed:
raise DisagreementException("Client is closed.") raise DisagreementException("Client is closed.")
@ -1340,19 +1488,19 @@ class Client:
return self.parse_guild(guild_data) return self.parse_guild(guild_data)
except DisagreementException as e: except DisagreementException as e:
print(f"Failed to fetch guild {guild_id}: {e}") print(f"Failed to fetch guild {guild_id}: {e}")
return None return None
async def fetch_guilds(self) -> List["Guild"]: async def fetch_guilds(self) -> List["Guild"]:
"""Fetch all guilds the current user is in.""" """Fetch all guilds the current user is in."""
if self._closed: if self._closed:
raise DisagreementException("Client is closed.") raise DisagreementException("Client is closed.")
data = await self._http.get_current_user_guilds() data = await self._http.get_current_user_guilds()
guilds: List["Guild"] = [] guilds: List["Guild"] = []
for guild_data in data: for guild_data in data:
guilds.append(self.parse_guild(guild_data)) guilds.append(self.parse_guild(guild_data))
return guilds return guilds
async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]: async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]:
"""Fetches a channel from Discord by its ID and updates the cache.""" """Fetches a channel from Discord by its ID and updates the cache."""
@ -1423,16 +1571,33 @@ class Client:
data = await self._http.edit_webhook(webhook_id, payload) data = await self._http.edit_webhook(webhook_id, payload)
return self.parse_webhook(data) return self.parse_webhook(data)
async def delete_webhook(self, webhook_id: Snowflake) -> None: async def delete_webhook(self, webhook_id: Snowflake) -> None:
"""|coro| Delete a webhook by ID.""" """|coro| Delete a webhook by ID."""
if self._closed: if self._closed:
raise DisagreementException("Client is closed.") raise DisagreementException("Client is closed.")
await self._http.delete_webhook(webhook_id) await self._http.delete_webhook(webhook_id)
async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]: async def fetch_webhook(self, webhook_id: Snowflake) -> Optional["Webhook"]:
"""|coro| Fetch all templates for a guild.""" """|coro| Fetch a webhook by ID."""
if self._closed:
raise DisagreementException("Client is closed.")
cached = self._webhooks.get(webhook_id)
if cached:
return cached
try:
data = await self._http.get_webhook(webhook_id)
return self.parse_webhook(data)
except DisagreementException as e:
print(f"Failed to fetch webhook {webhook_id}: {e}")
return None
async def fetch_templates(self, guild_id: Snowflake) -> List["GuildTemplate"]:
"""|coro| Fetch all templates for a guild."""
if self._closed: if self._closed:
raise DisagreementException("Client is closed.") raise DisagreementException("Client is closed.")
@ -1562,7 +1727,20 @@ class Client:
if self._closed: if self._closed:
raise DisagreementException("Client is closed.") raise DisagreementException("Client is closed.")
await self._http.delete_invite(code) await self._http.delete_invite(code)
async def fetch_invite(self, code: Snowflake) -> Optional["Invite"]:
"""|coro| Fetch a single invite by code."""
if self._closed:
raise DisagreementException("Client is closed.")
try:
data = await self._http.get_invite(code)
return self.parse_invite(data)
except DisagreementException as e:
print(f"Failed to fetch invite {code}: {e}")
return None
async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]: async def fetch_invites(self, channel_id: Snowflake) -> List["Invite"]:
"""|coro| Fetch all invites for a channel.""" """|coro| Fetch all invites for a channel."""
@ -1598,11 +1776,13 @@ class Client:
for item in view.children: for item in view.children:
if item.custom_id: # Ensure custom_id is not None if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views: if item.custom_id in self._persistent_views:
raise ValueError( raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered." f"A component with custom_id '{item.custom_id}' is already registered."
) )
self._persistent_views[item.custom_id] = view self._persistent_views[item.custom_id] = view
self._save_persistent_views()
# --- Application Command Methods --- # --- Application Command Methods ---
async def process_interaction(self, interaction: Interaction) -> None: async def process_interaction(self, interaction: Interaction) -> None:
@ -1647,16 +1827,6 @@ class Client:
"Ensure the client is connected and READY." "Ensure the client is connected and READY."
) )
return return
if not self.is_ready():
print(
"Warning: Client is not ready. Waiting for client to be ready before syncing commands."
)
await self.wait_until_ready()
if not self.application_id:
print(
"Error: application_id still not set after client is ready. Cannot sync commands."
)
return
await self.app_command_handler.sync_commands( await self.app_command_handler.sync_commands(
application_id=self.application_id, guild_id=guild_id application_id=self.application_id, guild_id=guild_id
@ -1677,6 +1847,16 @@ class Client:
pass pass
async def on_connect(self) -> None:
"""|coro| Called when the WebSocket connection opens."""
pass
async def on_disconnect(self) -> None:
"""|coro| Called when the WebSocket connection closes."""
pass
async def on_app_command_error( async def on_app_command_error(
self, context: AppCommandContext, error: Exception self, context: AppCommandContext, error: Exception
) -> None: ) -> None:

View File

@ -268,12 +268,19 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum
VERIFIED = "VERIFIED" VERIFIED = "VERIFIED"
VIP_REGIONS = "VIP_REGIONS" VIP_REGIONS = "VIP_REGIONS"
WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED" WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED"
SOUNDBOARD = "SOUNDBOARD"
VIDEO_QUALITY_720_60FPS = "VIDEO_QUALITY_720_60FPS"
# Add more as they become known or needed # Add more as they become known or needed
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work # This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
@classmethod @classmethod
def _missing_(cls, value): # type: ignore def _missing_(cls, value): # type: ignore
return str(value) member = object.__new__(cls)
member._name_ = str(value)
member._value_ = str(value)
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
cls._member_map_[member._name_] = member # pylint: disable=no-member
return member
# --- Guild Scheduled Event Enums --- # --- Guild Scheduled Event Enums ---
@ -329,7 +336,12 @@ class VoiceRegion(str, Enum):
@classmethod @classmethod
def _missing_(cls, value): # type: ignore def _missing_(cls, value): # type: ignore
return str(value) member = object.__new__(cls)
member._name_ = str(value)
member._value_ = str(value)
cls._value2member_map_[member._value_] = member # pylint: disable=no-member
cls._member_map_[member._name_] = member # pylint: disable=no-member
return member
# --- Channel Enums --- # --- Channel Enums ---

View File

@ -61,6 +61,11 @@ class EventDispatcher:
"GUILD_ROLE_UPDATE": self._parse_guild_role_update, "GUILD_ROLE_UPDATE": self._parse_guild_role_update,
"TYPING_START": self._parse_typing_start, "TYPING_START": self._parse_typing_start,
"VOICE_STATE_UPDATE": self._parse_voice_state_update, "VOICE_STATE_UPDATE": self._parse_voice_state_update,
"THREAD_CREATE": self._parse_thread_create,
"THREAD_UPDATE": self._parse_thread_update,
"THREAD_DELETE": self._parse_thread_delete,
"INVITE_CREATE": self._parse_invite_create,
"INVITE_DELETE": self._parse_invite_delete,
} }
def _parse_message_create(self, data: Dict[str, Any]) -> Message: def _parse_message_create(self, data: Dict[str, Any]) -> Message:
@ -165,6 +170,43 @@ class EventDispatcher:
return GuildRoleUpdate(data, client_instance=self._client) return GuildRoleUpdate(data, client_instance=self._client)
def _parse_thread_create(self, data: Dict[str, Any]):
"""Parses THREAD_CREATE into a Thread object and updates caches."""
return self._client.parse_channel(data)
def _parse_thread_update(self, data: Dict[str, Any]):
"""Parses THREAD_UPDATE into a Thread object."""
return self._client.parse_channel(data)
def _parse_thread_delete(self, data: Dict[str, Any]):
"""Parses THREAD_DELETE, removing the thread from caches."""
thread = self._client.parse_channel(data)
thread_id = data.get("id")
if thread_id:
self._client._channels.invalidate(thread_id)
guild_id = data.get("guild_id")
if guild_id:
guild = self._client._guilds.get(guild_id)
if guild:
guild._channels.invalidate(thread_id)
guild._threads.pop(thread_id, None)
return thread
def _parse_invite_create(self, data: Dict[str, Any]):
"""Parses INVITE_CREATE into an Invite object."""
return self._client.parse_invite(data)
def _parse_invite_delete(self, data: Dict[str, Any]):
"""Parses INVITE_DELETE into an InviteDelete model."""
from .models import InviteDelete
return InviteDelete(data)
# Potentially add _parse_user for events that directly provide a full user object # Potentially add _parse_user for events that directly provide a full user object
# def _parse_user_update(self, data: Dict[str, Any]) -> User: # def _parse_user_update(self, data: Dict[str, Any]) -> User:
# return User(data=data) # return User(data=data)

View File

@ -1 +1,3 @@
from . import app_commands, commands, tasks
__all__ = ["app_commands", "commands", "tasks"]

View File

@ -253,6 +253,8 @@ class AppCommandContext:
Optional[Message]: The sent message object if a new message was created and not ephemeral. Optional[Message]: The sent message object if a new message was created and not ephemeral.
None if the response was ephemeral or an edit to a deferred message. None if the response was ephemeral or an edit to a deferred message.
""" """
if allowed_mentions is None:
allowed_mentions = getattr(self.bot, "allowed_mentions", None)
if not self._responded and self._deferred: # Editing a deferred response if not self._responded and self._deferred: # Editing a deferred response
# Use edit_original_interaction_response # Use edit_original_interaction_response
payload: Dict[str, Any] = {} payload: Dict[str, Any] = {}
@ -393,6 +395,9 @@ class AppCommandContext:
"Must acknowledge or defer the interaction before sending a followup." "Must acknowledge or defer the interaction before sending a followup."
) )
if allowed_mentions is None:
allowed_mentions = getattr(self.bot, "allowed_mentions", None)
payload: Dict[str, Any] = {} payload: Dict[str, Any] = {}
if content is not None: if content is not None:
payload["content"] = content payload["content"] = content
@ -473,6 +478,9 @@ class AppCommandContext:
"Cannot edit response if interaction hasn't been responded to or deferred." "Cannot edit response if interaction hasn't been responded to or deferred."
) )
if allowed_mentions is None:
allowed_mentions = getattr(self.bot, "allowed_mentions", None)
payload: Dict[str, Any] = {} payload: Dict[str, Any] = {}
if content is not None: if content is not None:
payload["content"] = content # Use None to clear payload["content"] = content # Use None to clear

View File

@ -18,51 +18,23 @@ from typing import (
if TYPE_CHECKING: if TYPE_CHECKING:
from disagreement.client import Client from disagreement.client import Client
from disagreement.interactions import Interaction, ResolvedData, Snowflake from disagreement.interactions import Interaction, ResolvedData, Snowflake
from disagreement.enums import (
ApplicationCommandType,
ApplicationCommandOptionType,
InteractionType,
)
from .commands import (
AppCommand,
SlashCommand,
UserCommand,
MessageCommand,
AppCommandGroup,
)
from .context import AppCommandContext
from disagreement.models import (
User,
Member,
Role,
Attachment,
Message,
) # For resolved data
# Channel models would also go here from disagreement.enums import (
ApplicationCommandType,
ApplicationCommandOptionType,
InteractionType,
)
from .commands import (
AppCommand,
SlashCommand,
UserCommand,
MessageCommand,
AppCommandGroup,
)
from .context import AppCommandContext
from disagreement.models import User, Member, Role, Attachment, Message
# Placeholder for models not yet fully defined or imported Channel = Any
if not TYPE_CHECKING:
from disagreement.enums import (
ApplicationCommandType,
ApplicationCommandOptionType,
InteractionType,
)
from .commands import (
AppCommand,
SlashCommand,
UserCommand,
MessageCommand,
AppCommandGroup,
)
from .context import AppCommandContext
User = Any
Member = Any
Role = Any
Attachment = Any
Channel = Any
Message = Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -587,12 +559,19 @@ class AppCommandHandler:
# print(f"Failed to send error message for app command: {send_e}") # print(f"Failed to send error message for app command: {send_e}")
async def sync_commands( async def sync_commands(
self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None self,
application_id: Optional["Snowflake"] = None,
guild_id: Optional["Snowflake"] = None,
) -> None: ) -> None:
""" """
Synchronizes (registers/updates) all application commands with Discord. Synchronizes (registers/updates) all application commands with Discord.
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands. If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
""" """
if application_id is None:
application_id = self.client.application_id
if application_id is None:
raise ValueError("application_id must be provided to sync commands")
cache = self._load_cached_ids() cache = self._load_cached_ids()
scope_key = str(guild_id) if guild_id else "global" scope_key = str(guild_id) if guild_id else "global"
stored = cache.get(scope_key, {}) stored = cache.get(scope_key, {})

View File

@ -14,11 +14,12 @@ from .decorators import (
check, check,
check_any, check_any,
cooldown, cooldown,
max_concurrency, max_concurrency,
requires_permissions, requires_permissions,
has_role, has_role,
has_any_role, has_any_role,
) is_owner,
)
from .errors import ( from .errors import (
CommandError, CommandError,
CommandNotFound, CommandNotFound,
@ -47,9 +48,10 @@ __all__ = [
"cooldown", "cooldown",
"max_concurrency", "max_concurrency",
"requires_permissions", "requires_permissions",
"has_role", "has_role",
"has_any_role", "has_any_role",
# Errors "is_owner",
# Errors
"CommandError", "CommandError",
"CommandNotFound", "CommandNotFound",
"BadArgument", "BadArgument",

View File

@ -6,7 +6,16 @@ import re
import inspect import inspect
from .errors import BadArgument from .errors import BadArgument
from disagreement.models import Member, Guild, Role from disagreement.models import (
Member,
Guild,
Role,
User,
TextChannel,
VoiceChannel,
Emoji,
PartialEmoji,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .core import CommandContext from .core import CommandContext
@ -143,6 +152,97 @@ class GuildConverter(Converter["Guild"]):
raise BadArgument(f"Guild '{argument}' not found.") raise BadArgument(f"Guild '{argument}' not found.")
class UserConverter(Converter["User"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "User":
match = re.match(r"<@!?(\d+)>$", argument)
user_id = match.group(1) if match else argument
user = ctx.bot._users.get(user_id)
if user:
return user
user = await ctx.bot.fetch_user(user_id)
if user:
return user
raise BadArgument(f"User '{argument}' not found.")
class TextChannelConverter(Converter["TextChannel"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "TextChannel":
if not ctx.message.guild_id:
raise BadArgument("TextChannel converter requires guild context.")
match = re.match(r"<#(?P<id>\d+)>$", argument)
channel_id = match.group("id") if match else argument
guild = ctx.bot.get_guild(ctx.message.guild_id)
if guild:
channel = guild.get_channel(channel_id)
if isinstance(channel, TextChannel):
return channel
channel = (
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
)
if isinstance(channel, TextChannel):
return channel
if hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(channel_id)
if isinstance(channel, TextChannel):
return channel
raise BadArgument(f"Text channel '{argument}' not found.")
class VoiceChannelConverter(Converter["VoiceChannel"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "VoiceChannel":
if not ctx.message.guild_id:
raise BadArgument("VoiceChannel converter requires guild context.")
match = re.match(r"<#(?P<id>\d+)>$", argument)
channel_id = match.group("id") if match else argument
guild = ctx.bot.get_guild(ctx.message.guild_id)
if guild:
channel = guild.get_channel(channel_id)
if isinstance(channel, VoiceChannel):
return channel
channel = (
ctx.bot.get_channel(channel_id) if hasattr(ctx.bot, "get_channel") else None
)
if isinstance(channel, VoiceChannel):
return channel
if hasattr(ctx.bot, "fetch_channel"):
channel = await ctx.bot.fetch_channel(channel_id)
if isinstance(channel, VoiceChannel):
return channel
raise BadArgument(f"Voice channel '{argument}' not found.")
class EmojiConverter(Converter["PartialEmoji"]):
_CUSTOM_RE = re.compile(r"<(?P<animated>a)?:(?P<name>[^:]+):(?P<id>\d+)>$")
async def convert(self, ctx: "CommandContext", argument: str) -> "PartialEmoji":
match = self._CUSTOM_RE.match(argument)
if match:
return PartialEmoji(
{
"id": match.group("id"),
"name": match.group("name"),
"animated": bool(match.group("animated")),
}
)
if argument:
return PartialEmoji({"id": None, "name": argument})
raise BadArgument(f"Emoji '{argument}' not found.")
# Default converters mapping # Default converters mapping
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = { DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
int: IntConverter(), int: IntConverter(),
@ -152,7 +252,11 @@ DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
Member: MemberConverter(), Member: MemberConverter(),
Guild: GuildConverter(), Guild: GuildConverter(),
Role: RoleConverter(), Role: RoleConverter(),
# User: UserConverter(), # Add when User model and converter are ready User: UserConverter(),
TextChannel: TextChannelConverter(),
VoiceChannel: VoiceChannelConverter(),
PartialEmoji: EmojiConverter(),
Emoji: EmojiConverter(),
} }

View File

@ -79,8 +79,15 @@ class GroupMixin:
) )
self.commands[alias.lower()] = command self.commands[alias.lower()] = command
def get_command(self, name: str) -> Optional["Command"]: def get_command(self, name: str) -> Optional["Command"]:
return self.commands.get(name.lower()) return self.commands.get(name.lower())
def walk_commands(self):
"""Yield all commands in this group recursively."""
for cmd in dict.fromkeys(self.commands.values()):
yield cmd
if isinstance(cmd, Group):
yield from cmd.walk_commands()
class Command(GroupMixin): class Command(GroupMixin):
@ -250,15 +257,15 @@ class CommandContext:
client's :attr:`mention_replies` value is used. client's :attr:`mention_replies` value is used.
""" """
allowed_mentions = kwargs.pop("allowed_mentions", None) allowed_mentions = kwargs.pop("allowed_mentions", None)
if mention_author is None: if mention_author is None:
mention_author = getattr(self.bot, "mention_replies", False) mention_author = getattr(self.bot, "mention_replies", False)
if allowed_mentions is None: if allowed_mentions is None:
allowed_mentions = {"replied_user": mention_author} allowed_mentions = dict(getattr(self.bot, "allowed_mentions", {}) or {})
else: else:
allowed_mentions = dict(allowed_mentions) allowed_mentions = dict(allowed_mentions)
allowed_mentions.setdefault("replied_user", mention_author) allowed_mentions.setdefault("replied_user", mention_author)
return await self.bot.send_message( return await self.bot.send_message(
channel_id=self.message.channel_id, channel_id=self.message.channel_id,
@ -363,8 +370,20 @@ class CommandHandler:
self.commands.pop(alias.lower(), None) self.commands.pop(alias.lower(), None)
return command return command
def get_command(self, name: str) -> Optional[Command]: def get_command(self, name: str) -> Optional[Command]:
return self.commands.get(name.lower()) return self.commands.get(name.lower())
def walk_commands(self):
"""Yield every registered command, including subcommands."""
for cmd in dict.fromkeys(self.commands.values()):
yield cmd
if isinstance(cmd, Group):
yield from cmd.walk_commands()
def get_cog(self, name: str) -> Optional["Cog"]:
"""Return a loaded cog by name if present."""
return self.cogs.get(name)
def add_cog(self, cog_to_add: "Cog") -> None: def add_cog(self, cog_to_add: "Cog") -> None:
from .cog import Cog from .cog import Cog
@ -471,9 +490,9 @@ class CommandHandler:
return self.prefix(self.client, message) # type: ignore return self.prefix(self.client, message) # type: ignore
return self.prefix return self.prefix
async def _parse_arguments( async def _parse_arguments(
self, command: Command, ctx: CommandContext, view: StringView self, command: Command, ctx: CommandContext, view: StringView
) -> Tuple[List[Any], Dict[str, Any]]: ) -> Tuple[List[Any], Dict[str, Any]]:
args_list = [] args_list = []
kwargs_dict = {} kwargs_dict = {}
params_to_parse = list(command.params.values()) params_to_parse = list(command.params.values())
@ -636,7 +655,79 @@ class CommandHandler:
elif param.kind == inspect.Parameter.KEYWORD_ONLY: elif param.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs_dict[param.name] = final_value_for_param kwargs_dict[param.name] = final_value_for_param
return args_list, kwargs_dict return args_list, kwargs_dict
async def get_context(self, message: "Message") -> Optional[CommandContext]:
"""Parse a message and return a :class:`CommandContext` without executing the command.
Returns ``None`` if the message does not invoke a command."""
if not message.content:
return None
prefix_to_use = await self.get_prefix(message)
if not prefix_to_use:
return None
actual_prefix: Optional[str] = None
if isinstance(prefix_to_use, list):
for p in prefix_to_use:
if message.content.startswith(p):
actual_prefix = p
break
if not actual_prefix:
return None
elif isinstance(prefix_to_use, str):
if message.content.startswith(prefix_to_use):
actual_prefix = prefix_to_use
else:
return None
else:
return None
if actual_prefix is None:
return None
view = StringView(message.content[len(actual_prefix) :])
command_name = view.get_word()
if not command_name:
return None
command = self.get_command(command_name)
if not command:
return None
invoked_with = command_name
if isinstance(command, Group):
view.skip_whitespace()
potential_subcommand = view.get_word()
if potential_subcommand:
subcommand = command.get_command(potential_subcommand)
if subcommand:
command = subcommand
invoked_with += f" {potential_subcommand}"
elif command.invoke_without_command:
view.index -= len(potential_subcommand) + view.previous
else:
raise CommandNotFound(
f"Subcommand '{potential_subcommand}' not found."
)
ctx = CommandContext(
message=message,
bot=self.client,
prefix=actual_prefix,
command=command,
invoked_with=invoked_with,
cog=command.cog,
)
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
ctx.args = parsed_args
ctx.kwargs = parsed_kwargs
return ctx
async def process_commands(self, message: "Message") -> None: async def process_commands(self, message: "Message") -> None:
if not message.content: if not message.content:

View File

@ -292,3 +292,19 @@ def has_any_role(
) )
return check(predicate) return check(predicate)
def is_owner() -> (
Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]
):
"""Check that the invoking user is listed as a bot owner."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckFailure
owner_ids = getattr(ctx.bot, "owner_ids", [])
if str(ctx.author.id) not in {str(o) for o in owner_ids}:
raise CheckFailure("This command can only be used by the bot owner.")
return True
return check(predicate)

View File

@ -1,6 +1,8 @@
from collections import defaultdict
from typing import List, Optional from typing import List, Optional
from .core import Command, CommandContext, CommandHandler from ...utils import Paginator
from .core import Command, CommandContext, CommandHandler, Group
class HelpCommand(Command): class HelpCommand(Command):
@ -15,17 +17,22 @@ class HelpCommand(Command):
if not cmd or cmd.name.lower() != command.lower(): if not cmd or cmd.name.lower() != command.lower():
await ctx.send(f"Command '{command}' not found.") await ctx.send(f"Command '{command}' not found.")
return return
description = cmd.description or cmd.brief or "No description provided." if isinstance(cmd, Group):
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}") await self.send_group_help(ctx, cmd)
else: elif cmd:
lines: List[str] = [] description = cmd.description or cmd.brief or "No description provided."
for registered in dict.fromkeys(handler.commands.values()): await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
brief = registered.brief or registered.description or ""
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
if lines:
await ctx.send("\n".join(lines))
else: else:
await ctx.send("No commands available.") lines: List[str] = []
for registered in handler.walk_commands():
brief = registered.brief or registered.description or ""
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
if lines:
await ctx.send("\n".join(lines))
else:
await self.send_command_help(ctx, cmd)
else:
await self.send_bot_help(ctx)
super().__init__( super().__init__(
callback, callback,
@ -33,3 +40,42 @@ class HelpCommand(Command):
brief="Show command help.", brief="Show command help.",
description="Displays help for commands.", description="Displays help for commands.",
) )
async def send_bot_help(self, ctx: CommandContext) -> None:
groups = defaultdict(list)
for cmd in dict.fromkeys(self.handler.commands.values()):
key = cmd.cog.cog_name if cmd.cog else "No Category"
groups[key].append(cmd)
paginator = Paginator()
for cog_name, cmds in groups.items():
paginator.add_line(f"**{cog_name}**")
for cmd in cmds:
brief = cmd.brief or cmd.description or ""
paginator.add_line(f"{ctx.prefix}{cmd.name} - {brief}".strip())
paginator.add_line("")
pages = paginator.pages
if not pages:
await ctx.send("No commands available.")
return
for page in pages:
await ctx.send(page)
async def send_command_help(self, ctx: CommandContext, command: Command) -> None:
description = command.description or command.brief or "No description provided."
await ctx.send(f"**{ctx.prefix}{command.name}**\n{description}")
async def send_group_help(self, ctx: CommandContext, group: Group) -> None:
paginator = Paginator()
description = group.description or group.brief or "No description provided."
paginator.add_line(f"**{ctx.prefix}{group.name}**\n{description}")
if group.commands:
for sub in dict.fromkeys(group.commands.values()):
brief = sub.brief or sub.description or ""
paginator.add_line(
f"{ctx.prefix}{group.name} {sub.name} - {brief}".strip()
)
for page in paginator.pages:
await ctx.send(page)

View File

@ -1,9 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from importlib import import_module from importlib import import_module
import inspect
import sys import sys
from types import ModuleType from types import ModuleType
from typing import Dict from typing import Any, Coroutine, Dict, cast
__all__ = ["load_extension", "unload_extension", "reload_extension"] __all__ = ["load_extension", "unload_extension", "reload_extension"]
@ -25,7 +27,20 @@ def load_extension(name: str) -> ModuleType:
if not hasattr(module, "setup"): if not hasattr(module, "setup"):
raise ImportError(f"Extension '{name}' does not define a setup function") raise ImportError(f"Extension '{name}' does not define a setup function")
module.setup() result = module.setup()
if inspect.isawaitable(result):
coro = cast(Coroutine[Any, Any, Any], result)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
asyncio.run(coro)
else:
if loop.is_running():
future = asyncio.run_coroutine_threadsafe(coro, loop)
future.result()
else:
loop.run_until_complete(coro)
_loaded_extensions[name] = module _loaded_extensions[name] = module
return module return module
@ -38,7 +53,19 @@ def unload_extension(name: str) -> None:
raise ValueError(f"Extension '{name}' is not loaded") raise ValueError(f"Extension '{name}' is not loaded")
if hasattr(module, "teardown"): if hasattr(module, "teardown"):
module.teardown() result = module.teardown()
if inspect.isawaitable(result):
coro = cast(Coroutine[Any, Any, Any], result)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
asyncio.run(coro)
else:
if loop.is_running():
future = asyncio.run_coroutine_threadsafe(coro, loop)
future.result()
else:
loop.run_until_complete(coro)
sys.modules.pop(name, None) sys.modules.pop(name, None)

View File

@ -23,6 +23,7 @@ class Task:
) -> None: ) -> None:
self._coro = coro self._coro = coro
self._task: Optional[asyncio.Task[None]] = None self._task: Optional[asyncio.Task[None]] = None
self._current_loop = 0
if time_of_day is not None and ( if time_of_day is not None and (
seconds or minutes or hours or delta is not None seconds or minutes or hours or delta is not None
): ):
@ -68,6 +69,7 @@ class Task:
await _maybe_call(self._on_error, exc) await _maybe_call(self._on_error, exc)
else: else:
raise raise
self._current_loop += 1
first = False first = False
except asyncio.CancelledError: except asyncio.CancelledError:
@ -78,6 +80,7 @@ class Task:
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
if self._task is None or self._task.done(): if self._task is None or self._task.done():
self._current_loop = 0
self._task = asyncio.create_task(self._run(*args, **kwargs)) self._task = asyncio.create_task(self._run(*args, **kwargs))
return self._task return self._task
@ -90,6 +93,34 @@ class Task:
def running(self) -> bool: def running(self) -> bool:
return self._task is not None and not self._task.done() return self._task is not None and not self._task.done()
@property
def current_loop(self) -> int:
return self._current_loop
def change_interval(
self,
*,
seconds: float = 0.0,
minutes: float = 0.0,
hours: float = 0.0,
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
) -> None:
if time_of_day is not None and (
seconds or minutes or hours or delta is not None
):
raise ValueError("time_of_day cannot be used with an interval")
if delta is not None:
if not isinstance(delta, datetime.timedelta):
raise TypeError("delta must be a datetime.timedelta")
interval_seconds = delta.total_seconds()
else:
interval_seconds = seconds + minutes * 60.0 + hours * 3600.0
self._seconds = float(interval_seconds)
self._time_of_day = time_of_day
async def _maybe_call( async def _maybe_call(
func: Callable[[Exception], Awaitable[None] | None], exc: Exception func: Callable[[Exception], Awaitable[None] | None], exc: Exception
@ -181,10 +212,37 @@ class _Loop:
if self._task is not None: if self._task is not None:
self._task.stop() self._task.stop()
def change_interval(
self,
*,
seconds: float = 0.0,
minutes: float = 0.0,
hours: float = 0.0,
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
) -> None:
self.seconds = seconds
self.minutes = minutes
self.hours = hours
self.delta = delta
self.time_of_day = time_of_day
if self._task is not None:
self._task.change_interval(
seconds=seconds,
minutes=minutes,
hours=hours,
delta=delta,
time_of_day=time_of_day,
)
@property @property
def running(self) -> bool: def running(self) -> bool:
return self._task.running if self._task else False return self._task.running if self._task else False
@property
def current_loop(self) -> int:
return self._task.current_loop if self._task else 0
class _BoundLoop: class _BoundLoop:
def __init__(self, parent: _Loop, owner: Any) -> None: def __init__(self, parent: _Loop, owner: Any) -> None:
@ -202,6 +260,27 @@ class _BoundLoop:
def running(self) -> bool: def running(self) -> bool:
return self._parent.running return self._parent.running
def change_interval(
self,
*,
seconds: float = 0.0,
minutes: float = 0.0,
hours: float = 0.0,
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
) -> None:
self._parent.change_interval(
seconds=seconds,
minutes=minutes,
hours=hours,
delta=delta,
time_of_day=time_of_day,
)
@property
def current_loop(self) -> int:
return self._parent.current_loop
def loop( def loop(
*, *,

View File

@ -334,7 +334,19 @@ class GatewayClient:
self._resume_gateway_url, self._resume_gateway_url,
) )
# The client is now ready for operations. Set the event before dispatching to user code.
self._client_instance._ready_event.set()
logger.info("Client is now marked as ready.")
if isinstance(raw_event_d_payload, dict) and self._shard_id is not None:
raw_event_d_payload["shard_id"] = self._shard_id
await self._dispatcher.dispatch(event_name, raw_event_d_payload) await self._dispatcher.dispatch(event_name, raw_event_d_payload)
if (
getattr(self._client_instance, "sync_commands_on_ready", True)
and self._client_instance.application_id
):
asyncio.create_task(self._client_instance.sync_application_commands())
elif event_name == "GUILD_MEMBERS_CHUNK": elif event_name == "GUILD_MEMBERS_CHUNK":
if isinstance(raw_event_d_payload, dict): if isinstance(raw_event_d_payload, dict):
nonce = raw_event_d_payload.get("nonce") nonce = raw_event_d_payload.get("nonce")
@ -384,6 +396,8 @@ class GatewayClient:
event_data_to_dispatch = ( event_data_to_dispatch = (
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
) )
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
event_data_to_dispatch["shard_id"] = self._shard_id
await self._dispatcher.dispatch(event_name, event_data_to_dispatch) await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
await self._dispatcher.dispatch( await self._dispatcher.dispatch(
"SHARD_RESUME", {"shard_id": self._shard_id} "SHARD_RESUME", {"shard_id": self._shard_id}
@ -394,6 +408,8 @@ class GatewayClient:
event_data_to_dispatch = ( event_data_to_dispatch = (
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
) )
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
event_data_to_dispatch["shard_id"] = self._shard_id
await self._dispatcher.dispatch(event_name, event_data_to_dispatch) await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
else: else:
@ -553,6 +569,7 @@ class GatewayClient:
await self._dispatcher.dispatch( await self._dispatcher.dispatch(
"SHARD_CONNECT", {"shard_id": self._shard_id} "SHARD_CONNECT", {"shard_id": self._shard_id}
) )
await self._dispatcher.dispatch("CONNECT", {"shard_id": self._shard_id})
except aiohttp.ClientConnectorError as e: except aiohttp.ClientConnectorError as e:
raise GatewayException( raise GatewayException(
@ -608,6 +625,7 @@ class GatewayClient:
await self._dispatcher.dispatch( await self._dispatcher.dispatch(
"SHARD_DISCONNECT", {"shard_id": self._shard_id} "SHARD_DISCONNECT", {"shard_id": self._shard_id}
) )
await self._dispatcher.dispatch("DISCONNECT", {"shard_id": self._shard_id})
@property @property
def latency(self) -> Optional[float]: def latency(self) -> Optional[float]:

View File

@ -663,6 +663,15 @@ class HTTPClient:
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}") await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
async def crosspost_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> Dict[str, Any]:
"""Crossposts a message to any following channels."""
return await self.request(
"POST", f"/channels/{channel_id}/messages/{message_id}/crosspost"
)
async def delete_channel( async def delete_channel(
self, channel_id: str, reason: Optional[str] = None self, channel_id: str, reason: Optional[str] = None
) -> None: ) -> None:
@ -702,6 +711,22 @@ class HTTPClient:
"""Fetches a channel by ID.""" """Fetches a channel by ID."""
return await self.request("GET", f"/channels/{channel_id}") return await self.request("GET", f"/channels/{channel_id}")
async def create_guild_channel(
self,
guild_id: "Snowflake",
payload: Dict[str, Any],
reason: Optional[str] = None,
) -> Dict[str, Any]:
"""Creates a new channel in the specified guild."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
return await self.request(
"POST",
f"/guilds/{guild_id}/channels",
payload=payload,
custom_headers=headers,
)
async def get_channel_invites( async def get_channel_invites(
self, channel_id: "Snowflake" self, channel_id: "Snowflake"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
@ -721,11 +746,36 @@ class HTTPClient:
return Invite.from_dict(data) return Invite.from_dict(data)
async def create_channel_invite(
self,
channel_id: "Snowflake",
payload: Dict[str, Any],
*,
reason: Optional[str] = None,
) -> "Invite":
"""Creates an invite for a channel with an optional audit log reason."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
data = await self.request(
"POST",
f"/channels/{channel_id}/invites",
payload=payload,
custom_headers=headers,
)
from .models import Invite
return Invite.from_dict(data)
async def delete_invite(self, code: str) -> None: async def delete_invite(self, code: str) -> None:
"""Deletes an invite by code.""" """Deletes an invite by code."""
await self.request("DELETE", f"/invites/{code}") await self.request("DELETE", f"/invites/{code}")
async def get_invite(self, code: "Snowflake") -> Dict[str, Any]:
"""Fetches a single invite by its code."""
return await self.request("GET", f"/invites/{code}")
async def create_webhook( async def create_webhook(
self, channel_id: "Snowflake", payload: Dict[str, Any] self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook": ) -> "Webhook":
@ -738,6 +788,11 @@ class HTTPClient:
return Webhook(data) return Webhook(data)
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
"""Fetches a webhook by ID and returns the raw payload."""
return await self.request("GET", f"/webhooks/{webhook_id}")
async def edit_webhook( async def edit_webhook(
self, webhook_id: "Snowflake", payload: Dict[str, Any] self, webhook_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook": ) -> "Webhook":
@ -753,6 +808,24 @@ class HTTPClient:
await self.request("DELETE", f"/webhooks/{webhook_id}") await self.request("DELETE", f"/webhooks/{webhook_id}")
async def get_webhook_with_token(
self, webhook_id: "Snowflake", token: Optional[str] = None
) -> "Webhook":
"""Fetches a webhook by ID, optionally using its token."""
endpoint = f"/webhooks/{webhook_id}"
use_auth = True
if token is not None:
endpoint += f"/{token}"
use_auth = False
if use_auth:
data = await self.request("GET", endpoint)
else:
data = await self.request("GET", endpoint, use_auth_header=False)
from .models import Webhook
return Webhook(data)
async def execute_webhook( async def execute_webhook(
self, self,
webhook_id: "Snowflake", webhook_id: "Snowflake",
@ -839,13 +912,13 @@ class HTTPClient:
use_auth_header=False, use_auth_header=False,
) )
async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]: async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]:
"""Fetches a user object for a given user ID.""" """Fetches a user object for a given user ID."""
return await self.request("GET", f"/users/{user_id}") return await self.request("GET", f"/users/{user_id}")
async def get_current_user_guilds(self) -> List[Dict[str, Any]]: async def get_current_user_guilds(self) -> List[Dict[str, Any]]:
"""Returns the guilds the current user is in.""" """Returns the guilds the current user is in."""
return await self.request("GET", "/users/@me/guilds") return await self.request("GET", "/users/@me/guilds")
async def get_guild_member( async def get_guild_member(
self, guild_id: "Snowflake", user_id: "Snowflake" self, guild_id: "Snowflake", user_id: "Snowflake"
@ -902,6 +975,29 @@ class HTTPClient:
custom_headers=headers, custom_headers=headers,
) )
async def get_guild_prune_count(self, guild_id: "Snowflake", *, days: int) -> int:
"""Returns the number of members that would be pruned."""
data = await self.request(
"GET",
f"/guilds/{guild_id}/prune",
params={"days": days},
)
return int(data.get("pruned", 0))
async def begin_guild_prune(
self, guild_id: "Snowflake", *, days: int, compute_count: bool = True
) -> int:
"""Begins a prune operation for the guild and returns the count."""
payload = {"days": days, "compute_prune_count": compute_count}
data = await self.request(
"POST",
f"/guilds/{guild_id}/prune",
payload=payload,
)
return int(data.get("pruned", 0))
async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]: async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
"""Returns a list of role objects for the guild.""" """Returns a list of role objects for the guild."""
return await self.request("GET", f"/guilds/{guild_id}/roles") return await self.request("GET", f"/guilds/{guild_id}/roles")
@ -1364,3 +1460,8 @@ class HTTPClient:
async def leave_thread(self, channel_id: "Snowflake") -> None: async def leave_thread(self, channel_id: "Snowflake") -> None:
"""Removes the current user from a thread.""" """Removes the current user from a thread."""
await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me") await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me")
async def create_dm(self, recipient_id: "Snowflake") -> Dict[str, Any]:
"""Creates (or opens) a DM channel with the given user."""
payload = {"recipient_id": str(recipient_id)}
return await self.request("POST", "/users/@me/channels", payload=payload)

View File

@ -2,11 +2,26 @@
Data models for Discord objects. Data models for Discord objects.
""" """
from __future__ import annotations
import asyncio import asyncio
import datetime
import io
import json import json
import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
TYPE_CHECKING,
Union,
cast,
IO,
)
from .cache import ChannelCache, MemberCache from .cache import ChannelCache, MemberCache
from .caching import MemberCacheFlags from .caching import MemberCacheFlags
@ -40,29 +55,43 @@ if TYPE_CHECKING:
from .ui.view import View from .ui.view import View
from .interactions import Snowflake from .interactions import Snowflake
from .typing import Typing from .typing import Typing
from .shard_manager import Shard
from .asset import Asset
# Forward reference Message if it were used in type hints before its definition # Forward reference Message if it were used in type hints before its definition
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
from .components import component_factory from .components import component_factory
class User: class HashableById:
"""Represents a Discord User. """Mixin providing equality and hashing based on the ``id`` attribute."""
Attributes: id: str
id (str): The user's unique ID.
username (str): The user's username.
discriminator (str): The user's 4-digit discord-tag.
bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False.
avatar (Optional[str]): The user's avatar hash, if any.
"""
def __init__(self, data: dict): def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.id == other.id # type: ignore[attr-defined]
def __hash__(self) -> int: # pragma: no cover - trivial
return hash(self.id)
class User(HashableById):
"""Represents a Discord User."""
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
self._client = client_instance
if "id" not in data and "user" in data:
data = data["user"]
self.id: str = data["id"] self.id: str = data["id"]
self.username: str = data["username"] self.username: Optional[str] = data.get("username")
self.discriminator: str = data["discriminator"] self.discriminator: Optional[str] = data.get("discriminator")
self.bot: bool = data.get("bot", False) self.bot: bool = data.get("bot", False)
self.avatar: Optional[str] = data.get("avatar") avatar_hash = data.get("avatar")
self._avatar: Optional[str] = (
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
if avatar_hash
else None
)
@property @property
def mention(self) -> str: def mention(self) -> str:
@ -70,10 +99,45 @@ class User:
return f"<@{self.id}>" return f"<@{self.id}>"
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<User id='{self.id}' username='{self.username}' discriminator='{self.discriminator}'>" username = self.username or "Unknown"
disc = self.discriminator or "????"
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
@property
def avatar(self) -> Optional["Asset"]:
"""Return the user's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
async def send(
self,
content: Optional[str] = None,
*,
client: Optional["Client"] = None,
**kwargs: Any,
) -> "Message":
"""Send a direct message to this user."""
target_client = client or self._client
if target_client is None:
raise DisagreementException("User.send requires a Client instance")
return await target_client.send_dm(self.id, content=content, **kwargs)
class Message: class Message(HashableById):
"""Represents a message sent in a channel on Discord. """Represents a message sent in a channel on Discord.
Attributes: Attributes:
@ -96,9 +160,10 @@ class Message:
self.id: str = data["id"] self.id: str = data["id"]
self.channel_id: str = data["channel_id"] self.channel_id: str = data["channel_id"]
self.guild_id: Optional[str] = data.get("guild_id") self.guild_id: Optional[str] = data.get("guild_id")
self.author: User = User(data["author"]) self.author: User = User(data["author"], client_instance)
self.content: str = data["content"] self.content: str = data["content"]
self.timestamp: str = data["timestamp"] self.timestamp: str = data["timestamp"]
self.edited_timestamp: Optional[str] = data.get("edited_timestamp")
if data.get("components"): if data.get("components"):
self.components: Optional[List[ActionRow]] = [ self.components: Optional[List[ActionRow]] = [
ActionRow.from_dict(c, client_instance) ActionRow.from_dict(c, client_instance)
@ -106,21 +171,21 @@ class Message:
] ]
else: else:
self.components = None self.components = None
self.attachments: List[Attachment] = [ self.attachments: List[Attachment] = [
Attachment(a) for a in data.get("attachments", []) Attachment(a) for a in data.get("attachments", [])
] ]
self.pinned: bool = data.get("pinned", False) self.pinned: bool = data.get("pinned", False)
# Add other fields as needed, e.g., attachments, embeds, reactions, etc. # Add other fields as needed, e.g., attachments, embeds, reactions, etc.
# self.mentions: List[User] = [User(u) for u in data.get("mentions", [])] # self.mentions: List[User] = [User(u) for u in data.get("mentions", [])]
# self.mention_roles: List[str] = data.get("mention_roles", []) # self.mention_roles: List[str] = data.get("mention_roles", [])
# self.mention_everyone: bool = data.get("mention_everyone", False) # self.mention_everyone: bool = data.get("mention_everyone", False)
@property @property
def jump_url(self) -> str: def jump_url(self) -> str:
"""Return a URL that jumps to this message in the Discord client.""" """Return a URL that jumps to this message in the Discord client."""
guild_or_dm = self.guild_id or "@me" guild_or_dm = self.guild_id or "@me"
return f"https://discord.com/channels/{guild_or_dm}/{self.channel_id}/{self.id}" return f"https://discord.com/channels/{guild_or_dm}/{self.channel_id}/{self.id}"
@property @property
def clean_content(self) -> str: def clean_content(self) -> str:
@ -130,6 +195,20 @@ class Message:
cleaned = pattern.sub("", self.content) cleaned = pattern.sub("", self.content)
return " ".join(cleaned.split()) return " ".join(cleaned.split())
@property
def created_at(self) -> datetime.datetime:
"""Return message timestamp as a :class:`~datetime.datetime`."""
return datetime.datetime.fromisoformat(self.timestamp)
@property
def edited_at(self) -> Optional[datetime.datetime]:
"""Return edited timestamp as :class:`~datetime.datetime` if present."""
if self.edited_timestamp is None:
return None
return datetime.datetime.fromisoformat(self.edited_timestamp)
async def pin(self) -> None: async def pin(self) -> None:
"""|coro| """|coro|
@ -156,6 +235,15 @@ class Message:
await self._client._http.unpin_message(self.channel_id, self.id) await self._client._http.unpin_message(self.channel_id, self.id)
self.pinned = False self.pinned = False
async def crosspost(self) -> "Message":
"""|coro|
Crossposts this message to all follower channels and return the resulting message.
"""
data = await self._client._http.crosspost_message(self.channel_id, self.id)
return self._client.parse_message(data)
async def reply( async def reply(
self, self,
content: Optional[str] = None, content: Optional[str] = None,
@ -198,10 +286,10 @@ class Message:
mention_author = getattr(self._client, "mention_replies", False) mention_author = getattr(self._client, "mention_replies", False)
if allowed_mentions is None: if allowed_mentions is None:
allowed_mentions = {"replied_user": mention_author} allowed_mentions = dict(getattr(self._client, "allowed_mentions", {}) or {})
else: else:
allowed_mentions = dict(allowed_mentions) allowed_mentions = dict(allowed_mentions)
allowed_mentions.setdefault("replied_user", mention_author) allowed_mentions.setdefault("replied_user", mention_author)
# Client.send_message is already updated to handle these parameters # Client.send_message is already updated to handle these parameters
return await self._client.send_message( return await self._client.send_message(
@ -624,38 +712,72 @@ class Attachment:
class File: class File:
"""Represents a file to be uploaded.""" """Represents a file to be uploaded.
def __init__(self, filename: str, data: bytes): Parameters
self.filename = filename ----------
self.data = data fp:
A file path, file-like object, or bytes-like object containing the
data to upload.
filename:
Optional name of the file. If not provided and ``fp`` is a path or has
a ``name`` attribute, the name will be inferred.
spoiler:
When ``True`` the filename will be prefixed with ``"SPOILER_"``.
"""
def __init__(
self,
fp: Union[str, bytes, os.PathLike[Any], IO[bytes]],
*,
filename: Optional[str] = None,
spoiler: bool = False,
) -> None:
if isinstance(fp, (str, os.PathLike)):
self.data = open(fp, "rb")
inferred = os.path.basename(fp)
elif isinstance(fp, bytes):
self.data = io.BytesIO(fp)
inferred = None
else:
self.data = fp
inferred = getattr(fp, "name", None)
name = filename or inferred
if name is None:
raise ValueError("filename could not be inferred")
if spoiler and not name.startswith("SPOILER_"):
name = f"SPOILER_{name}"
self.filename = name
self.spoiler = spoiler
class AllowedMentions: class AllowedMentions:
"""Represents allowed mentions for a message or interaction response.""" """Represents allowed mentions for a message or interaction response."""
def __init__(self, data: Dict[str, Any]): def __init__(self, data: Dict[str, Any]):
self.parse: List[str] = data.get("parse", []) self.parse: List[str] = data.get("parse", [])
self.roles: List[str] = data.get("roles", []) self.roles: List[str] = data.get("roles", [])
self.users: List[str] = data.get("users", []) self.users: List[str] = data.get("users", [])
self.replied_user: bool = data.get("replied_user", False) self.replied_user: bool = data.get("replied_user", False)
@classmethod @classmethod
def all(cls) -> "AllowedMentions": def all(cls) -> "AllowedMentions":
"""Return an instance allowing all mention types.""" """Return an instance allowing all mention types."""
return cls( return cls(
{ {
"parse": ["users", "roles", "everyone"], "parse": ["users", "roles", "everyone"],
"replied_user": True, "replied_user": True,
} }
) )
@classmethod @classmethod
def none(cls) -> "AllowedMentions": def none(cls) -> "AllowedMentions":
"""Return an instance disallowing all mentions.""" """Return an instance disallowing all mentions."""
return cls({"parse": [], "replied_user": False}) return cls({"parse": [], "replied_user": False})
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
payload: Dict[str, Any] = {"parse": self.parse} payload: Dict[str, Any] = {"parse": self.parse}
@ -697,7 +819,12 @@ class Role:
self.name: str = data["name"] self.name: str = data["name"]
self.color: int = data["color"] self.color: int = data["color"]
self.hoist: bool = data["hoist"] self.hoist: bool = data["hoist"]
self.icon: Optional[str] = data.get("icon") icon_hash = data.get("icon")
self._icon: Optional[str] = (
f"https://cdn.discordapp.com/role-icons/{self.id}/{icon_hash}.png"
if icon_hash
else None
)
self.unicode_emoji: Optional[str] = data.get("unicode_emoji") self.unicode_emoji: Optional[str] = data.get("unicode_emoji")
self.position: int = data["position"] self.position: int = data["position"]
self.permissions: str = data["permissions"] # String of bitwise permissions self.permissions: str = data["permissions"] # String of bitwise permissions
@ -715,6 +842,23 @@ class Role:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Role id='{self.id}' name='{self.name}'>" return f"<Role id='{self.id}' name='{self.name}'>"
@property
def icon(self) -> Optional["Asset"]:
if self._icon:
from .asset import Asset
return Asset(self._icon, None)
return None
@icon.setter
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._icon = value
elif value is None:
self._icon = None
else:
self._icon = value.url
class Member(User): # Member inherits from User class Member(User): # Member inherits from User
"""Represents a Guild Member. """Represents a Guild Member.
@ -743,12 +887,18 @@ class Member(User): # Member inherits from User
) # Pass user_data or data if user_data is empty ) # Pass user_data or data if user_data is empty
self.nick: Optional[str] = data.get("nick") self.nick: Optional[str] = data.get("nick")
self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash avatar_hash = data.get("avatar")
self.roles: List[str] = data.get("roles", []) # List of role IDs if avatar_hash:
self.joined_at: str = data["joined_at"] # ISO8601 timestamp guild_id = data.get("guild_id")
self.premium_since: Optional[str] = data.get( if guild_id:
"premium_since" self._avatar = f"https://cdn.discordapp.com/guilds/{guild_id}/users/{self.id}/avatars/{avatar_hash}.png"
) # ISO8601 timestamp else:
self._avatar = (
f"https://cdn.discordapp.com/avatars/{self.id}/{avatar_hash}.png"
)
self.roles: List[str] = data.get("roles", [])
self.joined_at: str = data["joined_at"]
self.premium_since: Optional[str] = data.get("premium_since")
self.deaf: bool = data.get("deaf", False) self.deaf: bool = data.get("deaf", False)
self.mute: bool = data.get("mute", False) self.mute: bool = data.get("mute", False)
self.pending: bool = data.get("pending", False) self.pending: bool = data.get("pending", False)
@ -758,6 +908,7 @@ class Member(User): # Member inherits from User
self.communication_disabled_until: Optional[str] = data.get( self.communication_disabled_until: Optional[str] = data.get(
"communication_disabled_until" "communication_disabled_until"
) # ISO8601 timestamp ) # ISO8601 timestamp
self.voice_state = data.get("voice_state")
# If 'user' object was present, ensure User attributes are from there # If 'user' object was present, ensure User attributes are from there
if user_data: if user_data:
@ -772,11 +923,30 @@ class Member(User): # Member inherits from User
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>" return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
@property
def avatar(self) -> Optional["Asset"]:
"""Return the member's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
@property @property
def display_name(self) -> str: def display_name(self) -> str:
"""Return the nickname if set, otherwise the username.""" """Return the nickname if set, otherwise the username."""
return self.nick or self.username return self.nick or self.username or ""
async def kick(self, *, reason: Optional[str] = None) -> None: async def kick(self, *, reason: Optional[str] = None) -> None:
if not self.guild_id or not self._client: if not self.guild_id or not self._client:
@ -838,6 +1008,41 @@ class Member(User): # Member inherits from User
return max(role_objects, key=lambda r: r.position) return max(role_objects, key=lambda r: r.position)
@property
def guild_permissions(self) -> "Permissions":
"""Return the member's guild-level permissions."""
if not self.guild_id or not self._client:
return Permissions(0)
guild = self._client.get_guild(self.guild_id)
if guild is None:
return Permissions(0)
base = Permissions(0)
everyone = guild.get_role(guild.id)
if everyone is not None:
base |= Permissions(int(everyone.permissions))
for rid in self.roles:
role = guild.get_role(rid)
if role is not None:
base |= Permissions(int(role.permissions))
if base & Permissions.ADMINISTRATOR:
return Permissions(~0)
return base
@property
def voice(self) -> Optional["VoiceState"]:
"""Return the member's cached voice state as a :class:`VoiceState`."""
if self.voice_state is None:
return None
return VoiceState.from_dict(self.voice_state)
class PartialEmoji: class PartialEmoji:
"""Represents a partial emoji, often used in components or reactions. """Represents a partial emoji, often used in components or reactions.
@ -1035,7 +1240,7 @@ class PermissionOverwrite:
return f"<PermissionOverwrite id='{self.id}' type='{self.type.name if hasattr(self.type, 'name') else self._type_val}' allow='{self.allow}' deny='{self.deny}'>" return f"<PermissionOverwrite id='{self.id}' type='{self.type.name if hasattr(self.type, 'name') else self._type_val}' allow='{self.allow}' deny='{self.deny}'>"
class Guild: class Guild(HashableById):
"""Represents a Discord Guild (Server). """Represents a Discord Guild (Server).
Attributes: Attributes:
@ -1075,15 +1280,42 @@ class Guild:
nsfw_level (GuildNSFWLevel): Guild NSFW level. nsfw_level (GuildNSFWLevel): Guild NSFW level.
stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model) stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model)
premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled. premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled.
text_channels (List[TextChannel]): List of text-based channels in this guild.
voice_channels (List[VoiceChannel]): List of voice-based channels in this guild.
category_channels (List[CategoryChannel]): List of category channels in this guild.
""" """
def __init__(self, data: Dict[str, Any], client_instance: "Client"): def __init__(
self,
data: Dict[str, Any],
client_instance: "Client",
*,
shard_id: Optional[int] = None,
):
self._client: "Client" = client_instance self._client: "Client" = client_instance
self._shard_id: Optional[int] = (
shard_id if shard_id is not None else data.get("shard_id")
)
self.id: str = data["id"] self.id: str = data["id"]
self.name: str = data["name"] self.name: str = data["name"]
self.icon: Optional[str] = data.get("icon") icon_hash = data.get("icon")
self.splash: Optional[str] = data.get("splash") self._icon: Optional[str] = (
self.discovery_splash: Optional[str] = data.get("discovery_splash") f"https://cdn.discordapp.com/icons/{self.id}/{icon_hash}.png"
if icon_hash
else None
)
splash_hash = data.get("splash")
self._splash: Optional[str] = (
f"https://cdn.discordapp.com/splashes/{self.id}/{splash_hash}.png"
if splash_hash
else None
)
discovery_hash = data.get("discovery_splash")
self._discovery_splash: Optional[str] = (
f"https://cdn.discordapp.com/discovery-splashes/{self.id}/{discovery_hash}.png"
if discovery_hash
else None
)
self.owner: Optional[bool] = data.get("owner") self.owner: Optional[bool] = data.get("owner")
self.owner_id: str = data["owner_id"] self.owner_id: str = data["owner_id"]
self.permissions: Optional[str] = data.get("permissions") self.permissions: Optional[str] = data.get("permissions")
@ -1120,7 +1352,12 @@ class Guild:
self.max_members: Optional[int] = data.get("max_members") self.max_members: Optional[int] = data.get("max_members")
self.vanity_url_code: Optional[str] = data.get("vanity_url_code") self.vanity_url_code: Optional[str] = data.get("vanity_url_code")
self.description: Optional[str] = data.get("description") self.description: Optional[str] = data.get("description")
self.banner: Optional[str] = data.get("banner") banner_hash = data.get("banner")
self._banner: Optional[str] = (
f"https://cdn.discordapp.com/banners/{self.id}/{banner_hash}.png"
if banner_hash
else None
)
self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"]) self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"])
self.premium_subscription_count: Optional[int] = data.get( self.premium_subscription_count: Optional[int] = data.get(
"premium_subscription_count" "premium_subscription_count"
@ -1159,6 +1396,28 @@ class Guild:
getattr(client_instance, "member_cache_flags", MemberCacheFlags()) getattr(client_instance, "member_cache_flags", MemberCacheFlags())
) )
self._threads: Dict[str, "Thread"] = {} self._threads: Dict[str, "Thread"] = {}
self.text_channels: List["TextChannel"] = []
self.voice_channels: List["VoiceChannel"] = []
self.category_channels: List["CategoryChannel"] = []
@property
def shard_id(self) -> Optional[int]:
"""ID of the shard that received this guild, if any."""
return self._shard_id
@property
def shard(self) -> Optional["Shard"]:
"""The :class:`Shard` this guild belongs to."""
if self._shard_id is None:
return None
manager = getattr(self._client, "_shard_manager", None)
if not manager:
return None
if 0 <= self._shard_id < len(manager.shards):
return manager.shards[self._shard_id]
return None
def get_channel(self, channel_id: str) -> Optional["Channel"]: def get_channel(self, channel_id: str) -> Optional["Channel"]:
return self._channels.get(channel_id) return self._channels.get(channel_id)
@ -1185,7 +1444,7 @@ class Guild:
lowered = name.lower() lowered = name.lower()
for member in self._members.values(): for member in self._members.values():
if member.username.lower() == lowered: if member.username and member.username.lower() == lowered:
return member return member
if member.nick and member.nick.lower() == lowered: if member.nick and member.nick.lower() == lowered:
return member return member
@ -1194,9 +1453,86 @@ class Guild:
def get_role(self, role_id: str) -> Optional[Role]: def get_role(self, role_id: str) -> Optional[Role]:
return next((role for role in self.roles if role.id == role_id), None) return next((role for role in self.roles if role.id == role_id), None)
@property
def me(self) -> Optional[Member]:
"""The member object for the connected bot in this guild, if present."""
client_user = getattr(self._client, "user", None)
if not client_user:
return None
return self.get_member(client_user.id)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Guild id='{self.id}' name='{self.name}'>" return f"<Guild id='{self.id}' name='{self.name}'>"
@property
def icon(self) -> Optional["Asset"]:
if self._icon:
from .asset import Asset
return Asset(self._icon, self._client)
return None
@icon.setter
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._icon = value
elif value is None:
self._icon = None
else:
self._icon = value.url
@property
def splash(self) -> Optional["Asset"]:
if self._splash:
from .asset import Asset
return Asset(self._splash, self._client)
return None
@splash.setter
def splash(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._splash = value
elif value is None:
self._splash = None
else:
self._splash = value.url
@property
def discovery_splash(self) -> Optional["Asset"]:
if self._discovery_splash:
from .asset import Asset
return Asset(self._discovery_splash, self._client)
return None
@discovery_splash.setter
def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._discovery_splash = value
elif value is None:
self._discovery_splash = None
else:
self._discovery_splash = value.url
@property
def banner(self) -> Optional["Asset"]:
if self._banner:
from .asset import Asset
return Asset(self._banner, self._client)
return None
@banner.setter
def banner(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._banner = value
elif value is None:
self._banner = None
else:
self._banner = value.url
async def fetch_widget(self) -> Dict[str, Any]: async def fetch_widget(self) -> Dict[str, Any]:
"""|coro| Fetch this guild's widget settings.""" """|coro| Fetch this guild's widget settings."""
@ -1250,8 +1586,82 @@ class Guild:
del self._client._gateway._member_chunk_requests[nonce] del self._client._gateway._member_chunk_requests[nonce]
raise raise
async def prune_members(self, days: int, *, compute_count: bool = True) -> int:
"""|coro| Remove inactive members from the guild.
class Channel: Parameters
----------
days: int
Number of days of inactivity required to be pruned.
compute_count: bool
Whether to return the number of members pruned.
Returns
-------
int
The number of members pruned.
"""
return await self._client._http.begin_guild_prune(
self.id, days=days, compute_count=compute_count
)
async def create_text_channel(
self,
name: str,
*,
reason: Optional[str] = None,
**options: Any,
) -> "TextChannel":
"""|coro| Create a new text channel in this guild."""
payload: Dict[str, Any] = {"name": name, "type": ChannelType.GUILD_TEXT.value}
payload.update(options)
data = await self._client._http.create_guild_channel(
self.id, payload, reason=reason
)
return cast("TextChannel", self._client.parse_channel(data))
async def create_voice_channel(
self,
name: str,
*,
reason: Optional[str] = None,
**options: Any,
) -> "VoiceChannel":
"""|coro| Create a new voice channel in this guild."""
payload: Dict[str, Any] = {
"name": name,
"type": ChannelType.GUILD_VOICE.value,
}
payload.update(options)
data = await self._client._http.create_guild_channel(
self.id, payload, reason=reason
)
return cast("VoiceChannel", self._client.parse_channel(data))
async def create_category(
self,
name: str,
*,
reason: Optional[str] = None,
**options: Any,
) -> "CategoryChannel":
"""|coro| Create a new category channel in this guild."""
payload: Dict[str, Any] = {
"name": name,
"type": ChannelType.GUILD_CATEGORY.value,
}
payload.update(options)
data = await self._client._http.create_guild_channel(
self.id, payload, reason=reason
)
return cast("CategoryChannel", self._client.parse_channel(data))
class Channel(HashableById):
"""Base class for Discord channels.""" """Base class for Discord channels."""
def __init__(self, data: Dict[str, Any], client_instance: "Client"): def __init__(self, data: Dict[str, Any], client_instance: "Client"):
@ -1531,6 +1941,31 @@ class TextChannel(Channel, Messageable):
data = await self._client._http.start_thread_without_message(self.id, payload) data = await self._client._http.start_thread_without_message(self.id, payload)
return cast("Thread", self._client.parse_channel(data)) return cast("Thread", self._client.parse_channel(data))
async def create_invite(
self,
*,
max_age: Optional[int] = None,
max_uses: Optional[int] = None,
temporary: Optional[bool] = None,
unique: Optional[bool] = None,
reason: Optional[str] = None,
) -> "Invite":
"""|coro| Create an invite to this channel."""
payload: Dict[str, Any] = {}
if max_age is not None:
payload["max_age"] = max_age
if max_uses is not None:
payload["max_uses"] = max_uses
if temporary is not None:
payload["temporary"] = temporary
if unique is not None:
payload["unique"] = unique
return await self._client._http.create_channel_invite(
self.id, payload, reason=reason
)
class VoiceChannel(Channel): class VoiceChannel(Channel):
"""Represents a guild voice channel or stage voice channel.""" """Represents a guild voice channel or stage voice channel."""
@ -1830,7 +2265,12 @@ class Webhook:
self.guild_id: Optional[str] = data.get("guild_id") self.guild_id: Optional[str] = data.get("guild_id")
self.channel_id: Optional[str] = data.get("channel_id") self.channel_id: Optional[str] = data.get("channel_id")
self.name: Optional[str] = data.get("name") self.name: Optional[str] = data.get("name")
self.avatar: Optional[str] = data.get("avatar") avatar_hash = data.get("avatar")
self._avatar: Optional[str] = (
f"https://cdn.discordapp.com/webhooks/{self.id}/{avatar_hash}.png"
if avatar_hash
else None
)
self.token: Optional[str] = data.get("token") self.token: Optional[str] = data.get("token")
self.application_id: Optional[str] = data.get("application_id") self.application_id: Optional[str] = data.get("application_id")
self.url: Optional[str] = data.get("url") self.url: Optional[str] = data.get("url")
@ -1839,6 +2279,25 @@ class Webhook:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Webhook id='{self.id}' name='{self.name}'>" return f"<Webhook id='{self.id}' name='{self.name}'>"
@property
def avatar(self) -> Optional["Asset"]:
"""Return the webhook's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
@classmethod @classmethod
def from_url( def from_url(
cls, url: str, session: Optional[aiohttp.ClientSession] = None cls, url: str, session: Optional[aiohttp.ClientSession] = None
@ -1866,6 +2325,33 @@ class Webhook:
return cls({"id": webhook_id, "token": token, "url": url}) return cls({"id": webhook_id, "token": token, "url": url})
@classmethod
def from_token(
cls,
webhook_id: str,
token: str,
session: Optional[aiohttp.ClientSession] = None,
) -> "Webhook":
"""Create a minimal :class:`Webhook` from an ID and token.
Parameters
----------
webhook_id:
The ID of the webhook.
token:
The webhook token.
session:
Unused for now. Present for API compatibility.
Returns
-------
Webhook
A webhook instance containing only the ``id``, ``token`` and ``url``.
"""
url = f"https://discord.com/api/webhooks/{webhook_id}/{token}"
return cls({"id": webhook_id, "token": token, "url": url})
async def send( async def send(
self, self,
content: Optional[str] = None, content: Optional[str] = None,
@ -2471,7 +2957,7 @@ class PresenceUpdate:
self, data: Dict[str, Any], client_instance: Optional["Client"] = None self, data: Dict[str, Any], client_instance: Optional["Client"] = None
): ):
self._client = client_instance self._client = client_instance
self.user = User(data["user"]) self.user = User(data["user"], client_instance)
self.guild_id: Optional[str] = data.get("guild_id") self.guild_id: Optional[str] = data.get("guild_id")
self.status: Optional[str] = data.get("status") self.status: Optional[str] = data.get("status")
self.activities: List[Activity] = [] self.activities: List[Activity] = []
@ -2491,7 +2977,7 @@ class PresenceUpdate:
return f"<PresenceUpdate user_id='{self.user.id}' guild_id='{self.guild_id}' status='{self.status}'>" return f"<PresenceUpdate user_id='{self.user.id}' guild_id='{self.guild_id}' status='{self.status}'>"
class TypingStart: class TypingStart:
"""Represents a TYPING_START event.""" """Represents a TYPING_START event."""
def __init__( def __init__(
@ -2504,39 +2990,78 @@ class TypingStart:
self.timestamp: int = data["timestamp"] self.timestamp: int = data["timestamp"]
self.member: Optional[Member] = ( self.member: Optional[Member] = (
Member(data["member"], client_instance) if data.get("member") else None Member(data["member"], client_instance) if data.get("member") else None
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<TypingStart channel_id='{self.channel_id}' user_id='{self.user_id}'>" return f"<TypingStart channel_id='{self.channel_id}' user_id='{self.user_id}'>"
class VoiceStateUpdate: class VoiceStateUpdate:
"""Represents a VOICE_STATE_UPDATE event.""" """Represents a VOICE_STATE_UPDATE event."""
def __init__( def __init__(
self, data: Dict[str, Any], client_instance: Optional["Client"] = None self, data: Dict[str, Any], client_instance: Optional["Client"] = None
): ):
self._client = client_instance self._client = client_instance
self.guild_id: Optional[str] = data.get("guild_id") self.guild_id: Optional[str] = data.get("guild_id")
self.channel_id: Optional[str] = data.get("channel_id") self.channel_id: Optional[str] = data.get("channel_id")
self.user_id: str = data["user_id"] self.user_id: str = data["user_id"]
self.member: Optional[Member] = ( self.member: Optional[Member] = (
Member(data["member"], client_instance) if data.get("member") else None Member(data["member"], client_instance) if data.get("member") else None
) )
self.session_id: str = data["session_id"] self.session_id: str = data["session_id"]
self.deaf: bool = data.get("deaf", False) self.deaf: bool = data.get("deaf", False)
self.mute: bool = data.get("mute", False) self.mute: bool = data.get("mute", False)
self.self_deaf: bool = data.get("self_deaf", False) self.self_deaf: bool = data.get("self_deaf", False)
self.self_mute: bool = data.get("self_mute", False) self.self_mute: bool = data.get("self_mute", False)
self.self_stream: Optional[bool] = data.get("self_stream") self.self_stream: Optional[bool] = data.get("self_stream")
self.self_video: bool = data.get("self_video", False) self.self_video: bool = data.get("self_video", False)
self.suppress: bool = data.get("suppress", False) self.suppress: bool = data.get("suppress", False)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<VoiceStateUpdate guild_id='{self.guild_id}' user_id='{self.user_id}' " f"<VoiceStateUpdate guild_id='{self.guild_id}' user_id='{self.user_id}' "
f"channel_id='{self.channel_id}'>" f"channel_id='{self.channel_id}'>"
) )
@dataclass
class VoiceState:
"""Represents a cached voice state for a member."""
guild_id: Optional[str]
channel_id: Optional[str]
user_id: Optional[str]
session_id: Optional[str]
deaf: bool = False
mute: bool = False
self_deaf: bool = False
self_mute: bool = False
self_stream: Optional[bool] = None
self_video: bool = False
suppress: bool = False
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "VoiceState":
return cls(
guild_id=data.get("guild_id"),
channel_id=data.get("channel_id"),
user_id=data.get("user_id"),
session_id=data.get("session_id"),
deaf=data.get("deaf", False),
mute=data.get("mute", False),
self_deaf=data.get("self_deaf", False),
self_mute=data.get("self_mute", False),
self_stream=data.get("self_stream"),
self_video=data.get("self_video", False),
suppress=data.get("suppress", False),
)
def __repr__(self) -> str:
return (
f"<VoiceState guild_id='{self.guild_id}' user_id='{self.user_id}' "
f"channel_id='{self.channel_id}'>"
)
class Reaction: class Reaction:
@ -2631,6 +3156,18 @@ class Invite:
return f"<Invite code='{self.code}' guild_id='{self.guild_id}' channel_id='{self.channel_id}'>" return f"<Invite code='{self.code}' guild_id='{self.guild_id}' channel_id='{self.channel_id}'>"
class InviteDelete:
"""Represents an INVITE_DELETE event."""
def __init__(self, data: Dict[str, Any]):
self.channel_id: str = data["channel_id"]
self.guild_id: Optional[str] = data.get("guild_id")
self.code: str = data["code"]
def __repr__(self) -> str:
return f"<InviteDelete code='{self.code}' guild_id='{self.guild_id}' channel_id='{self.channel_id}'>"
class GuildMemberRemove: class GuildMemberRemove:
"""Represents a GUILD_MEMBER_REMOVE event.""" """Represents a GUILD_MEMBER_REMOVE event."""

19
disagreement/object.py Normal file
View File

@ -0,0 +1,19 @@
class Object:
"""A minimal wrapper around a Discord snowflake ID."""
__slots__ = ("id",)
def __init__(self, object_id: int) -> None:
self.id = int(object_id)
def __int__(self) -> int:
return self.id
def __hash__(self) -> int:
return hash(self.id)
def __eq__(self, other: object) -> bool:
return isinstance(other, Object) and self.id == other.id
def __repr__(self) -> str:
return f"<Object id={self.id}>"

View File

@ -57,6 +57,15 @@ class Permissions(IntFlag):
USE_EXTERNAL_SOUNDS = 1 << 45 USE_EXTERNAL_SOUNDS = 1 << 45
SEND_VOICE_MESSAGES = 1 << 46 SEND_VOICE_MESSAGES = 1 << 46
@classmethod
def all(cls) -> "Permissions":
"""Return a ``Permissions`` object with every permission bit enabled."""
value = 0
for perm in cls:
value |= perm.value
return cls(value)
def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int: def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int:
"""Return a combined integer value for multiple permissions.""" """Return a combined integer value for multiple permissions."""

View File

@ -3,7 +3,11 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING from typing import Any, AsyncIterator, Dict, Iterable, Optional, TYPE_CHECKING, Callable
import re
# Discord epoch in milliseconds (2015-01-01T00:00:00Z)
DISCORD_EPOCH = 1420070400000
if TYPE_CHECKING: # pragma: no cover - for type hinting only if TYPE_CHECKING: # pragma: no cover - for type hinting only
from .models import Message, TextChannel from .models import Message, TextChannel
@ -14,6 +18,27 @@ def utcnow() -> datetime:
return datetime.now(timezone.utc) return datetime.now(timezone.utc)
def find(predicate: Callable[[Any], bool], iterable: Iterable[Any]) -> Optional[Any]:
"""Return the first element in ``iterable`` matching the ``predicate``."""
for element in iterable:
if predicate(element):
return element
return None
def get(iterable: Iterable[Any], **attrs: Any) -> Optional[Any]:
"""Return the first element with matching attributes."""
def predicate(elem: Any) -> bool:
return all(getattr(elem, attr, None) == value for attr, value in attrs.items())
return find(predicate, iterable)
def snowflake_time(snowflake: int) -> datetime:
"""Return the creation time of a Discord snowflake."""
timestamp_ms = (snowflake >> 22) + DISCORD_EPOCH
return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc)
async def message_pager( async def message_pager(
channel: "TextChannel", channel: "TextChannel",
*, *,
@ -21,32 +46,11 @@ async def message_pager(
before: Optional[str] = None, before: Optional[str] = None,
after: Optional[str] = None, after: Optional[str] = None,
) -> AsyncIterator["Message"]: ) -> AsyncIterator["Message"]:
"""Asynchronously paginate a channel's messages. """Asynchronously paginate a channel's messages."""
Parameters
----------
channel:
The :class:`TextChannel` to fetch messages from.
limit:
The maximum number of messages to yield. ``None`` fetches until no
more messages are returned.
before:
Fetch messages with IDs less than this snowflake.
after:
Fetch messages with IDs greater than this snowflake.
Yields
------
Message
Messages in the channel, oldest first.
"""
remaining = limit remaining = limit
last_id = before last_id = before
while remaining is None or remaining > 0: while remaining is None or remaining > 0:
fetch_limit = 100 fetch_limit = min(100, remaining) if remaining is not None else 100
if remaining is not None:
fetch_limit = min(fetch_limit, remaining)
params: Dict[str, Any] = {"limit": fetch_limit} params: Dict[str, Any] = {"limit": fetch_limit}
if last_id is not None: if last_id is not None:
@ -71,3 +75,52 @@ async def message_pager(
remaining -= 1 remaining -= 1
if remaining == 0: if remaining == 0:
return return
class Paginator:
"""Helper to split text into pages under a character limit."""
def __init__(self, limit: int = 2000) -> None:
self.limit = limit
self._pages: list[str] = []
self._current = ""
def add_line(self, line: str) -> None:
"""Add a line of text to the paginator."""
if len(line) > self.limit:
if self._current:
self._pages.append(self._current)
self._current = ""
for i in range(0, len(line), self.limit):
chunk = line[i : i + self.limit]
if len(chunk) == self.limit:
self._pages.append(chunk)
else:
self._current = chunk
return
if not self._current:
self._current = line
elif len(self._current) + 1 + len(line) <= self.limit:
self._current += "\n" + line
else:
self._pages.append(self._current)
self._current = line
@property
def pages(self) -> list[str]:
"""Return the accumulated pages."""
pages = list(self._pages)
if self._current:
pages.append(self._current)
return pages
def escape_markdown(text: str) -> str:
"""Escape Discord markdown formatting in ``text``."""
return re.sub(r"([\\*_~`>|])", r"\\\1", text)
def escape_mentions(text: str) -> str:
"""Escape Discord mentions in ``text``."""
return text.replace("@", "@\u200b")

View File

@ -77,11 +77,14 @@ class VoiceClient:
self.secret_key: Optional[Sequence[int]] = None self.secret_key: Optional[Sequence[int]] = None
self._server_ip: Optional[str] = None self._server_ip: Optional[str] = None
self._server_port: Optional[int] = None self._server_port: Optional[int] = None
self._current_source: Optional[AudioSource] = None self._current_source: Optional[AudioSource] = None
self._play_task: Optional[asyncio.Task] = None self._play_task: Optional[asyncio.Task] = None
self._sink: Optional[AudioSink] = None self._pause_event = asyncio.Event()
self._ssrc_map: dict[int, int] = {} self._pause_event.set()
self._ssrc_lock = threading.Lock() self._is_playing = False
self._sink: Optional[AudioSink] = None
self._ssrc_map: dict[int, int] = {}
self._ssrc_lock = threading.Lock()
async def connect(self) -> None: async def connect(self) -> None:
if self._ws is None: if self._ws is None:
@ -189,31 +192,37 @@ class VoiceClient:
raise RuntimeError("UDP socket not initialised") raise RuntimeError("UDP socket not initialised")
self._udp.send(frame) self._udp.send(frame)
async def _play_loop(self) -> None: async def _play_loop(self) -> None:
assert self._current_source is not None assert self._current_source is not None
try: self._is_playing = True
while True: try:
data = await self._current_source.read() while True:
if not data: await self._pause_event.wait()
break data = await self._current_source.read()
volume = getattr(self._current_source, "volume", 1.0) if not data:
if volume != 1.0: break
data = _apply_volume(data, volume) volume = getattr(self._current_source, "volume", 1.0)
await self.send_audio_frame(data) if volume != 1.0:
finally: data = _apply_volume(data, volume)
await self._current_source.close() await self.send_audio_frame(data)
self._current_source = None finally:
self._play_task = None await self._current_source.close()
self._current_source = None
self._play_task = None
self._is_playing = False
self._pause_event.set()
async def stop(self) -> None: async def stop(self) -> None:
if self._play_task: if self._play_task:
self._play_task.cancel() self._play_task.cancel()
with contextlib.suppress(asyncio.CancelledError): self._pause_event.set()
await self._play_task with contextlib.suppress(asyncio.CancelledError):
self._play_task = None await self._play_task
if self._current_source: self._play_task = None
await self._current_source.close() self._is_playing = False
self._current_source = None if self._current_source:
await self._current_source.close()
self._current_source = None
async def play(self, source: AudioSource, *, wait: bool = True) -> None: async def play(self, source: AudioSource, *, wait: bool = True) -> None:
"""|coro| Play an :class:`AudioSource` on the voice connection.""" """|coro| Play an :class:`AudioSource` on the voice connection."""
@ -224,10 +233,31 @@ class VoiceClient:
if wait: if wait:
await self._play_task await self._play_task
async def play_file(self, filename: str, *, wait: bool = True) -> None: async def play_file(self, filename: str, *, wait: bool = True) -> None:
"""|coro| Stream an audio file or URL using FFmpeg.""" """|coro| Stream an audio file or URL using FFmpeg."""
await self.play(FFmpegAudioSource(filename), wait=wait) await self.play(FFmpegAudioSource(filename), wait=wait)
def pause(self) -> None:
"""Pause the current audio source."""
if self._play_task and not self._play_task.done():
self._pause_event.clear()
def resume(self) -> None:
"""Resume playback of a paused source."""
if self._play_task and not self._play_task.done():
self._pause_event.set()
def is_paused(self) -> bool:
"""Return ``True`` if playback is currently paused."""
return bool(self._play_task and not self._pause_event.is_set())
def is_playing(self) -> bool:
"""Return ``True`` if audio is actively being played."""
return self._is_playing and self._pause_event.is_set()
def listen(self, sink: AudioSink) -> None: def listen(self, sink: AudioSink) -> None:
"""Start listening to voice and routing to a sink.""" """Start listening to voice and routing to a sink."""

1
docs/CNAME Normal file
View File

@ -0,0 +1 @@
disagreement.xyz

View File

@ -13,12 +13,28 @@ if member:
print(member.display_name) print(member.display_name)
``` ```
To access the bot's own member object, use the ``Guild.me`` property. It returns
``None`` if the bot is not in the guild or its user data hasn't been loaded:
```python
bot_member = guild.me
if bot_member:
print(bot_member.joined_at)
```
The cache can be cleared manually if needed: The cache can be cleared manually if needed:
```python ```python
client.cache.clear() client.cache.clear()
``` ```
## Partial Objects
Some events only include minimal data for related resources. When only an ``id``
is available, Disagreement represents the resource using :class:`~disagreement.Object`.
These objects can be compared and used in sets or dictionaries and can be passed
to API methods to fetch the full data when needed.
## Next Steps ## Next Steps
- [Components](using_components.md) - [Components](using_components.md)

View File

@ -11,7 +11,11 @@ The command handler registers a `help` command automatically. Use it to list all
!help ping # shows help for the "ping" command !help ping # shows help for the "ping" command
``` ```
The help command will show each command's brief description if provided. Commands are grouped by their Cog name and paginated so that long help
lists are split into multiple messages using the `Paginator` utility.
If you need custom formatting you can subclass
`HelpCommand` and override `send_command_help` or `send_group_help`.
## Checks ## Checks
@ -20,7 +24,7 @@ returns ``True``. Checks may be regular or async callables that accept a
`CommandContext`. `CommandContext`.
```python ```python
from disagreement.ext.commands import command, check, CheckFailure from disagreement import command, check, CheckFailure
def is_owner(ctx): def is_owner(ctx):
return ctx.author.id == "1" return ctx.author.id == "1"
@ -40,7 +44,7 @@ Commands can be rate limited using the ``cooldown`` decorator. The example
below restricts usage to once every three seconds per user: below restricts usage to once every three seconds per user:
```python ```python
from disagreement.ext.commands import command, cooldown from disagreement import command, cooldown
@command() @command()
@cooldown(1, 3.0) @cooldown(1, 3.0)
@ -56,8 +60,8 @@ Use `commands.requires_permissions` to ensure the invoking member has the
required permissions in the channel. required permissions in the channel.
```python ```python
from disagreement.ext.commands import command, requires_permissions from disagreement import command, requires_permissions
from disagreement.permissions import Permissions from disagreement import Permissions
@command() @command()
@requires_permissions(Permissions.MANAGE_MESSAGES) @requires_permissions(Permissions.MANAGE_MESSAGES)

View File

@ -1,12 +1,11 @@
# Context Menu Commands # Context Menu Commands
`disagreement` supports Discord's user and message context menu commands. Use the `disagreement` supports Discord's user and message context menu commands. Use the
`user_command` and `message_command` decorators from `ext.app_commands` to `user_command` and `message_command` decorators to define them.
define them.
```python ```python
from disagreement.ext.app_commands import user_command, message_command, AppCommandContext from disagreement import User, Message
from disagreement.models import User, Message from disagreement import User, Message, user_command, message_command, AppCommandContext
@user_command(name="User Info") @user_command(name="User Info")
async def user_info(ctx: AppCommandContext, user: User) -> None: async def user_info(ctx: AppCommandContext, user: User) -> None:

View File

@ -13,8 +13,8 @@
```python ```python
from disagreement.ext.commands import command from disagreement.ext.commands import command
from disagreement import Member
from disagreement.ext.commands.core import CommandContext from disagreement.ext.commands.core import CommandContext
from disagreement.models import Member
@command() @command()
async def kick(ctx: CommandContext, target: Member): async def kick(ctx: CommandContext, target: Member):

View File

@ -4,7 +4,7 @@
These helper methods return the embed instance so you can chain calls. These helper methods return the embed instance so you can chain calls.
```python ```python
from disagreement.models import Embed from disagreement import Embed
embed = ( embed = (
Embed() Embed()

View File

@ -20,7 +20,7 @@ Triggered when a user's presence changes. The callback receives a `PresenceUpdat
```python ```python
@client.event @client.event
async def on_presence_update(presence: disagreement.PresenceUpdate): async def on_presence_update(presence: PresenceUpdate):
... ...
``` ```
@ -30,7 +30,7 @@ Dispatched when a user begins typing in a channel. The callback receives a `Typi
```python ```python
@client.event @client.event
async def on_typing_start(typing: disagreement.TypingStart): async def on_typing_start(typing: TypingStart):
... ...
``` ```
@ -40,7 +40,7 @@ Fired when a new member joins a guild. The callback receives a `Member` model.
```python ```python
@client.event @client.event
async def on_guild_member_add(member: disagreement.Member): async def on_guild_member_add(member: Member):
... ...
``` ```
@ -51,7 +51,7 @@ receives a `GuildMemberRemove` model.
```python ```python
@client.event @client.event
async def on_guild_member_remove(event: disagreement.GuildMemberRemove): async def on_guild_member_remove(event: GuildMemberRemove):
... ...
``` ```
@ -62,7 +62,7 @@ Dispatched when a user is banned from a guild. The callback receives a
```python ```python
@client.event @client.event
async def on_guild_ban_add(event: disagreement.GuildBanAdd): async def on_guild_ban_add(event: GuildBanAdd):
... ...
``` ```
@ -73,7 +73,7 @@ Dispatched when a user's ban is lifted. The callback receives a
```python ```python
@client.event @client.event
async def on_guild_ban_remove(event: disagreement.GuildBanRemove): async def on_guild_ban_remove(event: GuildBanRemove):
... ...
``` ```
@ -84,7 +84,7 @@ Sent when a channel's settings change. The callback receives an updated
```python ```python
@client.event @client.event
async def on_channel_update(channel: disagreement.Channel): async def on_channel_update(channel: Channel):
... ...
``` ```
@ -95,7 +95,7 @@ Emitted when a guild role is updated. The callback receives a
```python ```python
@client.event @client.event
async def on_guild_role_update(event: disagreement.GuildRoleUpdate): async def on_guild_role_update(event: GuildRoleUpdate):
... ...
``` ```
@ -132,12 +132,34 @@ async def on_shard_resume(info: dict):
... ...
``` ```
## CONNECT
Dispatched when the WebSocket connection opens. The callback receives a
dictionary with the shard ID.
```python
@client.event
async def on_connect(info: dict):
print("connected", info.get("shard_id"))
```
## DISCONNECT
Fired when the WebSocket connection closes. The callback receives a dictionary
with the shard ID.
```python
@client.event
async def on_disconnect(info: dict):
...
```
## VOICE_STATE_UPDATE ## VOICE_STATE_UPDATE
Triggered when a user's voice connection state changes, such as joining or leaving a voice channel. The callback receives a `VoiceStateUpdate` model. Triggered when a user's voice connection state changes, such as joining or leaving a voice channel. The callback receives a `VoiceStateUpdate` model.
```python ```python
@client.event @client.event
async def on_voice_state_update(state: disagreement.VoiceStateUpdate): async def on_voice_state_update(state: VoiceStateUpdate):
... ...
``` ```

View File

@ -8,6 +8,8 @@ You can control the maximum number of retries and the backoff cap when construct
These options are forwarded to `GatewayClient` as `max_retries` and `max_backoff`: These options are forwarded to `GatewayClient` as `max_retries` and `max_backoff`:
```python ```python
from disagreement import Client
bot = Client( bot = Client(
token="your-token", token="your-token",
gateway_max_retries=10, gateway_max_retries=10,

View File

@ -24,7 +24,7 @@ other supported session argument.
The HTTP client can list the guilds the bot user is in: The HTTP client can list the guilds the bot user is in:
```python ```python
from disagreement.http import HTTPClient from disagreement import HTTPClient
http = HTTPClient(token="TOKEN") http = HTTPClient(token="TOKEN")
guilds = await http.get_current_user_guilds() guilds = await http.get_current_user_guilds()

View File

@ -14,6 +14,7 @@ A Python library for interacting with the Discord API, with a focus on bot devel
- Built-in caching layer - Built-in caching layer
- Experimental voice support - Experimental voice support
- Helpful error handling utilities - Helpful error handling utilities
- Paginator utility for splitting long messages
## Installation ## Installation
@ -39,14 +40,14 @@ pip install "disagreement[dev]"
import asyncio import asyncio
import os import os
import disagreement from disagreement import Client, GatewayIntent
from disagreement.ext import commands from disagreement.ext import commands
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
class Basics(commands.Cog): class Basics(commands.Cog):
def __init__(self, client: disagreement.Client) -> None: def __init__(self, client: Client) -> None:
super().__init__(client) super().__init__(client)
@commands.command() @commands.command()
@ -58,15 +59,11 @@ token = os.getenv("DISCORD_BOT_TOKEN")
if not token: if not token:
raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set") raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set")
intents = disagreement.GatewayIntent.default() | disagreement.GatewayIntent.MESSAGE_CONTENT intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
client = disagreement.Client(token=token, command_prefix="!", intents=intents, mention_replies=True) client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
async def main() -> None:
client.add_cog(Basics(client))
await client.run()
client.add_cog(Basics(client))
if __name__ == "__main__": client.run()
asyncio.run(main())
``` ```
### Global Error Handling ### Global Error Handling
@ -100,10 +97,10 @@ setup_logging(logging.DEBUG, file="bot.log")
### HTTP Session Options ### HTTP Session Options
Pass additional keyword arguments to ``aiohttp.ClientSession`` using the Pass additional keyword arguments to ``aiohttp.ClientSession`` using the
``http_options`` parameter when constructing :class:`disagreement.Client`: ``http_options`` parameter when constructing :class:`Client`:
```python ```python
client = disagreement.Client( client = Client(
token=token, token=token,
http_options={"proxy": "http://localhost:8080"}, http_options={"proxy": "http://localhost:8080"},
) )
@ -119,14 +116,14 @@ Specify default mention behaviour for all outgoing messages when constructing th
```python ```python
from disagreement.models import AllowedMentions from disagreement.models import AllowedMentions
client = disagreement.Client( client = Client(
token=token, token=token,
allowed_mentions=AllowedMentions.none().to_dict(), allowed_mentions=AllowedMentions.none().to_dict(),
) )
``` ```
This dictionary is used whenever ``send_message`` is called without an explicit This dictionary is used whenever ``send_message`` or helpers like ``Message.reply``
``allowed_mentions`` argument. are called without an explicit ``allowed_mentions`` argument.
The :class:`AllowedMentions` class offers ``none()`` and ``all()`` helpers for The :class:`AllowedMentions` class offers ``none()`` and ``all()`` helpers for
quickly generating these configurations. quickly generating these configurations.
@ -173,14 +170,15 @@ To run your bot across multiple gateway shards, pass ``shard_count`` when creati
the client: the client:
```python ```python
client = disagreement.Client(token=BOT_TOKEN, shard_count=2) client = Client(token=BOT_TOKEN, shard_count=2)
``` ```
If you want the library to determine the recommended shard count automatically, If you want the library to determine the recommended shard count automatically,
use ``AutoShardedClient``: use ``AutoShardedClient``:
```python ```python
client = disagreement.AutoShardedClient(token=BOT_TOKEN) from disagreement import AutoShardedClient
client = AutoShardedClient(token=BOT_TOKEN)
``` ```
See `examples/sharded_bot.py` for a full example. See `examples/sharded_bot.py` for a full example.

View File

@ -8,14 +8,15 @@ Use the ``allowed_mentions`` parameter of :class:`disagreement.Client` to set a
default for all messages: default for all messages:
```python ```python
from disagreement.models import AllowedMentions from disagreement import AllowedMentions, Client
client = disagreement.Client( client = Client(
token="YOUR_TOKEN", token="YOUR_TOKEN",
allowed_mentions=AllowedMentions.none().to_dict(), allowed_mentions=AllowedMentions.none().to_dict(),
) )
``` ```
When ``Client.send_message`` is called without an explicit ``allowed_mentions`` When ``Client.send_message`` or convenience methods like ``Message.reply`` and
``CommandContext.reply`` are called without an explicit ``allowed_mentions``
argument this value will be used. argument this value will be used.
``AllowedMentions`` also provides the convenience methods ``AllowedMentions`` also provides the convenience methods

View File

@ -10,11 +10,17 @@ Each attribute of ``Permissions`` represents a single permission bit. The value
is a power of two so multiple permissions can be combined using bitwise OR. is a power of two so multiple permissions can be combined using bitwise OR.
```python ```python
from disagreement.permissions import Permissions from disagreement import Permissions
value = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES value = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
``` ```
You can also get a bitmask containing **every** permission:
```python
all_perms = Permissions.all()
```
## Helper Functions ## Helper Functions
### ``permissions_value`` ### ``permissions_value``
@ -47,10 +53,10 @@ Return a list of permissions that ``current`` does not contain.
```python ```python
from disagreement.permissions import ( from disagreement.permissions import (
Permissions,
has_permissions, has_permissions,
missing_permissions, missing_permissions,
) )
from disagreement import Permissions
current = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES current = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES

View File

@ -26,7 +26,7 @@ An activity dictionary must include a `name` and a `type` field. The type value
Example using the provided activity classes: Example using the provided activity classes:
```python ```python
from disagreement.models import Game from disagreement import Game
await client.change_presence(status="idle", activity=Game("with Discord")) await client.change_presence(status="idle", activity=Game("with Discord"))
``` ```
@ -34,7 +34,7 @@ await client.change_presence(status="idle", activity=Game("with Discord"))
You can also specify a streaming URL: You can also specify a streaming URL:
```python ```python
from disagreement.models import Streaming from disagreement import Streaming
await client.change_presence(status="online", activity=Streaming("My Stream", "https://twitch.tv/someone")) await client.change_presence(status="online", activity=Streaming("My Stream", "https://twitch.tv/someone"))
``` ```

View File

@ -48,7 +48,7 @@ The event handlers for these events receive both a `Reaction` object and the `Us
```python ```python
import disagreement import disagreement
from disagreement.models import Reaction, User, Member from disagreement import Reaction, User, Member
@client.on_event("MESSAGE_REACTION_ADD") @client.on_event("MESSAGE_REACTION_ADD")
async def on_reaction_add(reaction: Reaction, user: User | Member): async def on_reaction_add(reaction: Reaction, user: User | Member):

View File

@ -3,7 +3,7 @@
The `Client` provides helpers to manage guild scheduled events. The `Client` provides helpers to manage guild scheduled events.
```python ```python
from disagreement.client import Client from disagreement import Client
client = Client(token="TOKEN") client = Client(token="TOKEN")

View File

@ -8,13 +8,8 @@ manually.
and configures the `ShardManager` automatically. and configures the `ShardManager` automatically.
```python ```python
import asyncio
import disagreement import disagreement
bot = disagreement.AutoShardedClient(token="YOUR_TOKEN") bot = disagreement.AutoShardedClient(token="YOUR_TOKEN")
bot.run()
async def main():
await bot.run()
asyncio.run(main())
``` ```

View File

@ -1,9 +1,9 @@
# Using Slash Commands # Using Slash Commands
The library provides a slash command framework via the `ext.app_commands` package. Define commands with decorators and register them with Discord. The library provides a slash command framework to define commands with decorators and register them with Discord.
```python ```python
from disagreement.ext.app_commands import AppCommandGroup from disagreement import AppCommandGroup
bot_commands = AppCommandGroup("bot", "Bot commands") bot_commands = AppCommandGroup("bot", "Bot commands")

View File

@ -1,11 +1,11 @@
# Task Loops # Task Loops
The tasks extension allows you to run functions periodically. Decorate an async function with `@tasks.loop` and start it using `.start()`. The tasks extension allows you to run functions periodically. Decorate an async function with `@loop` and start it using `.start()`.
```python ```python
from disagreement.ext import tasks from disagreement import loop
@tasks.loop(minutes=1.0) @loop(minutes=1.0)
async def announce(): async def announce():
print("Hello from a loop") print("Hello from a loop")
@ -19,7 +19,7 @@ You can provide the interval in seconds, minutes, hours or as a `datetime.timede
```python ```python
import datetime import datetime
@tasks.loop(delta=datetime.timedelta(seconds=30)) @loop(delta=datetime.timedelta(seconds=30))
async def ping(): async def ping():
... ...
``` ```
@ -30,7 +30,7 @@ Handle exceptions raised by the looped coroutine using `on_error`:
async def log_error(exc: Exception) -> None: async def log_error(exc: Exception) -> None:
print("Loop failed:", exc) print("Loop failed:", exc)
@tasks.loop(seconds=5.0, on_error=log_error) @loop(seconds=5.0, on_error=log_error)
async def worker(): async def worker():
... ...
``` ```
@ -38,7 +38,7 @@ async def worker():
Run setup and teardown code using `before_loop` and `after_loop`: Run setup and teardown code using `before_loop` and `after_loop`:
```python ```python
@tasks.loop(seconds=5.0) @loop(seconds=5.0)
async def worker(): async def worker():
... ...
@ -58,7 +58,7 @@ from datetime import datetime, timedelta
time_to_run = (datetime.now() + timedelta(seconds=5)).time() time_to_run = (datetime.now() + timedelta(seconds=5)).time()
@tasks.loop(time_of_day=time_to_run) @loop(time_of_day=time_to_run)
async def daily_task(): async def daily_task():
... ...
``` ```

View File

@ -4,7 +4,7 @@
Use :class:`AutoArchiveDuration` to control when a thread is automatically archived. Use :class:`AutoArchiveDuration` to control when a thread is automatically archived.
```python ```python
from disagreement.enums import AutoArchiveDuration from disagreement import AutoArchiveDuration
await message.create_thread( await message.create_thread(
"discussion", "discussion",

View File

@ -4,9 +4,9 @@ The library exposes an async context manager to send the typing indicator for a
```python ```python
import asyncio import asyncio
import disagreement from disagreement import Client
client = disagreement.Client(token="YOUR_TOKEN") client = Client(token="YOUR_TOKEN")
async def indicate(channel_id: str): async def indicate(channel_id: str):
async with client.typing(channel_id): async with client.typing(channel_id):

View File

@ -19,8 +19,7 @@ The library exposes three broad categories of components:
`ActionRow` is a layout container. It may hold up to five buttons or a single select menu. `ActionRow` is a layout container. It may hold up to five buttons or a single select menu.
```python ```python
from disagreement.models import ActionRow, Button from disagreement import ActionRow, Button, ButtonStyle
from disagreement.enums import ButtonStyle
row = ActionRow(components=[ row = ActionRow(components=[
Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="btn") Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="btn")
@ -32,8 +31,7 @@ row = ActionRow(components=[
Buttons provide a clickable UI element. Buttons provide a clickable UI element.
```python ```python
from disagreement.models import Button from disagreement import Button, ButtonStyle
from disagreement.enums import ButtonStyle
button = Button( button = Button(
style=ButtonStyle.SUCCESS, style=ButtonStyle.SUCCESS,
@ -47,8 +45,7 @@ button = Button(
`SelectMenu` lets the user choose one or more options. The `type` parameter controls the menu variety (`STRING_SELECT`, `USER_SELECT`, `ROLE_SELECT`, `MENTIONABLE_SELECT`, `CHANNEL_SELECT`). `SelectMenu` lets the user choose one or more options. The `type` parameter controls the menu variety (`STRING_SELECT`, `USER_SELECT`, `ROLE_SELECT`, `MENTIONABLE_SELECT`, `CHANNEL_SELECT`).
```python ```python
from disagreement.models import SelectMenu, SelectOption from disagreement import SelectMenu, SelectOption, ComponentType, ChannelType
from disagreement.enums import ComponentType, ChannelType
menu = SelectMenu( menu = SelectMenu(
custom_id="example", custom_id="example",
@ -70,7 +67,7 @@ For channel selects you may pass `channel_types` with a list of allowed `Channel
`Section` groups one or more `TextDisplay` components and can include an accessory `Button` or `Thumbnail`. `Section` groups one or more `TextDisplay` components and can include an accessory `Button` or `Thumbnail`.
```python ```python
from disagreement.models import Section, TextDisplay, Thumbnail, UnfurledMediaItem from disagreement import Section, TextDisplay, Thumbnail, UnfurledMediaItem
section = Section( section = Section(
components=[ components=[
@ -86,7 +83,7 @@ section = Section(
`TextDisplay` simply renders markdown text. `TextDisplay` simply renders markdown text.
```python ```python
from disagreement.models import TextDisplay from disagreement import TextDisplay
text_display = TextDisplay(content="**Bold text**") text_display = TextDisplay(content="**Bold text**")
``` ```
@ -96,7 +93,7 @@ text_display = TextDisplay(content="**Bold text**")
`Thumbnail` shows a small image. Set `spoiler=True` to hide the image until clicked. `Thumbnail` shows a small image. Set `spoiler=True` to hide the image until clicked.
```python ```python
from disagreement.models import Thumbnail, UnfurledMediaItem from disagreement import Thumbnail, UnfurledMediaItem
thumb = Thumbnail( thumb = Thumbnail(
media=UnfurledMediaItem(url="https://example.com/image.png"), media=UnfurledMediaItem(url="https://example.com/image.png"),
@ -110,7 +107,7 @@ thumb = Thumbnail(
`MediaGallery` holds multiple `MediaGalleryItem` objects. `MediaGallery` holds multiple `MediaGalleryItem` objects.
```python ```python
from disagreement.models import MediaGallery, MediaGalleryItem, UnfurledMediaItem from disagreement import MediaGallery, MediaGalleryItem, UnfurledMediaItem
gallery = MediaGallery( gallery = MediaGallery(
items=[ items=[
@ -125,7 +122,7 @@ gallery = MediaGallery(
`File` displays an uploaded file. Use `spoiler=True` to mark it as a spoiler. `File` displays an uploaded file. Use `spoiler=True` to mark it as a spoiler.
```python ```python
from disagreement.models import File, UnfurledMediaItem from disagreement import File, UnfurledMediaItem
file_component = File( file_component = File(
file=UnfurledMediaItem(url="attachment://file.zip"), file=UnfurledMediaItem(url="attachment://file.zip"),
@ -138,7 +135,7 @@ file_component = File(
`Separator` adds vertical spacing or an optional divider line between components. `Separator` adds vertical spacing or an optional divider line between components.
```python ```python
from disagreement.models import Separator from disagreement import Separator
separator = Separator(divider=True, spacing=2) separator = Separator(divider=True, spacing=2)
``` ```
@ -148,7 +145,7 @@ separator = Separator(divider=True, spacing=2)
`Container` visually groups a set of components and can apply an accent colour or spoiler. `Container` visually groups a set of components and can apply an accent colour or spoiler.
```python ```python
from disagreement.models import Container, TextDisplay from disagreement import Container, TextDisplay
container = Container( container = Container(
components=[TextDisplay(content="Inside a container")], components=[TextDisplay(content="Inside a container")],
@ -160,6 +157,22 @@ container = Container(
A container can itself contain layout and content components, letting you build complex messages. A container can itself contain layout and content components, letting you build complex messages.
## Persistent Views
Views with ``timeout=None`` are persistent. Their ``custom_id`` components are saved to ``persistent_views.json`` so they survive bot restarts.
```python
class MyView(View):
@button(label="Press", custom_id="press")
async def handle(self, view, inter):
await inter.respond("Pressed!")
client.add_persistent_view(MyView())
```
When the client starts, it loads this file and registers each view again. Remove
the file to clear stored views.
## Next Steps ## Next Steps
- [Slash Commands](slash_commands.md) - [Slash Commands](slash_commands.md)

20
docs/utils.md Normal file
View File

@ -0,0 +1,20 @@
# Utility Helpers
Disagreement provides a few small utility functions for working with Discord data.
## `utcnow`
Returns the current timezone-aware UTC `datetime`.
## `snowflake_time`
Converts a Discord snowflake ID into the UTC timestamp when it was generated.
```python
from disagreement.utils import snowflake_time
created_at = snowflake_time(175928847299117063)
print(created_at.isoformat())
```
The function extracts the timestamp from the snowflake and returns a `datetime` in UTC.

View File

@ -6,6 +6,10 @@ Disagreement includes experimental support for connecting to voice channels. You
voice = await client.join_voice(guild_id, channel_id) voice = await client.join_voice(guild_id, channel_id)
await voice.play_file("welcome.mp3") await voice.play_file("welcome.mp3")
await voice.play_file("another.mp3") # switch sources while connected await voice.play_file("another.mp3") # switch sources while connected
voice.pause()
voice.resume()
if voice.is_playing():
print("audio is playing")
await voice.close() await voice.close()
``` ```

View File

@ -5,7 +5,7 @@ The `HTTPClient` includes helper methods for creating, editing and deleting Disc
## Create a webhook ## Create a webhook
```python ```python
from disagreement.http import HTTPClient from disagreement import HTTPClient
http = HTTPClient(token="TOKEN") http = HTTPClient(token="TOKEN")
payload = {"name": "My Webhook"} payload = {"name": "My Webhook"}
@ -27,7 +27,7 @@ await http.delete_webhook("456")
The methods now return a `Webhook` object directly: The methods now return a `Webhook` object directly:
```python ```python
from disagreement.models import Webhook from disagreement import Webhook
print(webhook.id, webhook.name) print(webhook.id, webhook.name)
``` ```
@ -37,7 +37,7 @@ print(webhook.id, webhook.name)
You can construct a `Webhook` object from an existing webhook URL without any API calls: You can construct a `Webhook` object from an existing webhook URL without any API calls:
```python ```python
from disagreement.models import Webhook from disagreement import Webhook
webhook = Webhook.from_url("https://discord.com/api/webhooks/123/token") webhook = Webhook.from_url("https://discord.com/api/webhooks/123/token")
print(webhook.id, webhook.token) print(webhook.id, webhook.token)

View File

@ -27,9 +27,17 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
try: try:
import disagreement from disagreement import (
from disagreement.models import Guild Client,
from disagreement.ext import commands # Import the new commands extension GatewayIntent,
Message,
Guild,
AuthenticationError,
DisagreementException,
Cog,
command,
CommandContext,
)
except ImportError: except ImportError:
print( print(
"Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly." "Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly."
@ -59,15 +67,13 @@ BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
# --- Intents Configuration --- # --- Intents Configuration ---
# Define the intents your bot needs. For basic message reading and responding: # Define the intents your bot needs. For basic message reading and responding:
intents = ( intents = (
disagreement.GatewayIntent.GUILDS GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
| disagreement.GatewayIntent.GUILD_MESSAGES
| disagreement.GatewayIntent.MESSAGE_CONTENT
) # MESSAGE_CONTENT is privileged! ) # MESSAGE_CONTENT is privileged!
# If you don't need message content and only react to commands/mentions, # If you don't need message content and only react to commands/mentions,
# you might not need MESSAGE_CONTENT intent. # you might not need MESSAGE_CONTENT intent.
# intents = disagreement.GatewayIntent.default() # A good starting point without privileged intents # intents = GatewayIntent.default() # A good starting point without privileged intents
# intents |= disagreement.GatewayIntent.MESSAGE_CONTENT # Add if needed # intents |= GatewayIntent.MESSAGE_CONTENT # Add if needed
# --- Initialize the Client --- # --- Initialize the Client ---
if not BOT_TOKEN: if not BOT_TOKEN:
@ -76,30 +82,30 @@ if not BOT_TOKEN:
sys.exit(1) sys.exit(1)
# Initialize Client with a command prefix # Initialize Client with a command prefix
client = disagreement.Client(token=BOT_TOKEN, intents=intents, command_prefix="!") client = Client(token=BOT_TOKEN, intents=intents, command_prefix="!")
# --- Define a Cog for example commands --- # --- Define a Cog for example commands ---
class ExampleCog(commands.Cog): # Ensuring this uses commands.Cog class ExampleCog(Cog): # Ensuring this uses commands.Cog
def __init__( def __init__(
self, bot_client self, bot_client
): # Renamed client to bot_client to avoid conflict with self.client ): # Renamed client to bot_client to avoid conflict with self.client
super().__init__(bot_client) # Pass the client instance to the base Cog super().__init__(bot_client) # Pass the client instance to the base Cog
@commands.command(name="hello", aliases=["hi"]) @command(name="hello", aliases=["hi"])
async def hello_command(self, ctx: commands.CommandContext, *, who: str = "world"): async def hello_command(self, ctx: CommandContext, *, who: str = "world"):
"""Greets someone.""" """Greets someone."""
await ctx.reply(f"Hello {ctx.author.mention} and {who}!") await ctx.reply(f"Hello {ctx.author.mention} and {who}!")
print(f"Executed 'hello' command for {ctx.author.username}, greeting {who}") print(f"Executed 'hello' command for {ctx.author.username}, greeting {who}")
@commands.command() @command()
async def ping(self, ctx: commands.CommandContext): async def ping(self, ctx: CommandContext):
"""Responds with Pong!""" """Responds with Pong!"""
await ctx.reply("Pong!") await ctx.reply("Pong!")
print(f"Executed 'ping' command for {ctx.author.username}") print(f"Executed 'ping' command for {ctx.author.username}")
@commands.command() @command()
async def me(self, ctx: commands.CommandContext): async def me(self, ctx: CommandContext):
"""Shows information about the invoking user.""" """Shows information about the invoking user."""
reply_content = ( reply_content = (
f"Hello {ctx.author.mention}!\n" f"Hello {ctx.author.mention}!\n"
@ -110,8 +116,8 @@ class ExampleCog(commands.Cog): # Ensuring this uses commands.Cog
await ctx.reply(reply_content) await ctx.reply(reply_content)
print(f"Executed 'me' command for {ctx.author.username}") print(f"Executed 'me' command for {ctx.author.username}")
@commands.command(name="add") @command(name="add")
async def add_numbers(self, ctx: commands.CommandContext, num1: int, num2: int): async def add_numbers(self, ctx: CommandContext, num1: int, num2: int):
"""Adds two numbers.""" """Adds two numbers."""
result = num1 + num2 result = num1 + num2
await ctx.reply(f"The sum of {num1} and {num2} is {result}.") await ctx.reply(f"The sum of {num1} and {num2} is {result}.")
@ -119,16 +125,16 @@ class ExampleCog(commands.Cog): # Ensuring this uses commands.Cog
f"Executed 'add' command for {ctx.author.username}: {num1} + {num2} = {result}" f"Executed 'add' command for {ctx.author.username}: {num1} + {num2} = {result}"
) )
@commands.command(name="say") @command(name="say")
async def say_something(self, ctx: commands.CommandContext, *, text_to_say: str): async def say_something(self, ctx: CommandContext, *, text_to_say: str):
"""Repeats the text you provide.""" """Repeats the text you provide."""
await ctx.reply(f"You said: {text_to_say}") await ctx.reply(f"You said: {text_to_say}")
print( print(
f"Executed 'say' command for {ctx.author.username}, saying: {text_to_say}" f"Executed 'say' command for {ctx.author.username}, saying: {text_to_say}"
) )
@commands.command(name="whois") @command(name="whois")
async def whois(self, ctx: commands.CommandContext, *, name: str): async def whois(self, ctx: CommandContext, *, name: str):
"""Looks up a member by username or nickname using the guild cache.""" """Looks up a member by username or nickname using the guild cache."""
if not ctx.guild: if not ctx.guild:
await ctx.reply("This command can only be used in a guild.") await ctx.reply("This command can only be used in a guild.")
@ -142,8 +148,8 @@ class ExampleCog(commands.Cog): # Ensuring this uses commands.Cog
else: else:
await ctx.reply("Member not found in cache.") await ctx.reply("Member not found in cache.")
@commands.command(name="quit") @command(name="quit")
async def quit_command(self, ctx: commands.CommandContext): async def quit_command(self, ctx: CommandContext):
"""Shuts down the bot (requires YOUR_USER_ID to be set).""" """Shuts down the bot (requires YOUR_USER_ID to be set)."""
# Replace YOUR_USER_ID with your actual Discord User ID for a safe shutdown command # Replace YOUR_USER_ID with your actual Discord User ID for a safe shutdown command
your_user_id = "YOUR_USER_ID_REPLACE_ME" # IMPORTANT: Replace this your_user_id = "YOUR_USER_ID_REPLACE_ME" # IMPORTANT: Replace this
@ -177,7 +183,7 @@ async def on_ready():
@client.event @client.event
async def on_message(message: disagreement.Message): async def on_message(message: Message):
"""Called when a message is created and received.""" """Called when a message is created and received."""
# Command processing is now handled by the CommandHandler via client._process_message_for_commands # Command processing is now handled by the CommandHandler via client._process_message_for_commands
# This on_message can be used for other message-related logic if needed, # This on_message can be used for other message-related logic if needed,
@ -202,19 +208,19 @@ async def on_guild_available(guild: Guild):
# --- Main Execution --- # --- Main Execution ---
async def main(): def main():
print("Starting Disagreement Bot...") print("Starting Disagreement Bot...")
try: try:
# Add the Cog to the client # Add the Cog to the client
client.add_cog(ExampleCog(client)) # Pass client instance to Cog constructor client.add_cog(ExampleCog(client)) # Pass client instance to Cog constructor
# client.add_cog is synchronous, but it schedules cog.cog_load() if it's async. # client.add_cog is synchronous, but it schedules cog.cog_load() if it's async.
await client.run() client.run()
except disagreement.AuthenticationError: except AuthenticationError:
print( print(
"Authentication failed. Please check your bot token and ensure it's correct." "Authentication failed. Please check your bot token and ensure it's correct."
) )
except disagreement.DisagreementException as e: except DisagreementException as e:
print(f"A Disagreement library error occurred: {e}") print(f"A Disagreement library error occurred: {e}")
except KeyboardInterrupt: except KeyboardInterrupt:
print("Bot shutting down due to KeyboardInterrupt...") print("Bot shutting down due to KeyboardInterrupt...")
@ -224,7 +230,7 @@ async def main():
finally: finally:
if not client.is_closed(): if not client.is_closed():
print("Ensuring client is closed...") print("Ensuring client is closed...")
await client.close() asyncio.run(client.close())
print("Bot has been shut down.") print("Bot has been shut down.")
@ -236,4 +242,4 @@ if __name__ == "__main__":
# if os.name == 'nt': # if os.name == 'nt':
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(main()) main()

View File

@ -1,8 +1,9 @@
import os import os
import asyncio import asyncio
from typing import Union, Optional from typing import Union
from disagreement import Client, ui, HybridContext from disagreement import (
from disagreement.models import ( Client,
HybridContext,
Message, Message,
SelectOption, SelectOption,
User, User,
@ -19,23 +20,18 @@ from disagreement.models import (
MediaGallery, MediaGallery,
MediaGalleryItem, MediaGalleryItem,
Container, Container,
)
from disagreement.enums import (
ButtonStyle, ButtonStyle,
GatewayIntent, GatewayIntent,
ChannelType, ChannelType,
MessageFlags, MessageFlags,
InteractionCallbackType,
MessageFlags,
)
from disagreement.ext.commands.cog import Cog
from disagreement.ext.commands.core import CommandContext
from disagreement.ext.app_commands.decorators import hybrid_command, slash_command
from disagreement.ext.app_commands.context import AppCommandContext
from disagreement.interactions import (
Interaction, Interaction,
InteractionResponsePayload, Cog,
InteractionCallbackData, CommandContext,
AppCommandContext,
hybrid_command,
View,
button,
select,
) )
try: try:
@ -76,12 +72,12 @@ STOCKS = [
# Define a View class that contains our components # Define a View class that contains our components
class MyView(ui.View): class MyView(View):
def __init__(self): def __init__(self):
super().__init__(timeout=180) # 180-second timeout super().__init__(timeout=180) # 180-second timeout
self.click_count = 0 self.click_count = 0
@ui.button(label="Click Me!", style=ButtonStyle.SUCCESS, emoji="🖱️") @button(label="Click Me!", style=ButtonStyle.SUCCESS, emoji="🖱️")
async def click_me(self, interaction: Interaction): async def click_me(self, interaction: Interaction):
self.click_count += 1 self.click_count += 1
await interaction.respond( await interaction.respond(
@ -89,7 +85,7 @@ class MyView(ui.View):
ephemeral=True, ephemeral=True,
) )
@ui.select( @select(
custom_id="string_select", custom_id="string_select",
placeholder="Choose an option", placeholder="Choose an option",
options=[ options=[
@ -115,12 +111,12 @@ class MyView(ui.View):
# View for cycling through available stocks # View for cycling through available stocks
class StockView(ui.View): class StockView(View):
def __init__(self): def __init__(self):
super().__init__(timeout=180) super().__init__(timeout=180)
self.index = 0 self.index = 0
@ui.button(label="Next Stock", style=ButtonStyle.PRIMARY) @button(label="Next Stock", style=ButtonStyle.PRIMARY)
async def next_stock(self, interaction: Interaction): async def next_stock(self, interaction: Interaction):
self.index = (self.index + 1) % len(STOCKS) self.index = (self.index + 1) % len(STOCKS)
stock = STOCKS[self.index] stock = STOCKS[self.index]
@ -267,7 +263,7 @@ class ComponentCommandsCog(Cog):
) )
async def main(): def main():
@client.event @client.event
async def on_ready(): async def on_ready():
if client.user: if client.user:
@ -287,8 +283,8 @@ async def main():
) )
client.add_cog(ComponentCommandsCog(client)) client.add_cog(ComponentCommandsCog(client))
await client.run() client.run()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) main()

View File

@ -7,13 +7,12 @@ import sys
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from disagreement.client import Client from disagreement import Client, User, Message
from disagreement.ext.app_commands import ( from disagreement.ext.app_commands import (
user_command, user_command,
message_command, message_command,
AppCommandContext, AppCommandContext,
) )
from disagreement.models import User, Message
try: try:
from dotenv import load_dotenv from dotenv import load_dotenv
@ -66,11 +65,9 @@ client.app_command_handler.add_command(user_info)
client.app_command_handler.add_command(quote) client.app_command_handler.add_command(quote)
async def main() -> None: def main() -> None:
await client.run() client.run()
if __name__ == "__main__": if __name__ == "__main__":
import asyncio main()
asyncio.run(main())

View File

@ -4,19 +4,18 @@
import asyncio import asyncio
import os import os
import disagreement from disagreement import Client, GatewayIntent, Cog, command, CommandContext
from disagreement.ext import commands
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
class Basics(commands.Cog): class Basics(Cog):
def __init__(self, client: disagreement.Client) -> None: def __init__(self, client: Client) -> None:
super().__init__(client) super().__init__(client)
@commands.command() @command()
async def ping(self, ctx: commands.CommandContext) -> None: async def ping(self, ctx: CommandContext) -> None:
await ctx.reply(f"Pong! Gateway Latency: {self.client.latency_ms} ms.") # type: ignore (latency is None during static analysis) await ctx.reply(f"Pong! Gateway Latency: {self.client.latency_ms} ms.") # type: ignore (latency is None during static analysis)
@ -24,18 +23,14 @@ token = os.getenv("DISCORD_BOT_TOKEN")
if not token: if not token:
raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set") raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set")
intents = ( intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
disagreement.GatewayIntent.default() | disagreement.GatewayIntent.MESSAGE_CONTENT client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
)
client = disagreement.Client(
token=token, command_prefix="!", intents=intents, mention_replies=True
)
async def main() -> None: def main() -> None:
client.add_cog(Basics(client)) client.add_cog(Basics(client))
await client.run() client.run()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) main()

View File

@ -3,32 +3,27 @@ import os
import logging import logging
from typing import Any, Optional, Literal, Union from typing import Any, Optional, Literal, Union
from disagreement import HybridContext from disagreement import (
HybridContext,
from disagreement.client import Client Client,
from disagreement.ext.commands.cog import Cog
from disagreement.ext.commands.core import CommandContext
from disagreement.ext.app_commands.decorators import (
slash_command,
user_command,
message_command,
hybrid_command,
)
from disagreement.ext.app_commands.commands import (
AppCommandGroup,
) # For defining groups
from disagreement.ext.app_commands.context import AppCommandContext # Added
from disagreement.models import (
User, User,
Member, Member,
Role, Role,
Attachment, Attachment,
Message, Message,
Channel, Channel,
) # For type hints
from disagreement.enums import (
ChannelType, ChannelType,
) # For channel option type hints, assuming it exists )
from disagreement.ext import commands
from disagreement.ext.commands import Cog, CommandContext
from disagreement.ext.app_commands import (
AppCommandContext,
AppCommandGroup,
slash_command,
user_command,
message_command,
hybrid_command,
)
# from disagreement.interactions import Interaction # Replaced by AppCommandContext # from disagreement.interactions import Interaction # Replaced by AppCommandContext
@ -235,7 +230,7 @@ class TestCog(Cog):
# --- Main Bot Script --- # --- Main Bot Script ---
async def main(): def main():
bot_token = os.getenv("DISCORD_BOT_TOKEN") bot_token = os.getenv("DISCORD_BOT_TOKEN")
application_id = os.getenv("DISCORD_APPLICATION_ID") application_id = os.getenv("DISCORD_APPLICATION_ID")
@ -296,7 +291,7 @@ async def main():
client.add_cog(TestCog(client)) client.add_cog(TestCog(client))
try: try:
await client.run() client.run()
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Bot shutting down...") logger.info("Bot shutting down...")
except Exception as e: except Exception as e:
@ -305,7 +300,7 @@ async def main():
) )
finally: finally:
if not client.is_closed(): if not client.is_closed():
await client.close() asyncio.run(client.close())
logger.info("Bot has been closed.") logger.info("Bot has been closed.")
@ -315,6 +310,6 @@ if __name__ == "__main__":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
try: try:
asyncio.run(main()) main()
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Main loop interrupted. Exiting.") logger.info("Main loop interrupted. Exiting.")

View File

@ -8,7 +8,7 @@ import sys
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from disagreement.client import Client from disagreement import Client, Channel
from disagreement.models import TextChannel from disagreement.models import TextChannel
try: try:

View File

@ -9,10 +9,9 @@ except ImportError: # pragma: no cover - example helper
load_dotenv = None load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded") print("python-dotenv is not installed. Environment variables will not be loaded")
from disagreement import Client, ui from disagreement import Client, ui, GatewayIntent
from disagreement.enums import GatewayIntent, TextInputStyle from disagreement.enums import TextInputStyle
from disagreement.ext.app_commands.decorators import slash_command from disagreement.ext.app_commands import slash_command, AppCommandContext
from disagreement.ext.app_commands.context import AppCommandContext
if load_dotenv: if load_dotenv:
load_dotenv() load_dotenv()
@ -63,9 +62,9 @@ async def on_ready():
print("------") print("------")
async def main(): def main():
await client.run() client.run()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) main()

View File

@ -11,9 +11,8 @@ except ImportError: # pragma: no cover - example helper
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from disagreement import Client, GatewayIntent, ui # type: ignore from disagreement import Client, GatewayIntent, ui
from disagreement.ext.app_commands.decorators import slash_command from disagreement.ext.app_commands import slash_command, AppCommandContext
from disagreement.ext.app_commands.context import AppCommandContext
if load_dotenv: if load_dotenv:
load_dotenv() load_dotenv()
@ -64,6 +63,4 @@ async def on_ready():
if __name__ == "__main__": if __name__ == "__main__":
import asyncio client.run()
asyncio.run(client.run())

View File

@ -9,9 +9,15 @@ from typing import Set
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import disagreement from disagreement import (
from disagreement.ext import commands Client,
from disagreement.models import Member, Message GatewayIntent,
Member,
Message,
Cog,
command,
CommandContext,
)
try: try:
from dotenv import load_dotenv from dotenv import load_dotenv
@ -28,31 +34,29 @@ if not BOT_TOKEN:
sys.exit(1) sys.exit(1)
intents = ( intents = (
disagreement.GatewayIntent.GUILDS GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
| disagreement.GatewayIntent.GUILD_MESSAGES
| disagreement.GatewayIntent.MESSAGE_CONTENT
) )
client = disagreement.Client(token=BOT_TOKEN, command_prefix="!", intents=intents) client = Client(token=BOT_TOKEN, command_prefix="!", intents=intents)
# Simple list of banned words # Simple list of banned words
BANNED_WORDS: Set[str] = {"badword1", "badword2"} BANNED_WORDS: Set[str] = {"badword1", "badword2"}
class ModerationCog(commands.Cog): class ModerationCog(Cog):
def __init__(self, bot: disagreement.Client) -> None: def __init__(self, bot: Client) -> None:
super().__init__(bot) super().__init__(bot)
@commands.command() @command()
async def kick( async def kick(
self, ctx: commands.CommandContext, member: Member, *, reason: str = "" self, ctx: CommandContext, member: Member, *, reason: str = ""
) -> None: ) -> None:
"""Kick a member from the guild.""" """Kick a member from the guild."""
await member.kick(reason=reason or None) await member.kick(reason=reason or None)
await ctx.reply(f"Kicked {member.display_name}") await ctx.reply(f"Kicked {member.display_name}")
@commands.command() @command()
async def ban( async def ban(
self, ctx: commands.CommandContext, member: Member, *, reason: str = "" self, ctx: CommandContext, member: Member, *, reason: str = ""
) -> None: ) -> None:
"""Ban a member from the guild.""" """Ban a member from the guild."""
await member.ban(reason=reason or None) await member.ban(reason=reason or None)
@ -80,10 +84,10 @@ async def on_message(message: Message) -> None:
) )
async def main() -> None: def main() -> None:
client.add_cog(ModerationCog(client)) client.add_cog(ModerationCog(client))
await client.run() client.run()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) main()

View File

@ -25,9 +25,18 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
try: try:
import disagreement from disagreement import (
from disagreement.ext import commands Client,
from disagreement.models import Reaction, User, Member GatewayIntent,
Reaction,
User,
Member,
HTTPException,
AuthenticationError,
Cog,
command,
CommandContext,
)
except ImportError: except ImportError:
print( print(
"Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly." "Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly."
@ -50,10 +59,10 @@ BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
# We need GUILDS for server context, GUILD_MESSAGES to receive messages, # We need GUILDS for server context, GUILD_MESSAGES to receive messages,
# and GUILD_MESSAGE_REACTIONS to listen for reaction events. # and GUILD_MESSAGE_REACTIONS to listen for reaction events.
intents = ( intents = (
disagreement.GatewayIntent.GUILDS GatewayIntent.GUILDS
| disagreement.GatewayIntent.GUILD_MESSAGES | GatewayIntent.GUILD_MESSAGES
| disagreement.GatewayIntent.GUILD_MESSAGE_REACTIONS | GatewayIntent.GUILD_MESSAGE_REACTIONS
| disagreement.GatewayIntent.MESSAGE_CONTENT # For commands | GatewayIntent.MESSAGE_CONTENT # For commands
) )
# --- Initialize the Client --- # --- Initialize the Client ---
@ -61,22 +70,22 @@ if not BOT_TOKEN:
print("Error: The DISCORD_BOT_TOKEN environment variable is not set.") print("Error: The DISCORD_BOT_TOKEN environment variable is not set.")
sys.exit(1) sys.exit(1)
client = disagreement.Client(token=BOT_TOKEN, intents=intents, command_prefix="!") client = Client(token=BOT_TOKEN, intents=intents, command_prefix="!")
# --- Define a Cog for reaction-related commands --- # --- Define a Cog for reaction-related commands ---
class ReactionCog(commands.Cog): class ReactionCog(Cog):
def __init__(self, bot_client): def __init__(self, bot_client):
super().__init__(bot_client) super().__init__(bot_client)
@commands.command(name="react") @command(name="react")
async def react_command(self, ctx: commands.CommandContext): async def react_command(self, ctx: CommandContext):
"""Reacts to the command message with a thumbs up.""" """Reacts to the command message with a thumbs up."""
try: try:
# The emoji can be a standard Unicode emoji or a custom one in the format '<:name:id>' # The emoji can be a standard Unicode emoji or a custom one in the format '<:name:id>'
await ctx.message.add_reaction("👍") await ctx.message.add_reaction("👍")
print(f"Reacted to command from {ctx.author.username}") print(f"Reacted to command from {ctx.author.username}")
except disagreement.HTTPException as e: except HTTPException as e:
print(f"Failed to add reaction: {e}") print(f"Failed to add reaction: {e}")
await ctx.reply( await ctx.reply(
"I couldn't add the reaction. I might be missing permissions." "I couldn't add the reaction. I might be missing permissions."
@ -128,21 +137,21 @@ async def on_reaction_remove(reaction: Reaction, user: User | Member):
# --- Main Execution --- # --- Main Execution ---
async def main(): def main():
print("Starting Reaction Bot...") print("Starting Reaction Bot...")
try: try:
client.add_cog(ReactionCog(client)) client.add_cog(ReactionCog(client))
await client.run() client.run()
except disagreement.AuthenticationError: except AuthenticationError:
print("Authentication failed. Check your bot token.") print("Authentication failed. Check your bot token.")
except Exception as e: except Exception as e:
print(f"An unexpected error occurred: {e}") print(f"An unexpected error occurred: {e}")
traceback.print_exc() traceback.print_exc()
finally: finally:
if not client.is_closed(): if not client.is_closed():
await client.close() asyncio.run(client.close())
print("Bot has been shut down.") print("Bot has been shut down.")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) main()

View File

@ -1,7 +1,7 @@
from disagreement.ext import tasks from disagreement import loop
@tasks.loop(seconds=2.0) @loop(seconds=2.0)
async def ticker() -> None: async def ticker() -> None:
print("Extension tick") print("Extension tick")

View File

@ -8,7 +8,7 @@ import sys
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import disagreement from disagreement import Client
try: try:
from dotenv import load_dotenv from dotenv import load_dotenv
@ -23,7 +23,7 @@ TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
if not TOKEN: if not TOKEN:
raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set") raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set")
client = disagreement.Client(token=TOKEN, shard_count=2) client = Client(token=TOKEN, shard_count=2)
@client.event @client.event
@ -34,12 +34,12 @@ async def on_ready():
print("Shard bot ready") print("Shard bot ready")
async def main(): def main():
if not TOKEN: if not TOKEN:
print("DISCORD_BOT_TOKEN environment variable not set") print("DISCORD_BOT_TOKEN environment variable not set")
return return
await client.run() client.run()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) main()

View File

@ -8,12 +8,12 @@ import sys
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from disagreement.ext import tasks from disagreement import loop
counter = 0 counter = 0
@tasks.loop(seconds=1.0) @loop(seconds=1.0)
async def ticker() -> None: async def ticker() -> None:
global counter global counter
counter += 1 counter += 1

View File

@ -24,8 +24,15 @@ if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__fi
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
try: try:
import disagreement from disagreement import (
from disagreement.ext import commands Client,
GatewayIntent,
HTTPException,
AuthenticationError,
Cog,
command,
CommandContext,
)
except ImportError: except ImportError:
print( print(
"Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly." "Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly."
@ -46,9 +53,7 @@ BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
# --- Intents Configuration --- # --- Intents Configuration ---
intents = ( intents = (
disagreement.GatewayIntent.GUILDS GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
| disagreement.GatewayIntent.GUILD_MESSAGES
| disagreement.GatewayIntent.MESSAGE_CONTENT
) )
# --- Initialize the Client --- # --- Initialize the Client ---
@ -56,16 +61,16 @@ if not BOT_TOKEN:
print("Error: The DISCORD_BOT_TOKEN environment variable is not set.") print("Error: The DISCORD_BOT_TOKEN environment variable is not set.")
sys.exit(1) sys.exit(1)
client = disagreement.Client(token=BOT_TOKEN, intents=intents, command_prefix="!") client = Client(token=BOT_TOKEN, intents=intents, command_prefix="!")
# --- Define a Cog for the typing indicator command --- # --- Define a Cog for the typing indicator command ---
class TypingCog(commands.Cog): class TypingCog(Cog):
def __init__(self, bot_client): def __init__(self, bot_client):
super().__init__(bot_client) super().__init__(bot_client)
@commands.command(name="typing") @command(name="typing")
async def typing_test_command(self, ctx: commands.CommandContext): async def typing_test_command(self, ctx: CommandContext):
"""Shows a typing indicator for 5 seconds.""" """Shows a typing indicator for 5 seconds."""
await ctx.reply("Showing typing indicator for 5 seconds...") await ctx.reply("Showing typing indicator for 5 seconds...")
try: try:
@ -76,7 +81,7 @@ class TypingCog(commands.Cog):
await asyncio.sleep(5) await asyncio.sleep(5)
print("Typing indicator stopped.") print("Typing indicator stopped.")
await ctx.send("Done!") await ctx.send("Done!")
except disagreement.HTTPException as e: except HTTPException as e:
print(f"Failed to send typing indicator: {e}") print(f"Failed to send typing indicator: {e}")
await ctx.reply( await ctx.reply(
"I couldn't show the typing indicator. I might be missing permissions." "I couldn't show the typing indicator. I might be missing permissions."
@ -99,21 +104,21 @@ async def on_ready():
# --- Main Execution --- # --- Main Execution ---
async def main(): def main():
print("Starting Typing Indicator Bot...") print("Starting Typing Indicator Bot...")
try: try:
client.add_cog(TypingCog(client)) client.add_cog(TypingCog(client))
await client.run() client.run()
except disagreement.AuthenticationError: except AuthenticationError:
print("Authentication failed. Check your bot token.") print("Authentication failed. Check your bot token.")
except Exception as e: except Exception as e:
print(f"An unexpected error occurred: {e}") print(f"An unexpected error occurred: {e}")
traceback.print_exc() traceback.print_exc()
finally: finally:
if not client.is_closed(): if not client.is_closed():
await client.close() asyncio.run(client.close())
print("Bot has been shut down.") print("Bot has been shut down.")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) main()

View File

@ -16,7 +16,7 @@ except ImportError: # pragma: no cover - example helper
load_dotenv = None load_dotenv = None
print("python-dotenv is not installed. Environment variables will not be loaded") print("python-dotenv is not installed. Environment variables will not be loaded")
import disagreement from disagreement import Client
if load_dotenv: if load_dotenv:
load_dotenv() load_dotenv()
@ -39,7 +39,7 @@ CHANNEL_ID = cast(str, _CHANNEL_ID)
async def main() -> None: async def main() -> None:
client = disagreement.Client(TOKEN) client = Client(TOKEN)
await client.connect() await client.connect()
voice = await client.join_voice(GUILD_ID, CHANNEL_ID) voice = await client.join_voice(GUILD_ID, CHANNEL_ID)
try: try:

View File

@ -62,4 +62,5 @@ nav:
- 'Mentions': 'mentions.md' - 'Mentions': 'mentions.md'
- 'OAuth2': 'oauth2.md' - 'OAuth2': 'oauth2.md'
- 'Presence': 'presence.md' - 'Presence': 'presence.md'
- 'Voice Client': 'voice_client.md' - 'Voice Client': 'voice_client.md'
- 'Utility Helpers': 'utils.md'

View File

@ -1,6 +1,6 @@
[project] [project]
name = "disagreement" name = "disagreement"
version = "0.4.2" version = "0.8.1"
description = "A Python library for the Discord API." description = "A Python library for the Discord API."
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,6 +1,6 @@
{ {
"include": ["."], "include": ["."],
"exclude": ["**/node_modules", "**/__pycache__", "**/.venv", "**/.git", "**/dist", "**/build", "**/tests/**", "tavilytool.py"], "exclude": ["**/node_modules", "**/__pycache__", "**/.venv", "**/venv", "**/.git", "**/dist", "**/build", "**/tests/**", "tavilytool.py"],
"ignore": [], "ignore": [],
"reportMissingImports": true, "reportMissingImports": true,
"reportMissingTypeStubs": false, "reportMissingTypeStubs": false,

View File

@ -3,7 +3,16 @@ import pytest
from disagreement.ext.commands.converters import run_converters from disagreement.ext.commands.converters import run_converters
from disagreement.ext.commands.core import CommandContext, Command from disagreement.ext.commands.core import CommandContext, Command
from disagreement.ext.commands.errors import BadArgument from disagreement.ext.commands.errors import BadArgument
from disagreement.models import Message, Member, Role, Guild from disagreement.models import (
Message,
Member,
Role,
Guild,
User,
TextChannel,
VoiceChannel,
PartialEmoji,
)
from disagreement.enums import ( from disagreement.enums import (
VerificationLevel, VerificationLevel,
MessageNotificationLevel, MessageNotificationLevel,
@ -11,21 +20,27 @@ from disagreement.enums import (
MFALevel, MFALevel,
GuildNSFWLevel, GuildNSFWLevel,
PremiumTier, PremiumTier,
ChannelType,
) )
from disagreement.client import Client from disagreement.client import Client
from disagreement.cache import GuildCache from disagreement.cache import GuildCache, Cache, ChannelCache
class DummyBot(Client): class DummyBot(Client):
def __init__(self): def __init__(self):
super().__init__(token="test") super().__init__(token="test")
self._guilds = GuildCache() self._guilds = GuildCache()
self._users = Cache()
self._channels = ChannelCache()
def get_guild(self, guild_id): def get_guild(self, guild_id):
return self._guilds.get(guild_id) return self._guilds.get(guild_id)
def get_channel(self, channel_id):
return self._channels.get(channel_id)
async def fetch_member(self, guild_id, member_id): async def fetch_member(self, guild_id, member_id):
guild = self._guilds.get(guild_id) guild = self._guilds.get(guild_id)
return guild.get_member(member_id) if guild else None return guild.get_member(member_id) if guild else None
@ -37,6 +52,12 @@ class DummyBot(Client):
async def fetch_guild(self, guild_id): async def fetch_guild(self, guild_id):
return self._guilds.get(guild_id) return self._guilds.get(guild_id)
async def fetch_user(self, user_id):
return self._users.get(user_id)
async def fetch_channel(self, channel_id):
return self._channels.get(channel_id)
@pytest.fixture() @pytest.fixture()
def guild_objects(): def guild_objects():
@ -60,6 +81,9 @@ def guild_objects():
guild = Guild(guild_data, client_instance=bot) guild = Guild(guild_data, client_instance=bot)
bot._guilds.set(guild.id, guild) bot._guilds.set(guild.id, guild)
user = User({"id": "7", "username": "u", "discriminator": "0001"})
bot._users.set(user.id, user)
member = Member( member = Member(
{ {
"user": {"id": "3", "username": "m", "discriminator": "0001"}, "user": {"id": "3", "username": "m", "discriminator": "0001"},
@ -86,12 +110,38 @@ def guild_objects():
guild._members.set(member.id, member) guild._members.set(member.id, member)
guild.roles.append(role) guild.roles.append(role)
return guild, member, role text_channel = TextChannel(
{
"id": "20",
"type": ChannelType.GUILD_TEXT.value,
"guild_id": guild.id,
"permission_overwrites": [],
},
client_instance=bot,
)
voice_channel = VoiceChannel(
{
"id": "21",
"type": ChannelType.GUILD_VOICE.value,
"guild_id": guild.id,
"permission_overwrites": [],
},
client_instance=bot,
)
guild._channels.set(text_channel.id, text_channel)
guild.text_channels.append(text_channel)
guild._channels.set(voice_channel.id, voice_channel)
guild.voice_channels.append(voice_channel)
bot._channels.set(text_channel.id, text_channel)
bot._channels.set(voice_channel.id, voice_channel)
return guild, member, role, user, text_channel, voice_channel
@pytest.fixture() @pytest.fixture()
def command_context(guild_objects): def command_context(guild_objects):
guild, member, role = guild_objects guild, member, role, _, _, _ = guild_objects
bot = guild._client bot = guild._client
message_data = { message_data = {
"id": "10", "id": "10",
@ -114,7 +164,7 @@ def command_context(guild_objects):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_member_converter(command_context, guild_objects): async def test_member_converter(command_context, guild_objects):
_, member, _ = guild_objects _, member, _, _, _, _ = guild_objects
mention = f"<@!{member.id}>" mention = f"<@!{member.id}>"
result = await run_converters(command_context, Member, mention) result = await run_converters(command_context, Member, mention)
assert result is member assert result is member
@ -124,7 +174,7 @@ async def test_member_converter(command_context, guild_objects):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_role_converter(command_context, guild_objects): async def test_role_converter(command_context, guild_objects):
_, _, role = guild_objects _, _, role, _, _, _ = guild_objects
mention = f"<@&{role.id}>" mention = f"<@&{role.id}>"
result = await run_converters(command_context, Role, mention) result = await run_converters(command_context, Role, mention)
assert result is role assert result is role
@ -132,13 +182,55 @@ async def test_role_converter(command_context, guild_objects):
assert result is role assert result is role
@pytest.mark.asyncio
async def test_user_converter(command_context, guild_objects):
_, _, _, user, _, _ = guild_objects
mention = f"<@{user.id}>"
result = await run_converters(command_context, User, mention)
assert result is user
result = await run_converters(command_context, User, user.id)
assert result is user
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_guild_converter(command_context, guild_objects): async def test_guild_converter(command_context, guild_objects):
guild, _, _ = guild_objects guild, _, _, _, _, _ = guild_objects
result = await run_converters(command_context, Guild, guild.id) result = await run_converters(command_context, Guild, guild.id)
assert result is guild assert result is guild
@pytest.mark.asyncio
async def test_text_channel_converter(command_context, guild_objects):
_, _, _, _, text_channel, _ = guild_objects
mention = f"<#{text_channel.id}>"
result = await run_converters(command_context, TextChannel, mention)
assert result is text_channel
result = await run_converters(command_context, TextChannel, text_channel.id)
assert result is text_channel
@pytest.mark.asyncio
async def test_voice_channel_converter(command_context, guild_objects):
_, _, _, _, _, voice_channel = guild_objects
mention = f"<#{voice_channel.id}>"
result = await run_converters(command_context, VoiceChannel, mention)
assert result is voice_channel
result = await run_converters(command_context, VoiceChannel, voice_channel.id)
assert result is voice_channel
@pytest.mark.asyncio
async def test_emoji_converter(command_context):
result = await run_converters(command_context, PartialEmoji, "<:smile:1>")
assert isinstance(result, PartialEmoji)
assert result.id == "1"
assert result.name == "smile"
result = await run_converters(command_context, PartialEmoji, "😄")
assert result.id is None
assert result.name == "😄"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_member_converter_no_guild(): async def test_member_converter_no_guild():
guild_data = { guild_data = {

14
tests/test_asset.py Normal file
View File

@ -0,0 +1,14 @@
from disagreement.models import User
from disagreement.asset import Asset
def test_user_avatar_returns_asset():
user = User({"id": "1", "username": "u", "discriminator": "0001", "avatar": "abc"})
avatar = user.avatar
assert isinstance(avatar, Asset)
assert avatar.url == "https://cdn.discordapp.com/avatars/1/abc.png"
def test_user_avatar_none():
user = User({"id": "1", "username": "u", "discriminator": "0001"})
assert user.avatar is None

View File

@ -0,0 +1,83 @@
import asyncio
from unittest.mock import AsyncMock
import pytest
from disagreement.client import Client
from disagreement.gateway import GatewayClient
from disagreement.event_dispatcher import EventDispatcher
class DummyHTTP:
pass
class DummyUser:
username = "u"
discriminator = "0001"
@pytest.mark.asyncio
async def test_auto_sync_on_ready(monkeypatch):
client = Client(token="t", application_id="123")
http = DummyHTTP()
dispatcher = EventDispatcher(client)
gw = GatewayClient(
http_client=http,
event_dispatcher=dispatcher,
token="t",
intents=0,
client_instance=client,
)
monkeypatch.setattr(client, "parse_user", lambda d: DummyUser())
monkeypatch.setattr(gw._dispatcher, "dispatch", AsyncMock())
sync_mock = AsyncMock()
monkeypatch.setattr(client, "sync_application_commands", sync_mock)
data = {
"t": "READY",
"s": 1,
"d": {
"session_id": "s1",
"resume_gateway_url": "url",
"application": {"id": "123"},
"user": {"id": "1"},
},
}
await gw._handle_dispatch(data)
await asyncio.sleep(0)
sync_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_auto_sync_disabled(monkeypatch):
client = Client(token="t", application_id="123", sync_commands_on_ready=False)
http = DummyHTTP()
dispatcher = EventDispatcher(client)
gw = GatewayClient(
http_client=http,
event_dispatcher=dispatcher,
token="t",
intents=0,
client_instance=client,
)
monkeypatch.setattr(client, "parse_user", lambda d: DummyUser())
monkeypatch.setattr(gw._dispatcher, "dispatch", AsyncMock())
sync_mock = AsyncMock()
monkeypatch.setattr(client, "sync_application_commands", sync_mock)
data = {
"t": "READY",
"s": 1,
"d": {
"session_id": "s1",
"resume_gateway_url": "url",
"application": {"id": "123"},
"user": {"id": "1"},
},
}
await gw._handle_dispatch(data)
await asyncio.sleep(0)
sync_mock.assert_not_called()

View File

@ -1,6 +1,60 @@
import time import time
from disagreement.cache import Cache from disagreement.cache import Cache
from disagreement.client import Client
from disagreement.caching import MemberCacheFlags
from disagreement.enums import (
ChannelType,
ExplicitContentFilterLevel,
GuildNSFWLevel,
MFALevel,
MessageNotificationLevel,
PremiumTier,
VerificationLevel,
)
def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict:
base = {
"id": gid,
"name": f"g{gid}",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
"channels": [],
"members": [],
}
for i in range(channel_count):
base["channels"].append(
{
"id": f"{gid}-c{i}",
"type": ChannelType.GUILD_TEXT.value,
"guild_id": gid,
"permission_overwrites": [],
}
)
for i in range(member_count):
base["members"].append(
{
"user": {
"id": f"{gid}-m{i}",
"username": f"u{i}",
"discriminator": "0001",
},
"joined_at": "t",
"roles": [],
}
)
return base
def test_cache_store_and_get(): def test_cache_store_and_get():
@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item():
assert cache.get_or_fetch("c", fetch) == 3 assert cache.get_or_fetch("c", fetch) == 3
assert called assert called
def test_client_get_all_channels_and_members():
client = Client(token="t")
client.parse_guild(_guild_payload("1", 2, 2))
client.parse_guild(_guild_payload("2", 1, 1))
channels = {c.id for c in client.get_all_channels()}
members = {m.id for m in client.get_all_members()}
assert channels == {"1-c0", "1-c1", "2-c0"}
assert members == {"1-m0", "1-m1", "2-m0"}
def test_client_get_all_members_disabled_cache():
client = Client(token="t", member_cache_flags=MemberCacheFlags.none())
client.parse_guild(_guild_payload("1", 1, 2))
assert client.get_all_members() == []

View File

@ -1,7 +1,10 @@
import asyncio import asyncio
import pytest
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest
# pylint: disable=no-member
from disagreement.client import Client from disagreement.client import Client

View File

@ -0,0 +1,42 @@
import pytest
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.client import Client
@pytest.mark.asyncio
async def test_client_records_start_time(monkeypatch):
start = datetime(2020, 1, 1, tzinfo=timezone.utc)
monkeypatch.setattr("disagreement.client.utcnow", lambda: start)
client = Client(token="t")
monkeypatch.setattr(client, "_initialize_gateway", AsyncMock())
client._gateway = SimpleNamespace(connect=AsyncMock())
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
assert client.start_time is None
await client.connect()
assert client.start_time == start
@pytest.mark.asyncio
async def test_client_uptime(monkeypatch):
start = datetime(2020, 1, 1, tzinfo=timezone.utc)
end = start + timedelta(seconds=5)
times = [start, end]
def fake_now():
return times.pop(0)
monkeypatch.setattr("disagreement.client.utcnow", fake_now)
client = Client(token="t")
monkeypatch.setattr(client, "_initialize_gateway", AsyncMock())
client._gateway = SimpleNamespace(connect=AsyncMock())
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
await client.connect()
assert client.uptime() == timedelta(seconds=5)

View File

@ -6,6 +6,7 @@ from disagreement.ext.commands.decorators import (
check, check,
cooldown, cooldown,
requires_permissions, requires_permissions,
is_owner,
) )
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
from disagreement.permissions import Permissions from disagreement.permissions import Permissions
@ -133,3 +134,44 @@ async def test_requires_permissions_fail(message):
with pytest.raises(CheckFailure): with pytest.raises(CheckFailure):
await cmd.invoke(ctx) await cmd.invoke(ctx)
@pytest.mark.asyncio
async def test_is_owner_pass(message):
message._client.owner_ids = ["2"]
@is_owner()
async def cb(ctx):
pass
cmd = Command(cb)
ctx = CommandContext(
message=message,
bot=message._client,
prefix="!",
command=cmd,
invoked_with="test",
)
await cmd.invoke(ctx)
@pytest.mark.asyncio
async def test_is_owner_fail(message):
message._client.owner_ids = ["1"]
@is_owner()
async def cb(ctx):
pass
cmd = Command(cb)
ctx = CommandContext(
message=message,
bot=message._client,
prefix="!",
command=cmd,
invoked_with="test",
)
with pytest.raises(CheckFailure):
await cmd.invoke(ctx)

View File

@ -0,0 +1,59 @@
import asyncio
import pytest
from unittest.mock import AsyncMock
from disagreement.shard_manager import ShardManager
from disagreement.event_dispatcher import EventDispatcher
class DummyGateway:
def __init__(self, *args, **kwargs):
self.connect = AsyncMock()
self.close = AsyncMock()
dispatcher = kwargs.get("event_dispatcher")
shard_id = kwargs.get("shard_id")
async def emit_connect():
await dispatcher.dispatch("CONNECT", {"shard_id": shard_id})
async def emit_close():
await dispatcher.dispatch("DISCONNECT", {"shard_id": shard_id})
self.connect.side_effect = emit_connect
self.close.side_effect = emit_close
class DummyClient:
def __init__(self):
self._http = object()
self._event_dispatcher = EventDispatcher(self)
self.token = "t"
self.intents = 0
self.verbose = False
self.gateway_max_retries = 5
self.gateway_max_backoff = 60.0
@pytest.mark.asyncio
async def test_connect_disconnect_events(monkeypatch):
monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway)
client = DummyClient()
manager = ShardManager(client, shard_count=1)
events: list[tuple[str, int | None]] = []
async def on_connect(info):
events.append(("connect", info.get("shard_id")))
async def on_disconnect(info):
events.append(("disconnect", info.get("shard_id")))
client._event_dispatcher.register("CONNECT", on_connect)
client._event_dispatcher.register("DISCONNECT", on_disconnect)
await manager.start()
await manager.close()
assert ("connect", 0) in events
assert ("disconnect", 0) in events

View File

@ -0,0 +1,41 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.http import HTTPClient
from disagreement.client import Client
from disagreement.models import Message
@pytest.mark.asyncio
async def test_http_crosspost_message_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={"id": "m"})
data = await http.crosspost_message("c", "m")
http.request.assert_called_once_with(
"POST",
"/channels/c/messages/m/crosspost",
)
assert data == {"id": "m"}
@pytest.mark.asyncio
async def test_message_crosspost_returns_message():
payload = {
"id": "2",
"channel_id": "1",
"author": {"id": "3", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
http = SimpleNamespace(crosspost_message=AsyncMock(return_value=payload))
client = Client.__new__(Client)
client._http = http
client.parse_message = lambda d: Message(d, client_instance=client)
message = Message(payload, client_instance=client)
new_msg = await message.crosspost()
http.crosspost_message.assert_awaited_once_with("1", "2")
assert isinstance(new_msg, Message)
assert new_msg._client is client

View File

@ -22,6 +22,38 @@ def create_dummy_module(name):
return called return called
def create_async_module(name):
mod = types.ModuleType(name)
called = {"setup": False, "teardown": False}
async def setup():
called["setup"] = True
def teardown():
called["teardown"] = True
mod.setup = setup
mod.teardown = teardown
sys.modules[name] = mod
return called
def create_async_teardown_module(name):
mod = types.ModuleType(name)
called = {"setup": False, "teardown": False}
def setup():
called["setup"] = True
async def teardown():
called["teardown"] = True
mod.setup = setup
mod.teardown = teardown
sys.modules[name] = mod
return called
def test_load_and_unload_extension(): def test_load_and_unload_extension():
called = create_dummy_module("dummy_ext") called = create_dummy_module("dummy_ext")
@ -75,3 +107,23 @@ def test_reload_extension(monkeypatch):
loader.unload_extension("reload_ext") loader.unload_extension("reload_ext")
assert called_second["teardown"] is True assert called_second["teardown"] is True
def test_async_setup():
called = create_async_module("async_ext")
loader.load_extension("async_ext")
assert called["setup"] is True
loader.unload_extension("async_ext")
assert called["teardown"] is True
def test_async_teardown():
called = create_async_teardown_module("async_teardown_ext")
loader.load_extension("async_teardown_ext")
assert called["setup"] is True
loader.unload_extension("async_teardown_ext")
assert called["teardown"] is True

29
tests/test_get_cog.py Normal file
View File

@ -0,0 +1,29 @@
import asyncio
import pytest
from disagreement.client import Client
from disagreement.ext import commands
class DummyCog(commands.Cog):
def __init__(self, client: Client) -> None:
super().__init__(client)
@pytest.mark.asyncio()
async def test_command_handler_get_cog():
bot = object()
handler = commands.core.CommandHandler(client=bot, prefix="!")
cog = DummyCog(bot) # type: ignore[arg-type]
handler.add_cog(cog)
await asyncio.sleep(0) # allow any scheduled tasks to start
assert handler.get_cog("DummyCog") is cog
@pytest.mark.asyncio()
async def test_client_get_cog():
client = Client(token="t")
cog = DummyCog(client)
client.add_cog(cog)
await asyncio.sleep(0)
assert client.get_cog("DummyCog") is cog

59
tests/test_get_context.py Normal file
View File

@ -0,0 +1,59 @@
import pytest
from disagreement.client import Client
from disagreement.ext.commands.core import Command, CommandHandler
from disagreement.models import Message
class DummyBot:
def __init__(self):
self.executed = False
@pytest.mark.asyncio
async def test_get_context_parses_without_execution():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
async def foo(ctx, number: int, word: str):
bot.executed = True
handler.add_command(Command(foo, name="foo"))
msg_data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "!foo 1 bar",
"timestamp": "t",
}
msg = Message(msg_data, client_instance=bot)
ctx = await handler.get_context(msg)
assert ctx is not None
assert ctx.command.name == "foo"
assert ctx.args == [1, "bar"]
assert bot.executed is False
@pytest.mark.asyncio
async def test_client_get_context():
client = Client(token="t")
async def foo(ctx):
raise RuntimeError("should not run")
client.command_handler.add_command(Command(foo, name="foo"))
msg_data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "!foo",
"timestamp": "t",
}
msg = Message(msg_data, client_instance=client)
ctx = await client.get_context(msg)
assert ctx is not None
assert ctx.command.name == "foo"

View File

@ -0,0 +1,126 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.http import HTTPClient
from disagreement.client import Client
from disagreement.models import Guild, TextChannel, VoiceChannel, CategoryChannel
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
ChannelType,
)
def _guild_data():
return {
"id": "1",
"name": "g",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
@pytest.mark.asyncio
async def test_http_create_guild_channel_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={})
payload = {"name": "chan", "type": ChannelType.GUILD_TEXT.value}
await http.create_guild_channel("1", payload, reason="r")
http.request.assert_called_once_with(
"POST",
"/guilds/1/channels",
payload=payload,
custom_headers={"X-Audit-Log-Reason": "r"},
)
@pytest.mark.asyncio
async def test_guild_create_text_channel_returns_channel():
http = SimpleNamespace(
create_guild_channel=AsyncMock(
return_value={
"id": "10",
"type": ChannelType.GUILD_TEXT.value,
"guild_id": "1",
"permission_overwrites": [],
}
)
)
client = Client(token="t")
client._http = http
guild = Guild(_guild_data(), client_instance=client)
channel = await guild.create_text_channel("general")
http.create_guild_channel.assert_awaited_once_with(
"1", {"name": "general", "type": ChannelType.GUILD_TEXT.value}, reason=None
)
assert isinstance(channel, TextChannel)
assert client._channels.get("10") is channel
@pytest.mark.asyncio
async def test_guild_create_voice_channel_returns_channel():
http = SimpleNamespace(
create_guild_channel=AsyncMock(
return_value={
"id": "11",
"type": ChannelType.GUILD_VOICE.value,
"guild_id": "1",
"permission_overwrites": [],
}
)
)
client = Client(token="t")
client._http = http
guild = Guild(_guild_data(), client_instance=client)
channel = await guild.create_voice_channel("Voice")
http.create_guild_channel.assert_awaited_once_with(
"1", {"name": "Voice", "type": ChannelType.GUILD_VOICE.value}, reason=None
)
assert isinstance(channel, VoiceChannel)
assert client._channels.get("11") is channel
@pytest.mark.asyncio
async def test_guild_create_category_returns_channel():
http = SimpleNamespace(
create_guild_channel=AsyncMock(
return_value={
"id": "12",
"type": ChannelType.GUILD_CATEGORY.value,
"guild_id": "1",
"permission_overwrites": [],
}
)
)
client = Client(token="t")
client._http = http
guild = Guild(_guild_data(), client_instance=client)
channel = await guild.create_category("Cat")
http.create_guild_channel.assert_awaited_once_with(
"1", {"name": "Cat", "type": ChannelType.GUILD_CATEGORY.value}, reason=None
)
assert isinstance(channel, CategoryChannel)
assert client._channels.get("12") is channel

View File

@ -0,0 +1,63 @@
import pytest
from disagreement.client import Client
from disagreement.enums import (
ChannelType,
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
)
from disagreement.models import TextChannel, VoiceChannel, CategoryChannel
@pytest.mark.asyncio
async def test_guild_channel_lists_populated():
client = Client(token="t")
guild_data = {
"id": "1",
"name": "g",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
"channels": [
{
"id": "10",
"type": ChannelType.GUILD_TEXT.value,
"guild_id": "1",
"permission_overwrites": [],
},
{
"id": "11",
"type": ChannelType.GUILD_VOICE.value,
"guild_id": "1",
"permission_overwrites": [],
},
{
"id": "12",
"type": ChannelType.GUILD_CATEGORY.value,
"guild_id": "1",
"permission_overwrites": [],
},
],
}
guild = client.parse_guild(guild_data)
assert len(guild.text_channels) == 1
assert isinstance(guild.text_channels[0], TextChannel)
assert len(guild.voice_channels) == 1
assert isinstance(guild.voice_channels[0], VoiceChannel)
assert len(guild.category_channels) == 1
assert isinstance(guild.category_channels[0], CategoryChannel)

64
tests/test_guild_prune.py Normal file
View File

@ -0,0 +1,64 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.http import HTTPClient
from disagreement.client import Client
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
)
from disagreement.models import Guild
@pytest.mark.asyncio
async def test_http_get_guild_prune_count_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={"pruned": 3})
count = await http.get_guild_prune_count("1", days=7)
http.request.assert_called_once_with("GET", f"/guilds/1/prune", params={"days": 7})
assert count == 3
@pytest.mark.asyncio
async def test_http_begin_guild_prune_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={"pruned": 2})
count = await http.begin_guild_prune("1", days=1, compute_count=True)
http.request.assert_called_once_with(
"POST",
f"/guilds/1/prune",
payload={"days": 1, "compute_prune_count": True},
)
assert count == 2
@pytest.mark.asyncio
async def test_guild_prune_members_calls_http():
http = SimpleNamespace(begin_guild_prune=AsyncMock(return_value=1))
client = Client(token="t")
client._http = http
guild_data = {
"id": "1",
"name": "g",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
guild = Guild(guild_data, client_instance=client)
count = await guild.prune_members(2)
http.begin_guild_prune.assert_awaited_once_with("1", days=2, compute_count=True)
assert count == 1

View File

@ -0,0 +1,56 @@
import pytest
from unittest.mock import Mock
from disagreement.models import Guild
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
)
class DummyShard:
def __init__(self, shard_id):
self.id = shard_id
self.count = 1
self.gateway = Mock()
class DummyManager:
def __init__(self):
self.shards = [DummyShard(0)]
class DummyClient:
pass
def _guild_data():
return {
"id": "1",
"name": "g",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
"shard_id": 0,
}
def test_guild_shard_property():
client = DummyClient()
client._shard_manager = DummyManager()
guild = Guild(_guild_data(), client_instance=client, shard_id=0)
assert guild.shard_id == 0
assert guild.shard is client._shard_manager.shards[0]

View File

@ -0,0 +1,86 @@
import types
from disagreement.models import User, Guild, Channel, Message
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
ChannelType,
)
def _guild_data(gid="1"):
return {
"id": gid,
"name": "g",
"owner_id": gid,
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
def _user(uid="1"):
return User({"id": uid, "username": "u", "discriminator": "0001"})
def _message(mid="1"):
data = {
"id": mid,
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
return Message(data, client_instance=types.SimpleNamespace())
def _channel(cid="1"):
data = {"id": cid, "type": ChannelType.GUILD_TEXT.value}
return Channel(data, client_instance=types.SimpleNamespace())
def test_user_hash_and_eq():
a = _user()
b = _user()
c = _user("2")
assert a == b
assert hash(a) == hash(b)
assert a != c
def test_guild_hash_and_eq():
a = Guild(_guild_data(), client_instance=types.SimpleNamespace())
b = Guild(_guild_data(), client_instance=types.SimpleNamespace())
c = Guild(_guild_data("2"), client_instance=types.SimpleNamespace())
assert a == b
assert hash(a) == hash(b)
assert a != c
def test_channel_hash_and_eq():
a = _channel()
b = _channel()
c = _channel("2")
assert a == b
assert hash(a) == hash(b)
assert a != c
def test_message_hash_and_eq():
a = _message()
b = _message()
c = _message("2")
assert a == b
assert hash(a) == hash(b)
assert a != c

View File

@ -1,7 +1,9 @@
import pytest import pytest
from disagreement.ext.commands.core import CommandHandler, Command from disagreement.ext import commands
from disagreement.ext.commands.core import CommandHandler, Command, Group
from disagreement.models import Message from disagreement.models import Message
from disagreement.ext.commands.help import HelpCommand
class DummyBot: class DummyBot:
@ -13,15 +15,21 @@ class DummyBot:
return {"id": "1", "channel_id": channel_id, "content": content} return {"id": "1", "channel_id": channel_id, "content": content}
class MyCog(commands.Cog):
def __init__(self, client) -> None:
super().__init__(client)
@commands.command()
async def foo(self, ctx: commands.CommandContext) -> None:
pass
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_help_lists_commands(): async def test_help_lists_commands():
bot = DummyBot() bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!") handler = CommandHandler(client=bot, prefix="!")
async def foo(ctx): handler.add_cog(MyCog(bot))
pass
handler.add_command(Command(foo, name="foo", brief="Foo cmd"))
msg_data = { msg_data = {
"id": "1", "id": "1",
@ -33,6 +41,7 @@ async def test_help_lists_commands():
msg = Message(msg_data, client_instance=bot) msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg) await handler.process_commands(msg)
assert any("foo" in m for m in bot.sent) assert any("foo" in m for m in bot.sent)
assert any("MyCog" in m for m in bot.sent)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -55,3 +64,65 @@ async def test_help_specific_command():
msg = Message(msg_data, client_instance=bot) msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg) await handler.process_commands(msg)
assert any("Bar desc" in m for m in bot.sent) assert any("Bar desc" in m for m in bot.sent)
class CustomHelp(HelpCommand):
async def send_command_help(self, ctx, command):
await ctx.send(f"custom {command.name}")
async def send_group_help(self, ctx, group):
await ctx.send(f"group {group.name}")
@pytest.mark.asyncio
async def test_custom_help_methods():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
handler.remove_command("help")
handler.add_command(CustomHelp(handler))
async def sub(ctx):
pass
group = Group(sub, name="grp")
handler.add_command(group)
msg_data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "!help grp",
"timestamp": "t",
}
msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg)
assert any("group grp" in m for m in bot.sent)
@pytest.mark.asyncio
async def test_help_lists_subcommands():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
async def root(ctx):
pass
group = Group(root, name="root")
@group.command(name="child")
async def child(ctx):
pass
handler.add_command(group)
msg_data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "!help",
"timestamp": "t",
}
msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg)
assert any("root" in m for m in bot.sent)
assert any("child" in m for m in bot.sent)

31
tests/test_invite.py Normal file
View File

@ -0,0 +1,31 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.http import HTTPClient
from disagreement.client import Client
from disagreement.models import Invite
@pytest.mark.asyncio
async def test_http_get_invite_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={"code": "abc"})
result = await http.get_invite("abc")
http.request.assert_called_once_with("GET", "/invites/abc")
assert result == {"code": "abc"}
@pytest.mark.asyncio
async def test_client_fetch_invite_returns_invite():
http = SimpleNamespace(get_invite=AsyncMock(return_value={"code": "abc"}))
client = Client.__new__(Client)
client._http = http
client._closed = False
invite = await client.fetch_invite("abc")
http.get_invite.assert_awaited_once_with("abc")
assert isinstance(invite, Invite)

39
tests/test_invites.py Normal file
View File

@ -0,0 +1,39 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.client import Client
from disagreement.http import HTTPClient
from disagreement.models import TextChannel, Invite
@pytest.mark.asyncio
async def test_create_channel_invite_calls_request_and_returns_model():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={"code": "abc"})
invite = await http.create_channel_invite("123", {"max_age": 60}, reason="r")
http.request.assert_called_once_with(
"POST",
"/channels/123/invites",
payload={"max_age": 60},
custom_headers={"X-Audit-Log-Reason": "r"},
)
assert isinstance(invite, Invite)
@pytest.mark.asyncio
async def test_textchannel_create_invite_uses_http():
http = SimpleNamespace(
create_channel_invite=AsyncMock(return_value=Invite.from_dict({"code": "a"}))
)
client = Client(token="t")
client._http = http
channel = TextChannel({"id": "c", "type": 0}, client)
invite = await channel.create_invite(max_age=30, reason="why")
http.create_channel_invite.assert_awaited_once_with(
"c", {"max_age": 30}, reason="why"
)
assert isinstance(invite, Invite)

View File

@ -1,4 +1,21 @@
from disagreement.models import Member import pytest # pylint: disable=E0401
from disagreement.client import Client
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
)
from disagreement.models import Member, Guild, Role
from disagreement.permissions import Permissions
class DummyClient(Client):
def __init__(self):
super().__init__(token="test")
def _make_member(member_id: str, username: str, nick: str | None): def _make_member(member_id: str, username: str, nick: str | None):
@ -12,6 +29,58 @@ def _make_member(member_id: str, username: str, nick: str | None):
return Member(data, client_instance=None) return Member(data, client_instance=None)
def _base_guild(client: Client) -> Guild:
data = {
"id": "1",
"name": "g",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
guild = Guild(data, client_instance=client)
client._guilds.set(guild.id, guild)
return guild
def _role(guild: Guild, rid: str, perms: Permissions) -> Role:
role = Role(
{
"id": rid,
"name": f"r{rid}",
"color": 0,
"hoist": False,
"position": 0,
"permissions": str(int(perms)),
"managed": False,
"mentionable": False,
}
)
guild.roles.append(role)
return role
def _member(guild: Guild, client: Client, *roles: Role) -> Member:
data = {
"user": {"id": "10", "username": "u", "discriminator": "0001"},
"joined_at": "t",
"roles": [r.id for r in roles] or [guild.id],
}
member = Member(data, client_instance=client)
member.guild_id = guild.id
member._client = client
guild._members.set(member.id, member)
return member
def test_display_name_prefers_nick(): def test_display_name_prefers_nick():
member = _make_member("1", "u", "nickname") member = _make_member("1", "u", "nickname")
assert member.display_name == "nickname" assert member.display_name == "nickname"
@ -20,3 +89,25 @@ def test_display_name_prefers_nick():
def test_display_name_falls_back_to_username(): def test_display_name_falls_back_to_username():
member = _make_member("2", "u2", None) member = _make_member("2", "u2", None)
assert member.display_name == "u2" assert member.display_name == "u2"
def test_guild_permissions_from_roles():
client = DummyClient()
guild = _base_guild(client)
everyone = _role(guild, guild.id, Permissions.VIEW_CHANNEL)
mod = _role(guild, "2", Permissions.MANAGE_MESSAGES)
member = _member(guild, client, everyone, mod)
perms = member.guild_permissions
assert perms & Permissions.VIEW_CHANNEL
assert perms & Permissions.MANAGE_MESSAGES
assert not perms & Permissions.BAN_MEMBERS
def test_guild_permissions_administrator_role_grants_all():
client = DummyClient()
guild = _base_guild(client)
admin = _role(guild, "2", Permissions.ADMINISTRATOR)
member = _member(guild, client, admin)
assert member.guild_permissions == Permissions(~0)

View File

@ -0,0 +1,36 @@
from disagreement.models import Member, VoiceState
def test_member_voice_dataclass():
data = {
"user": {"id": "1", "username": "u", "discriminator": "0001"},
"joined_at": "t",
"roles": [],
"voice_state": {
"guild_id": "g",
"channel_id": "c",
"user_id": "1",
"session_id": "s",
"deaf": False,
"mute": True,
"self_deaf": False,
"self_mute": False,
"self_video": False,
"suppress": False,
},
}
member = Member(data, client_instance=None)
voice = member.voice
assert isinstance(voice, VoiceState)
assert voice.channel_id == "c"
assert voice.mute is True
def test_member_voice_none():
data = {
"user": {"id": "2", "username": "u2", "discriminator": "0001"},
"joined_at": "t",
"roles": [],
}
member = Member(data, client_instance=None)
assert member.voice is None

View File

@ -21,3 +21,19 @@ def test_clean_content_removes_mentions():
def test_clean_content_no_mentions(): def test_clean_content_no_mentions():
msg = make_message("Just text") msg = make_message("Just text")
assert msg.clean_content == "Just text" assert msg.clean_content == "Just text"
def test_created_at_parses_timestamp():
ts = "2024-05-04T12:34:56+00:00"
msg = make_message("hi")
msg.timestamp = ts
assert msg.created_at.isoformat() == ts
def test_edited_at_parses_timestamp_or_none():
ts = "2024-05-04T12:35:56+00:00"
msg = make_message("hi")
msg.timestamp = ts
assert msg.edited_at is None
msg.edited_timestamp = ts
assert msg.edited_at.isoformat() == ts

15
tests/test_object.py Normal file
View File

@ -0,0 +1,15 @@
from disagreement.object import Object
def test_object_int():
obj = Object(123)
assert int(obj) == 123
def test_object_equality_and_hash():
a = Object(1)
b = Object(1)
c = Object(2)
assert a == b
assert a != c
assert hash(a) == hash(b)

23
tests/test_paginator.py Normal file
View File

@ -0,0 +1,23 @@
from disagreement.utils import Paginator
def test_paginator_single_page():
p = Paginator(limit=10)
p.add_line("hi")
p.add_line("there")
assert p.pages == ["hi\nthere"]
def test_paginator_splits_pages():
p = Paginator(limit=10)
p.add_line("12345")
p.add_line("67890")
assert p.pages == ["12345", "67890"]
p.add_line("xyz")
assert p.pages == ["12345", "67890\nxyz"]
def test_paginator_handles_long_line():
p = Paginator(limit=5)
p.add_line("abcdef")
assert p.pages == ["abcde", "f"]

View File

@ -32,3 +32,11 @@ def test_missing_permissions():
current, Permissions.SEND_MESSAGES, Permissions.MANAGE_MESSAGES current, Permissions.SEND_MESSAGES, Permissions.MANAGE_MESSAGES
) )
assert missing == [Permissions.MANAGE_MESSAGES] assert missing == [Permissions.MANAGE_MESSAGES]
def test_permissions_all():
all_value = Permissions.all()
union = Permissions(0)
for perm in Permissions:
union |= perm
assert all_value == union

View File

@ -1,3 +1,4 @@
import io
import pytest import pytest
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@ -38,7 +39,9 @@ async def test_http_send_message_with_files_uses_formdata():
"timestamp": "t", "timestamp": "t",
} }
) )
await http.send_message("c", "hi", files=[File("f.txt", b"data")]) await http.send_message(
"c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")]
)
args, kwargs = http.request.call_args args, kwargs = http.request.call_args
assert kwargs["is_json"] is False assert kwargs["is_json"] is False
@ -75,7 +78,33 @@ async def test_client_send_message_passes_files():
"timestamp": "t", "timestamp": "t",
} }
) )
await client.send_message("c", "hi", files=[File("f.txt", b"data")]) await client.send_message(
"c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")]
)
client._http.send_message.assert_awaited_once() client._http.send_message.assert_awaited_once()
kwargs = client._http.send_message.call_args.kwargs kwargs = client._http.send_message.call_args.kwargs
assert kwargs["files"][0].filename == "f.txt" assert kwargs["files"][0].filename == "f.txt"
@pytest.mark.asyncio
async def test_file_from_path(tmp_path):
file_path = tmp_path / "path.txt"
file_path.write_bytes(b"ok")
http = HTTPClient(token="t")
http.request = AsyncMock(
return_value={
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
)
await http.send_message("c", "hi", files=[File(file_path)])
_, kwargs = http.request.call_args
assert kwargs["is_json"] is False
def test_file_spoiler():
f = File(io.BytesIO(b"d"), filename="a.txt", spoiler=True)
assert f.filename == "SPOILER_a.txt"

View File

@ -82,3 +82,24 @@ async def test_before_after_loop_callbacks() -> None:
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert events and events[0] == "before" assert events and events[0] == "before"
assert "after" in events assert "after" in events
@pytest.mark.asyncio
async def test_change_interval_and_current_loop() -> None:
count = 0
@tasks.loop(seconds=0.01)
async def ticker() -> None:
nonlocal count
count += 1
ticker.start()
await asyncio.sleep(0.03)
initial = ticker.current_loop
ticker.change_interval(seconds=0.02)
await asyncio.sleep(0.05)
ticker.stop()
assert initial >= 2
assert ticker.current_loop > initial
assert count == ticker.current_loop

View File

@ -1,8 +1,54 @@
from datetime import timezone from datetime import datetime, timezone
from types import SimpleNamespace
from disagreement.utils import utcnow from disagreement.utils import (
escape_markdown,
escape_mentions,
utcnow,
snowflake_time,
find,
get
)
def test_utcnow_timezone(): def test_utcnow_timezone():
now = utcnow() now = utcnow()
assert now.tzinfo == timezone.utc assert now.tzinfo == timezone.utc
def test_find_returns_matching_element():
seq = [1, 2, 3]
assert find(lambda x: x > 1, seq) == 2
assert find(lambda x: x > 3, seq) is None
def test_get_matches_attributes():
items = [
SimpleNamespace(id=1, name="a"),
SimpleNamespace(id=2, name="b"),
]
assert get(items, id=2) is items[1]
assert get(items, id=1, name="a") is items[0]
assert get(items, name="c") is None
def test_snowflake_time():
dt = datetime(2020, 1, 1, tzinfo=timezone.utc)
ms = int(dt.timestamp() * 1000) - 1420070400000
snowflake = ms << 22
assert snowflake_time(snowflake) == dt
def test_escape_markdown():
text = "**bold** _under_ ~strike~ `code` > quote | pipe"
escaped = escape_markdown(text)
assert (
escaped
== "\\*\\*bold\\*\\* \\_under\\_ \\~strike\\~ \\`code\\` \\> quote \\| pipe"
)
def test_escape_mentions():
text = "Hello @everyone and <@123>!"
escaped = escape_mentions(text)
assert escaped == "Hello @\u200beveryone and <@\u200b123>!"

Some files were not shown because too many files have changed in this diff Show More