Initial commit
This commit is contained in:
commit
7c7cb4137c
34
.github/workflows/ci.yml
vendored
Normal file
34
.github/workflows/ci.yml
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
name: Python CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
npm install -g pyright
|
||||
|
||||
- name: Run Pyright
|
||||
run: pyright
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest tests/
|
129
.gitignore
vendored
Normal file
129
.gitignore
vendored
Normal file
@ -0,0 +1,129 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a PyInstaller script; this is deployed to PyInstaller's temporary directory.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# PEP 582; __pypackages__
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# IDE specific files
|
||||
.idea/
|
||||
.vscode/
|
||||
.kilocode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
64
AGENTS.md
Normal file
64
AGENTS.md
Normal file
@ -0,0 +1,64 @@
|
||||
# Agents
|
||||
- There are no nested `AGENTS.md` files; this is the only one in the project.
|
||||
- Tools to use for testing: `pyright`, `pylint`, `pytest`, `black`
|
||||
|
||||
- You have a Python script `tavilytool.py` in the project root that you can use to search the web.
|
||||
|
||||
# Tavily API Script Usage Instructions
|
||||
|
||||
## Basic Usage
|
||||
Search for information using simple queries:
|
||||
```bash
|
||||
python tavilytool.py "your search query"
|
||||
```
|
||||
|
||||
## Examples
|
||||
```bash
|
||||
python tavilytool.py "latest AI development 2024"
|
||||
python tavilytool.py "how to make chocolate chip cookies"
|
||||
python tavilytool.py "current weather in New York"
|
||||
python tavilytool.py "best programming practices Python"
|
||||
```
|
||||
|
||||
## Advanced Options
|
||||
|
||||
### Search Depth
|
||||
- **Basic search**: `python tavilytool.py "query"` (default)
|
||||
- **Advanced search**: `python tavilytool.py "query" --depth advanced`
|
||||
|
||||
### Control Results
|
||||
- **Limit results**: `python tavilytool.py "query" --max-results 3`
|
||||
- **Include images**: `python tavilytool.py "query" --include-images`
|
||||
- **Skip AI answer**: `python tavilytool.py "query" --no-answer`
|
||||
|
||||
### Domain Filtering
|
||||
- **Include specific domains**: `python tavilytool.py "query" --include-domains reddit.com stackoverflow.com`
|
||||
- **Exclude domains**: `python tavilytool.py "query" --exclude-domains wikipedia.org`
|
||||
|
||||
### Output Format
|
||||
- **Formatted output**: `python tavilytool.py "query"` (default - human readable)
|
||||
- **Raw JSON**: `python tavilytool.py "query" --raw` (for programmatic processing)
|
||||
|
||||
## Output Structure
|
||||
The default formatted output includes:
|
||||
- 🤖 **AI Answer**: Direct answer to your query
|
||||
- 🔍 **Search Results**: Titles, URLs, and content snippets
|
||||
- 🖼️ **Images**: Relevant images (when `--include-images` is used)
|
||||
|
||||
## Command Combinations
|
||||
```bash
|
||||
# Advanced search with images, limited results
|
||||
python tavilytool.py "machine learning tutorials" --depth advanced --include-images --max-results 3
|
||||
|
||||
# Search specific sites only, raw output
|
||||
python tavilytool.py "Python best practices" --include-domains github.com stackoverflow.com --raw
|
||||
|
||||
# Quick search without AI answer
|
||||
python tavilytool.py "today's news" --no-answer --max-results 5
|
||||
```
|
||||
|
||||
## Tips
|
||||
- Always quote your search queries to handle spaces and special characters
|
||||
- Use `--max-results` to control response length and API usage
|
||||
- Use `--raw` when you need to parse results programmatically
|
||||
- Combine options as needed for specific use cases
|
26
LICENSE
Normal file
26
LICENSE
Normal file
@ -0,0 +1,26 @@
|
||||
Copyright (c) 2025, Slipstream
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
3
MANIFEST.in
Normal file
3
MANIFEST.in
Normal file
@ -0,0 +1,3 @@
|
||||
graft docs
|
||||
graft examples
|
||||
include LICENSE
|
131
README.md
Normal file
131
README.md
Normal file
@ -0,0 +1,131 @@
|
||||
# Disagreement
|
||||
|
||||
A Python library for interacting with the Discord API, with a focus on bot development.
|
||||
|
||||
## Features
|
||||
|
||||
- Asynchronous design using `aiohttp`
|
||||
- Gateway and HTTP API clients
|
||||
- Slash command framework
|
||||
- Message component helpers
|
||||
- Built-in caching layer
|
||||
- Experimental voice support
|
||||
- Helpful error handling utilities
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
python -m pip install -U pip
|
||||
pip install disagreement
|
||||
# or install from source for development
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Requires Python 3.11 or newer.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import os
|
||||
import disagreement
|
||||
|
||||
# Ensure DISCORD_BOT_TOKEN is set in your environment
|
||||
client = disagreement.Client(token=os.environ.get("DISCORD_BOT_TOKEN"))
|
||||
|
||||
@client.on_event('MESSAGE_CREATE')
|
||||
async def on_message(message: disagreement.Message):
|
||||
print(f"Received: {message.content} from {message.author.username}")
|
||||
if message.content.lower() == '!ping':
|
||||
await message.reply('Pong!')
|
||||
|
||||
async def main():
|
||||
if not client.token:
|
||||
print("Error: DISCORD_BOT_TOKEN environment variable not set.")
|
||||
return
|
||||
try:
|
||||
async with client:
|
||||
await asyncio.Future() # run until cancelled
|
||||
except KeyboardInterrupt:
|
||||
print("Bot shutting down...")
|
||||
# Add any other specific exception handling from your library, e.g., disagreement.AuthenticationError
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### Global Error Handling
|
||||
|
||||
To ensure unexpected errors don't crash your bot, you can enable the library's
|
||||
global error handler:
|
||||
|
||||
```python
|
||||
import disagreement
|
||||
|
||||
disagreement.setup_global_error_handler()
|
||||
```
|
||||
|
||||
Call this early in your program to log unhandled exceptions instead of letting
|
||||
them terminate the process.
|
||||
|
||||
### Configuring Logging
|
||||
|
||||
Use :func:`disagreement.logging_config.setup_logging` to configure logging for
|
||||
your bot. The helper accepts a logging level and an optional file path.
|
||||
|
||||
```python
|
||||
import logging
|
||||
from disagreement.logging_config import setup_logging
|
||||
|
||||
setup_logging(logging.INFO)
|
||||
# Or log to a file
|
||||
setup_logging(logging.DEBUG, file="bot.log")
|
||||
```
|
||||
|
||||
### Defining Subcommands with `AppCommandGroup`
|
||||
|
||||
```python
|
||||
from disagreement.ext.app_commands import AppCommandGroup
|
||||
|
||||
settings = AppCommandGroup("settings", "Manage settings")
|
||||
|
||||
@settings.command(name="show")
|
||||
async def show(ctx):
|
||||
"""Displays a setting."""
|
||||
...
|
||||
|
||||
@settings.group("admin", description="Admin settings")
|
||||
def admin_group():
|
||||
pass
|
||||
|
||||
@admin_group.command(name="set")
|
||||
async def set_setting(ctx, key: str, value: str):
|
||||
...
|
||||
## Fetching Guilds
|
||||
|
||||
Use `Client.fetch_guild` to retrieve a guild from the Discord API if it
|
||||
isn't already cached. This is useful when working with guild IDs from
|
||||
outside the gateway events.
|
||||
|
||||
```python
|
||||
guild = await client.fetch_guild("123456789012345678")
|
||||
roles = await client.fetch_roles(guild.id)
|
||||
```
|
||||
|
||||
## Sharding
|
||||
|
||||
To run your bot across multiple gateway shards, pass `shard_count` when creating
|
||||
the client:
|
||||
|
||||
```python
|
||||
client = disagreement.Client(token=BOT_TOKEN, shard_count=2)
|
||||
```
|
||||
|
||||
See `examples/sharded_bot.py` for a full example.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please open an issue or submit a pull request.
|
||||
|
||||
See the [docs](docs/) directory for detailed guides on components, slash commands, caching, and voice features.
|
||||
|
36
disagreement/__init__.py
Normal file
36
disagreement/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
# disagreement/__init__.py
|
||||
|
||||
"""
|
||||
Disagreement
|
||||
~~~~~~~~~~~~
|
||||
|
||||
A Python library for interacting with the Discord API.
|
||||
|
||||
:copyright: (c) 2025 Slipstream
|
||||
:license: BSD 3-Clause License, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
__title__ = "disagreement"
|
||||
__author__ = "Slipstream"
|
||||
__license__ = "BSD 3-Clause License"
|
||||
__copyright__ = "Copyright 2025 Slipstream"
|
||||
__version__ = "0.0.1"
|
||||
|
||||
from .client import Client
|
||||
from .models import Message, User
|
||||
from .voice_client import VoiceClient
|
||||
from .typing import Typing
|
||||
from .errors import (
|
||||
DisagreementException,
|
||||
HTTPException,
|
||||
GatewayException,
|
||||
AuthenticationError,
|
||||
)
|
||||
from .enums import GatewayIntent, GatewayOpcode # Export enums
|
||||
from .error_handler import setup_global_error_handler
|
||||
from .hybrid_context import HybridContext
|
||||
from .ext import tasks
|
||||
|
||||
# Set up logging if desired
|
||||
# import logging
|
||||
# logging.getLogger(__name__).addHandler(logging.NullHandler())
|
55
disagreement/cache.py
Normal file
55
disagreement/cache.py
Normal file
@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import Channel, Guild
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Cache(Generic[T]):
|
||||
"""Simple in-memory cache with optional TTL support."""
|
||||
|
||||
def __init__(self, ttl: Optional[float] = None) -> None:
|
||||
self.ttl = ttl
|
||||
self._data: Dict[str, tuple[T, Optional[float]]] = {}
|
||||
|
||||
def set(self, key: str, value: T) -> None:
|
||||
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
|
||||
self._data[key] = (value, expiry)
|
||||
|
||||
def get(self, key: str) -> Optional[T]:
|
||||
item = self._data.get(key)
|
||||
if not item:
|
||||
return None
|
||||
value, expiry = item
|
||||
if expiry is not None and expiry < time.monotonic():
|
||||
self.invalidate(key)
|
||||
return None
|
||||
return value
|
||||
|
||||
def invalidate(self, key: str) -> None:
|
||||
self._data.pop(key, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._data.clear()
|
||||
|
||||
def values(self) -> list[T]:
|
||||
now = time.monotonic()
|
||||
items = []
|
||||
for key, (value, expiry) in list(self._data.items()):
|
||||
if expiry is not None and expiry < now:
|
||||
self.invalidate(key)
|
||||
else:
|
||||
items.append(value)
|
||||
return items
|
||||
|
||||
|
||||
class GuildCache(Cache["Guild"]):
|
||||
"""Cache specifically for :class:`Guild` objects."""
|
||||
|
||||
|
||||
class ChannelCache(Cache["Channel"]):
|
||||
"""Cache specifically for :class:`Channel` objects."""
|
1144
disagreement/client.py
Normal file
1144
disagreement/client.py
Normal file
File diff suppressed because it is too large
Load Diff
166
disagreement/components.py
Normal file
166
disagreement/components.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""Message component utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from .enums import ComponentType, ButtonStyle, ChannelType, TextInputStyle
|
||||
from .models import (
|
||||
ActionRow,
|
||||
Button,
|
||||
Component,
|
||||
SelectMenu,
|
||||
SelectOption,
|
||||
PartialEmoji,
|
||||
PartialEmoji,
|
||||
Section,
|
||||
TextDisplay,
|
||||
Thumbnail,
|
||||
MediaGallery,
|
||||
MediaGalleryItem,
|
||||
File,
|
||||
Separator,
|
||||
Container,
|
||||
UnfurledMediaItem,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - optional client for future use
|
||||
from .client import Client
|
||||
|
||||
|
||||
def component_factory(
|
||||
data: Dict[str, Any], client: Optional["Client"] = None
|
||||
) -> "Component":
|
||||
"""Create a component object from raw API data."""
|
||||
ctype = ComponentType(data["type"])
|
||||
|
||||
if ctype == ComponentType.ACTION_ROW:
|
||||
row = ActionRow()
|
||||
for comp in data.get("components", []):
|
||||
row.add_component(component_factory(comp, client))
|
||||
return row
|
||||
|
||||
if ctype == ComponentType.BUTTON:
|
||||
return Button(
|
||||
style=ButtonStyle(data["style"]),
|
||||
label=data.get("label"),
|
||||
emoji=PartialEmoji(data["emoji"]) if data.get("emoji") else None,
|
||||
custom_id=data.get("custom_id"),
|
||||
url=data.get("url"),
|
||||
disabled=data.get("disabled", False),
|
||||
)
|
||||
|
||||
if ctype in {
|
||||
ComponentType.STRING_SELECT,
|
||||
ComponentType.USER_SELECT,
|
||||
ComponentType.ROLE_SELECT,
|
||||
ComponentType.MENTIONABLE_SELECT,
|
||||
ComponentType.CHANNEL_SELECT,
|
||||
}:
|
||||
options = [
|
||||
SelectOption(
|
||||
label=o["label"],
|
||||
value=o["value"],
|
||||
description=o.get("description"),
|
||||
emoji=PartialEmoji(o["emoji"]) if o.get("emoji") else None,
|
||||
default=o.get("default", False),
|
||||
)
|
||||
for o in data.get("options", [])
|
||||
]
|
||||
channel_types = None
|
||||
if ctype == ComponentType.CHANNEL_SELECT and data.get("channel_types"):
|
||||
channel_types = [ChannelType(ct) for ct in data.get("channel_types", [])]
|
||||
|
||||
return SelectMenu(
|
||||
custom_id=data["custom_id"],
|
||||
options=options,
|
||||
placeholder=data.get("placeholder"),
|
||||
min_values=data.get("min_values", 1),
|
||||
max_values=data.get("max_values", 1),
|
||||
disabled=data.get("disabled", False),
|
||||
channel_types=channel_types,
|
||||
type=ctype,
|
||||
)
|
||||
|
||||
if ctype == ComponentType.TEXT_INPUT:
|
||||
from .ui.modal import TextInput
|
||||
|
||||
return TextInput(
|
||||
label=data.get("label", ""),
|
||||
custom_id=data.get("custom_id"),
|
||||
style=TextInputStyle(data.get("style", TextInputStyle.SHORT.value)),
|
||||
placeholder=data.get("placeholder"),
|
||||
required=data.get("required", True),
|
||||
min_length=data.get("min_length"),
|
||||
max_length=data.get("max_length"),
|
||||
)
|
||||
|
||||
if ctype == ComponentType.SECTION:
|
||||
# The components in a section can only be TextDisplay
|
||||
section_components = []
|
||||
for c in data.get("components", []):
|
||||
comp = component_factory(c, client)
|
||||
if isinstance(comp, TextDisplay):
|
||||
section_components.append(comp)
|
||||
|
||||
accessory = None
|
||||
if data.get("accessory"):
|
||||
acc_comp = component_factory(data["accessory"], client)
|
||||
if isinstance(acc_comp, (Thumbnail, Button)):
|
||||
accessory = acc_comp
|
||||
|
||||
return Section(
|
||||
components=section_components,
|
||||
accessory=accessory,
|
||||
id=data.get("id"),
|
||||
)
|
||||
|
||||
if ctype == ComponentType.TEXT_DISPLAY:
|
||||
return TextDisplay(content=data["content"], id=data.get("id"))
|
||||
|
||||
if ctype == ComponentType.THUMBNAIL:
|
||||
return Thumbnail(
|
||||
media=UnfurledMediaItem(**data["media"]),
|
||||
description=data.get("description"),
|
||||
spoiler=data.get("spoiler", False),
|
||||
id=data.get("id"),
|
||||
)
|
||||
|
||||
if ctype == ComponentType.MEDIA_GALLERY:
|
||||
return MediaGallery(
|
||||
items=[
|
||||
MediaGalleryItem(
|
||||
media=UnfurledMediaItem(**i["media"]),
|
||||
description=i.get("description"),
|
||||
spoiler=i.get("spoiler", False),
|
||||
)
|
||||
for i in data.get("items", [])
|
||||
],
|
||||
id=data.get("id"),
|
||||
)
|
||||
|
||||
if ctype == ComponentType.FILE:
|
||||
return File(
|
||||
file=UnfurledMediaItem(**data["file"]),
|
||||
spoiler=data.get("spoiler", False),
|
||||
id=data.get("id"),
|
||||
)
|
||||
|
||||
if ctype == ComponentType.SEPARATOR:
|
||||
return Separator(
|
||||
divider=data.get("divider", True),
|
||||
spacing=data.get("spacing", 1),
|
||||
id=data.get("id"),
|
||||
)
|
||||
|
||||
if ctype == ComponentType.CONTAINER:
|
||||
return Container(
|
||||
components=[
|
||||
component_factory(c, client) for c in data.get("components", [])
|
||||
],
|
||||
accent_color=data.get("accent_color"),
|
||||
spoiler=data.get("spoiler", False),
|
||||
id=data.get("id"),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported component type: {ctype}")
|
357
disagreement/enums.py
Normal file
357
disagreement/enums.py
Normal file
@ -0,0 +1,357 @@
|
||||
# disagreement/enums.py
|
||||
|
||||
"""
|
||||
Enums for Discord constants.
|
||||
"""
|
||||
|
||||
from enum import IntEnum, Enum # Import Enum
|
||||
|
||||
|
||||
class GatewayOpcode(IntEnum):
|
||||
"""Represents a Discord Gateway Opcode."""
|
||||
|
||||
DISPATCH = 0
|
||||
HEARTBEAT = 1
|
||||
IDENTIFY = 2
|
||||
PRESENCE_UPDATE = 3
|
||||
VOICE_STATE_UPDATE = 4
|
||||
RESUME = 6
|
||||
RECONNECT = 7
|
||||
REQUEST_GUILD_MEMBERS = 8
|
||||
INVALID_SESSION = 9
|
||||
HELLO = 10
|
||||
HEARTBEAT_ACK = 11
|
||||
|
||||
|
||||
class GatewayIntent(IntEnum):
|
||||
"""Represents a Discord Gateway Intent bit.
|
||||
|
||||
Intents are used to subscribe to specific groups of events from the Gateway.
|
||||
"""
|
||||
|
||||
GUILDS = 1 << 0
|
||||
GUILD_MEMBERS = 1 << 1 # Privileged
|
||||
GUILD_MODERATION = 1 << 2 # Formerly GUILD_BANS
|
||||
GUILD_EMOJIS_AND_STICKERS = 1 << 3
|
||||
GUILD_INTEGRATIONS = 1 << 4
|
||||
GUILD_WEBHOOKS = 1 << 5
|
||||
GUILD_INVITES = 1 << 6
|
||||
GUILD_VOICE_STATES = 1 << 7
|
||||
GUILD_PRESENCES = 1 << 8 # Privileged
|
||||
GUILD_MESSAGES = 1 << 9
|
||||
GUILD_MESSAGE_REACTIONS = 1 << 10
|
||||
GUILD_MESSAGE_TYPING = 1 << 11
|
||||
DIRECT_MESSAGES = 1 << 12
|
||||
DIRECT_MESSAGE_REACTIONS = 1 << 13
|
||||
DIRECT_MESSAGE_TYPING = 1 << 14
|
||||
MESSAGE_CONTENT = 1 << 15 # Privileged (as of Aug 31, 2022)
|
||||
GUILD_SCHEDULED_EVENTS = 1 << 16
|
||||
AUTO_MODERATION_CONFIGURATION = 1 << 20
|
||||
AUTO_MODERATION_EXECUTION = 1 << 21
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> int:
|
||||
"""Returns default intents (excluding privileged ones like members, presences, message content)."""
|
||||
return (
|
||||
cls.GUILDS
|
||||
| cls.GUILD_MODERATION
|
||||
| cls.GUILD_EMOJIS_AND_STICKERS
|
||||
| cls.GUILD_INTEGRATIONS
|
||||
| cls.GUILD_WEBHOOKS
|
||||
| cls.GUILD_INVITES
|
||||
| cls.GUILD_VOICE_STATES
|
||||
| cls.GUILD_MESSAGES
|
||||
| cls.GUILD_MESSAGE_REACTIONS
|
||||
| cls.GUILD_MESSAGE_TYPING
|
||||
| cls.DIRECT_MESSAGES
|
||||
| cls.DIRECT_MESSAGE_REACTIONS
|
||||
| cls.DIRECT_MESSAGE_TYPING
|
||||
| cls.GUILD_SCHEDULED_EVENTS
|
||||
| cls.AUTO_MODERATION_CONFIGURATION
|
||||
| cls.AUTO_MODERATION_EXECUTION
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> int:
|
||||
"""Returns all intents, including privileged ones. Use with caution."""
|
||||
val = 0
|
||||
for intent in cls:
|
||||
val |= intent.value
|
||||
return val
|
||||
|
||||
@classmethod
|
||||
def privileged(cls) -> int:
|
||||
"""Returns a bitmask of all privileged intents."""
|
||||
return cls.GUILD_MEMBERS | cls.GUILD_PRESENCES | cls.MESSAGE_CONTENT
|
||||
|
||||
|
||||
# --- Application Command Enums ---
|
||||
|
||||
|
||||
class ApplicationCommandType(IntEnum):
|
||||
"""Type of application command."""
|
||||
|
||||
CHAT_INPUT = 1
|
||||
USER = 2
|
||||
MESSAGE = 3
|
||||
PRIMARY_ENTRY_POINT = 4
|
||||
|
||||
|
||||
class ApplicationCommandOptionType(IntEnum):
|
||||
"""Type of application command option."""
|
||||
|
||||
SUB_COMMAND = 1
|
||||
SUB_COMMAND_GROUP = 2
|
||||
STRING = 3
|
||||
INTEGER = 4 # Any integer between -2^53 and 2^53
|
||||
BOOLEAN = 5
|
||||
USER = 6
|
||||
CHANNEL = 7 # Includes all channel types + categories
|
||||
ROLE = 8
|
||||
MENTIONABLE = 9 # Includes users and roles
|
||||
NUMBER = 10 # Any double between -2^53 and 2^53
|
||||
ATTACHMENT = 11
|
||||
|
||||
|
||||
class InteractionType(IntEnum):
|
||||
"""Type of interaction."""
|
||||
|
||||
PING = 1
|
||||
APPLICATION_COMMAND = 2
|
||||
MESSAGE_COMPONENT = 3
|
||||
APPLICATION_COMMAND_AUTOCOMPLETE = 4
|
||||
MODAL_SUBMIT = 5
|
||||
|
||||
|
||||
class InteractionCallbackType(IntEnum):
|
||||
"""Type of interaction callback."""
|
||||
|
||||
PONG = 1
|
||||
CHANNEL_MESSAGE_WITH_SOURCE = 4
|
||||
DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE = 5
|
||||
DEFERRED_UPDATE_MESSAGE = 6
|
||||
UPDATE_MESSAGE = 7
|
||||
APPLICATION_COMMAND_AUTOCOMPLETE_RESULT = 8
|
||||
MODAL = 9 # Response to send a modal
|
||||
|
||||
|
||||
class IntegrationType(IntEnum):
|
||||
"""
|
||||
Installation context(s) where the command is available,
|
||||
only for globally-scoped commands.
|
||||
"""
|
||||
|
||||
GUILD_INSTALL = (
|
||||
0 # Command is available when the app is installed to a guild (default)
|
||||
)
|
||||
USER_INSTALL = 1 # Command is available when the app is installed to a user
|
||||
|
||||
|
||||
class InteractionContextType(IntEnum):
|
||||
"""
|
||||
Interaction context(s) where the command can be used,
|
||||
only for globally-scoped commands.
|
||||
"""
|
||||
|
||||
GUILD = 0 # Command can be used in guilds
|
||||
BOT_DM = 1 # Command can be used in DMs with the app's bot user
|
||||
PRIVATE_CHANNEL = 2 # Command can be used in Group DMs and DMs (requires USER_INSTALL integration_type)
|
||||
|
||||
|
||||
class MessageFlags(IntEnum):
|
||||
"""Represents the flags of a message."""
|
||||
|
||||
CROSSPOSTED = 1 << 0
|
||||
IS_CROSSPOST = 1 << 1
|
||||
SUPPRESS_EMBEDS = 1 << 2
|
||||
SOURCE_MESSAGE_DELETED = 1 << 3
|
||||
URGENT = 1 << 4
|
||||
HAS_THREAD = 1 << 5
|
||||
EPHEMERAL = 1 << 6
|
||||
LOADING = 1 << 7
|
||||
FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8
|
||||
SUPPRESS_NOTIFICATIONS = (
|
||||
1 << 12
|
||||
) # Discord specific, was previously 1 << 4 (IS_VOICE_MESSAGE)
|
||||
IS_COMPONENTS_V2 = 1 << 15
|
||||
|
||||
|
||||
# --- Guild Enums ---
|
||||
|
||||
|
||||
class VerificationLevel(IntEnum):
|
||||
"""Guild verification level."""
|
||||
|
||||
NONE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
VERY_HIGH = 4
|
||||
|
||||
|
||||
class MessageNotificationLevel(IntEnum):
|
||||
"""Default message notification level for a guild."""
|
||||
|
||||
ALL_MESSAGES = 0
|
||||
ONLY_MENTIONS = 1
|
||||
|
||||
|
||||
class ExplicitContentFilterLevel(IntEnum):
|
||||
"""Explicit content filter level for a guild."""
|
||||
|
||||
DISABLED = 0
|
||||
MEMBERS_WITHOUT_ROLES = 1
|
||||
ALL_MEMBERS = 2
|
||||
|
||||
|
||||
class MFALevel(IntEnum):
|
||||
"""Multi-Factor Authentication level for a guild."""
|
||||
|
||||
NONE = 0
|
||||
ELEVATED = 1
|
||||
|
||||
|
||||
class GuildNSFWLevel(IntEnum):
|
||||
"""NSFW level of a guild."""
|
||||
|
||||
DEFAULT = 0
|
||||
EXPLICIT = 1
|
||||
SAFE = 2
|
||||
AGE_RESTRICTED = 3
|
||||
|
||||
|
||||
class PremiumTier(IntEnum):
|
||||
"""Guild premium tier (boost level)."""
|
||||
|
||||
NONE = 0
|
||||
TIER_1 = 1
|
||||
TIER_2 = 2
|
||||
TIER_3 = 3
|
||||
|
||||
|
||||
class GuildFeature(str, Enum): # Changed from IntEnum to Enum
|
||||
"""Features that a guild can have.
|
||||
|
||||
Note: This is not an exhaustive list and Discord may add more.
|
||||
Using str as a base allows for unknown features to be stored as strings.
|
||||
"""
|
||||
|
||||
ANIMATED_BANNER = "ANIMATED_BANNER"
|
||||
ANIMATED_ICON = "ANIMATED_ICON"
|
||||
APPLICATION_COMMAND_PERMISSIONS_V2 = "APPLICATION_COMMAND_PERMISSIONS_V2"
|
||||
AUTO_MODERATION = "AUTO_MODERATION"
|
||||
BANNER = "BANNER"
|
||||
COMMUNITY = "COMMUNITY"
|
||||
CREATOR_MONETIZABLE_PROVISIONAL = "CREATOR_MONETIZABLE_PROVISIONAL"
|
||||
CREATOR_STORE_PAGE = "CREATOR_STORE_PAGE"
|
||||
DEVELOPER_SUPPORT_SERVER = "DEVELOPER_SUPPORT_SERVER"
|
||||
DISCOVERABLE = "DISCOVERABLE"
|
||||
FEATURABLE = "FEATURABLE"
|
||||
INVITES_DISABLED = "INVITES_DISABLED"
|
||||
INVITE_SPLASH = "INVITE_SPLASH"
|
||||
MEMBER_VERIFICATION_GATE_ENABLED = "MEMBER_VERIFICATION_GATE_ENABLED"
|
||||
MORE_STICKERS = "MORE_STICKERS"
|
||||
NEWS = "NEWS"
|
||||
PARTNERED = "PARTNERED"
|
||||
PREVIEW_ENABLED = "PREVIEW_ENABLED"
|
||||
RAID_ALERTS_DISABLED = "RAID_ALERTS_DISABLED"
|
||||
ROLE_ICONS = "ROLE_ICONS"
|
||||
ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE = (
|
||||
"ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE"
|
||||
)
|
||||
ROLE_SUBSCRIPTIONS_ENABLED = "ROLE_SUBSCRIPTIONS_ENABLED"
|
||||
TICKETED_EVENTS_ENABLED = "TICKETED_EVENTS_ENABLED"
|
||||
VANITY_URL = "VANITY_URL"
|
||||
VERIFIED = "VERIFIED"
|
||||
VIP_REGIONS = "VIP_REGIONS"
|
||||
WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED"
|
||||
# Add more as they become known or needed
|
||||
|
||||
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
|
||||
@classmethod
|
||||
def _missing_(cls, value): # type: ignore
|
||||
return str(value)
|
||||
|
||||
|
||||
# --- Channel Enums ---
|
||||
|
||||
|
||||
class ChannelType(IntEnum):
|
||||
"""Type of channel."""
|
||||
|
||||
GUILD_TEXT = 0 # a text channel within a server
|
||||
DM = 1 # a direct message between users
|
||||
GUILD_VOICE = 2 # a voice channel within a server
|
||||
GROUP_DM = 3 # a direct message between multiple users
|
||||
GUILD_CATEGORY = 4 # an organizational category that contains up to 50 channels
|
||||
GUILD_ANNOUNCEMENT = 5 # a channel that users can follow and crosspost into their own server (formerly GUILD_NEWS)
|
||||
ANNOUNCEMENT_THREAD = (
|
||||
10 # a temporary sub-channel within a GUILD_ANNOUNCEMENT channel
|
||||
)
|
||||
PUBLIC_THREAD = (
|
||||
11 # a temporary sub-channel within a GUILD_TEXT or GUILD_ANNOUNCEMENT channel
|
||||
)
|
||||
PRIVATE_THREAD = 12 # a temporary sub-channel within a GUILD_TEXT channel that is only viewable by those invited and those with the MANAGE_THREADS permission
|
||||
GUILD_STAGE_VOICE = (
|
||||
13 # a voice channel for hosting events with speakers and audiences
|
||||
)
|
||||
GUILD_DIRECTORY = 14 # a channel in a hub containing the listed servers
|
||||
GUILD_FORUM = 15 # (Still in development) a channel that can only contain threads
|
||||
GUILD_MEDIA = 16 # (Still in development) a channel that can only contain media
|
||||
|
||||
|
||||
class OverwriteType(IntEnum):
|
||||
"""Type of target for a permission overwrite."""
|
||||
|
||||
ROLE = 0
|
||||
MEMBER = 1
|
||||
|
||||
|
||||
# --- Component Enums ---
|
||||
|
||||
|
||||
class ComponentType(IntEnum):
|
||||
"""Type of message component."""
|
||||
|
||||
ACTION_ROW = 1
|
||||
BUTTON = 2
|
||||
STRING_SELECT = 3 # Formerly SELECT_MENU
|
||||
TEXT_INPUT = 4
|
||||
USER_SELECT = 5
|
||||
ROLE_SELECT = 6
|
||||
MENTIONABLE_SELECT = 7
|
||||
CHANNEL_SELECT = 8
|
||||
SECTION = 9
|
||||
TEXT_DISPLAY = 10
|
||||
THUMBNAIL = 11
|
||||
MEDIA_GALLERY = 12
|
||||
FILE = 13
|
||||
SEPARATOR = 14
|
||||
CONTAINER = 17
|
||||
|
||||
|
||||
class ButtonStyle(IntEnum):
|
||||
"""Style of a button component."""
|
||||
|
||||
# Blurple
|
||||
PRIMARY = 1
|
||||
# Grey
|
||||
SECONDARY = 2
|
||||
# Green
|
||||
SUCCESS = 3
|
||||
# Red
|
||||
DANGER = 4
|
||||
# Grey, navigates to a URL
|
||||
LINK = 5
|
||||
|
||||
|
||||
class TextInputStyle(IntEnum):
|
||||
"""Style of a text input component."""
|
||||
|
||||
SHORT = 1
|
||||
PARAGRAPH = 2
|
||||
|
||||
|
||||
# Example of how you might combine intents:
|
||||
# intents = GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
|
||||
# client = Client(token="YOUR_TOKEN", intents=intents)
|
33
disagreement/error_handler.py
Normal file
33
disagreement/error_handler.py
Normal file
@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from .logging_config import setup_logging
|
||||
|
||||
|
||||
def setup_global_error_handler(
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
) -> None:
|
||||
"""Configure a basic global error handler for the provided loop.
|
||||
|
||||
The handler logs unhandled exceptions so they don't crash the bot.
|
||||
"""
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if not logging.getLogger().hasHandlers():
|
||||
setup_logging(logging.ERROR)
|
||||
|
||||
def handle_exception(loop: asyncio.AbstractEventLoop, context: dict) -> None:
|
||||
exception = context.get("exception")
|
||||
if exception:
|
||||
logging.error("Unhandled exception in event loop: %s", exception)
|
||||
traceback.print_exception(
|
||||
type(exception), exception, exception.__traceback__
|
||||
)
|
||||
else:
|
||||
message = context.get("message")
|
||||
logging.error("Event loop error: %s", message)
|
||||
|
||||
loop.set_exception_handler(handle_exception)
|
112
disagreement/errors.py
Normal file
112
disagreement/errors.py
Normal file
@ -0,0 +1,112 @@
|
||||
# disagreement/errors.py
|
||||
|
||||
"""
|
||||
Custom exceptions for the Disagreement library.
|
||||
"""
|
||||
|
||||
from typing import Optional, Any # Add Optional and Any here
|
||||
|
||||
|
||||
class DisagreementException(Exception):
|
||||
"""Base exception class for all errors raised by this library."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HTTPException(DisagreementException):
|
||||
"""Exception raised for HTTP-related errors.
|
||||
|
||||
Attributes:
|
||||
response: The aiohttp response object, if available.
|
||||
status: The HTTP status code.
|
||||
text: The response text, if available.
|
||||
error_code: Discord specific error code, if available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, response=None, message=None, *, status=None, text=None, error_code=None
|
||||
):
|
||||
self.response = response
|
||||
self.status = status or (response.status if response else None)
|
||||
self.text = text or (
|
||||
response.text if response else None
|
||||
) # Or await response.text() if in async context
|
||||
self.error_code = error_code
|
||||
|
||||
full_message = f"HTTP {self.status}"
|
||||
if message:
|
||||
full_message += f": {message}"
|
||||
elif self.text:
|
||||
full_message += f": {self.text}"
|
||||
if self.error_code:
|
||||
full_message += f" (Discord Error Code: {self.error_code})"
|
||||
|
||||
super().__init__(full_message)
|
||||
|
||||
|
||||
class GatewayException(DisagreementException):
|
||||
"""Exception raised for errors related to the Discord Gateway connection or protocol."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(DisagreementException):
|
||||
"""Exception raised for authentication failures (e.g., invalid token)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(HTTPException):
|
||||
"""
|
||||
Exception raised when a rate limit is encountered.
|
||||
|
||||
Attributes:
|
||||
retry_after (float): The number of seconds to wait before retrying.
|
||||
is_global (bool): Whether this is a global rate limit.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, response, message=None, *, retry_after: float, is_global: bool = False
|
||||
):
|
||||
self.retry_after = retry_after
|
||||
self.is_global = is_global
|
||||
super().__init__(
|
||||
response,
|
||||
message
|
||||
or f"Rate limited. Retry after: {retry_after}s. Global: {is_global}",
|
||||
)
|
||||
|
||||
|
||||
# You can add more specific exceptions as needed, e.g.:
|
||||
# class NotFound(HTTPException):
|
||||
# """Raised for 404 Not Found errors."""
|
||||
# pass
|
||||
|
||||
# class Forbidden(HTTPException):
|
||||
# """Raised for 403 Forbidden errors."""
|
||||
# pass
|
||||
|
||||
|
||||
class AppCommandError(DisagreementException):
|
||||
"""Base exception for application command related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AppCommandOptionConversionError(AppCommandError):
|
||||
"""Exception raised when an application command option fails to convert."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
option_name: Optional[str] = None,
|
||||
original_value: Any = None,
|
||||
):
|
||||
self.option_name = option_name
|
||||
self.original_value = original_value
|
||||
full_message = message
|
||||
if option_name:
|
||||
full_message = f"Failed to convert option '{option_name}': {message}"
|
||||
if original_value is not None:
|
||||
full_message += f" (Original value: '{original_value}')"
|
||||
super().__init__(full_message)
|
243
disagreement/event_dispatcher.py
Normal file
243
disagreement/event_dispatcher.py
Normal file
@ -0,0 +1,243 @@
|
||||
# disagreement/event_dispatcher.py
|
||||
|
||||
"""
|
||||
Event dispatcher for handling Discord Gateway events.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Callable,
|
||||
Coroutine,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from .models import Message, User # Assuming User might be part of other events
|
||||
from .errors import DisagreementException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client # For type hinting to avoid circular imports
|
||||
from .interactions import Interaction
|
||||
|
||||
# Type alias for an event listener
|
||||
EventListener = Callable[..., Awaitable[None]]
|
||||
|
||||
|
||||
class EventDispatcher:
|
||||
"""
|
||||
Manages registration and dispatching of event listeners.
|
||||
"""
|
||||
|
||||
def __init__(self, client_instance: "Client"):
|
||||
self._client: "Client" = client_instance
|
||||
self._listeners: Dict[str, List[EventListener]] = defaultdict(list)
|
||||
self._waiters: Dict[
|
||||
str, List[tuple[asyncio.Future, Optional[Callable[[Any], bool]]]]
|
||||
] = defaultdict(list)
|
||||
self.on_dispatch_error: Optional[
|
||||
Callable[[str, Exception, EventListener], Awaitable[None]]
|
||||
] = None
|
||||
# Pre-defined parsers for specific event types to convert raw data to models
|
||||
self._event_parsers: Dict[str, Callable[[Dict[str, Any]], Any]] = {
|
||||
"MESSAGE_CREATE": self._parse_message_create,
|
||||
"INTERACTION_CREATE": self._parse_interaction_create,
|
||||
"GUILD_CREATE": self._parse_guild_create,
|
||||
"CHANNEL_CREATE": self._parse_channel_create,
|
||||
"PRESENCE_UPDATE": self._parse_presence_update,
|
||||
"TYPING_START": self._parse_typing_start,
|
||||
}
|
||||
|
||||
def _parse_message_create(self, data: Dict[str, Any]) -> Message:
|
||||
"""Parses raw MESSAGE_CREATE data into a Message object."""
|
||||
return self._client.parse_message(data)
|
||||
|
||||
def _parse_interaction_create(self, data: Dict[str, Any]) -> "Interaction":
|
||||
"""Parses raw INTERACTION_CREATE data into an Interaction object."""
|
||||
from .interactions import Interaction
|
||||
|
||||
return Interaction(data=data, client_instance=self._client)
|
||||
|
||||
def _parse_guild_create(self, data: Dict[str, Any]):
|
||||
"""Parses raw GUILD_CREATE data into a Guild object."""
|
||||
|
||||
return self._client.parse_guild(data)
|
||||
|
||||
def _parse_channel_create(self, data: Dict[str, Any]):
|
||||
"""Parses raw CHANNEL_CREATE data into a Channel object."""
|
||||
|
||||
return self._client.parse_channel(data)
|
||||
|
||||
def _parse_presence_update(self, data: Dict[str, Any]):
|
||||
"""Parses raw PRESENCE_UPDATE data into a PresenceUpdate object."""
|
||||
|
||||
from .models import PresenceUpdate
|
||||
|
||||
return PresenceUpdate(data, client_instance=self._client)
|
||||
|
||||
def _parse_typing_start(self, data: Dict[str, Any]):
|
||||
"""Parses raw TYPING_START data into a TypingStart object."""
|
||||
|
||||
from .models import TypingStart
|
||||
|
||||
return TypingStart(data, client_instance=self._client)
|
||||
|
||||
# Potentially add _parse_user for events that directly provide a full user object
|
||||
# def _parse_user_update(self, data: Dict[str, Any]) -> User:
|
||||
# return User(data=data)
|
||||
|
||||
def register(self, event_name: str, coro: EventListener):
|
||||
"""
|
||||
Registers a coroutine function to listen for a specific event.
|
||||
|
||||
Args:
|
||||
event_name (str): The name of the event (e.g., 'MESSAGE_CREATE').
|
||||
coro (Callable): The coroutine function to call when the event occurs.
|
||||
It should accept arguments appropriate for the event.
|
||||
|
||||
Raises:
|
||||
TypeError: If the provided callback is not a coroutine function.
|
||||
"""
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(
|
||||
f"Event listener for '{event_name}' must be a coroutine function (async def)."
|
||||
)
|
||||
|
||||
# Normalize event name, e.g., 'on_message' -> 'MESSAGE_CREATE'
|
||||
# For now, we assume event_name is already the Discord event type string.
|
||||
# If using decorators like @client.on_message, the decorator would handle this mapping.
|
||||
self._listeners[event_name.upper()].append(coro)
|
||||
|
||||
def unregister(self, event_name: str, coro: EventListener):
|
||||
"""
|
||||
Unregisters a coroutine function from an event.
|
||||
|
||||
Args:
|
||||
event_name (str): The name of the event.
|
||||
coro (Callable): The coroutine function to unregister.
|
||||
"""
|
||||
event_name_upper = event_name.upper()
|
||||
if event_name_upper in self._listeners:
|
||||
try:
|
||||
self._listeners[event_name_upper].remove(coro)
|
||||
except ValueError:
|
||||
pass # Listener not in list
|
||||
|
||||
def add_waiter(
|
||||
self,
|
||||
event_name: str,
|
||||
future: asyncio.Future,
|
||||
check: Optional[Callable[[Any], bool]] = None,
|
||||
) -> None:
|
||||
self._waiters[event_name.upper()].append((future, check))
|
||||
|
||||
def remove_waiter(self, event_name: str, future: asyncio.Future) -> None:
|
||||
waiters = self._waiters.get(event_name.upper())
|
||||
if not waiters:
|
||||
return
|
||||
self._waiters[event_name.upper()] = [
|
||||
(f, c) for f, c in waiters if f is not future
|
||||
]
|
||||
if not self._waiters[event_name.upper()]:
|
||||
self._waiters.pop(event_name.upper(), None)
|
||||
|
||||
def _resolve_waiters(self, event_name: str, data: Any) -> None:
|
||||
waiters = self._waiters.get(event_name)
|
||||
if not waiters:
|
||||
return
|
||||
to_remove: List[tuple[asyncio.Future, Optional[Callable[[Any], bool]]]] = []
|
||||
for future, check in waiters:
|
||||
if future.cancelled():
|
||||
to_remove.append((future, check))
|
||||
continue
|
||||
try:
|
||||
if check is None or check(data):
|
||||
future.set_result(data)
|
||||
to_remove.append((future, check))
|
||||
except Exception as exc:
|
||||
future.set_exception(exc)
|
||||
to_remove.append((future, check))
|
||||
for item in to_remove:
|
||||
if item in waiters:
|
||||
waiters.remove(item)
|
||||
if not waiters:
|
||||
self._waiters.pop(event_name, None)
|
||||
|
||||
async def dispatch(self, event_name: str, raw_data: Dict[str, Any]):
|
||||
"""
|
||||
Dispatches an event to all registered listeners.
|
||||
|
||||
Args:
|
||||
event_name (str): The name of the event (e.g., 'MESSAGE_CREATE').
|
||||
raw_data (Dict[str, Any]): The raw data payload from the Discord Gateway for this event.
|
||||
"""
|
||||
event_name_upper = event_name.upper()
|
||||
listeners = self._listeners.get(event_name_upper)
|
||||
|
||||
if not listeners:
|
||||
# print(f"No listeners for event {event_name_upper}")
|
||||
return
|
||||
|
||||
parsed_data: Any = raw_data
|
||||
if event_name_upper in self._event_parsers:
|
||||
try:
|
||||
parser = self._event_parsers[event_name_upper]
|
||||
parsed_data = parser(raw_data)
|
||||
except Exception as e:
|
||||
print(f"Error parsing event data for {event_name_upper}: {e}")
|
||||
# Optionally, dispatch with raw_data or raise, or log more formally
|
||||
# For now, we'll proceed to dispatch with raw_data if parsing fails,
|
||||
# or just log and return if parsed_data is critical.
|
||||
# Let's assume if a parser exists, its output is critical.
|
||||
return
|
||||
|
||||
self._resolve_waiters(event_name_upper, parsed_data)
|
||||
# print(f"Dispatching event {event_name_upper} with data: {parsed_data} to {len(listeners)} listeners.")
|
||||
for listener in listeners:
|
||||
try:
|
||||
# Inspect the listener to see how many arguments it expects
|
||||
sig = inspect.signature(listener)
|
||||
num_params = len(sig.parameters)
|
||||
|
||||
if num_params == 0: # Listener takes no arguments
|
||||
await listener()
|
||||
elif (
|
||||
num_params == 1
|
||||
): # Listener takes one argument (the parsed data or model)
|
||||
await listener(parsed_data)
|
||||
# elif num_params == 2 and event_name_upper == "MESSAGE_CREATE": # Special case for (client, message)
|
||||
# await listener(self._client, parsed_data) # This might be too specific here
|
||||
else:
|
||||
# Fallback or error if signature doesn't match expected patterns
|
||||
# For now, assume one arg is the most common for parsed data.
|
||||
# Or, if you want to be strict:
|
||||
print(
|
||||
f"Warning: Listener {listener.__name__} for {event_name_upper} has an unhandled number of parameters ({num_params}). Skipping or attempting with one arg."
|
||||
)
|
||||
if num_params > 0: # Try with one arg if it takes any
|
||||
await listener(parsed_data)
|
||||
|
||||
except Exception as e:
|
||||
callback = self.on_dispatch_error
|
||||
if callback is not None:
|
||||
try:
|
||||
await callback(event_name_upper, e, listener)
|
||||
|
||||
except Exception as hook_error:
|
||||
print(f"Error in on_dispatch_error hook itself: {hook_error}")
|
||||
else:
|
||||
# Default error handling if no hook is set
|
||||
print(
|
||||
f"Error in event listener {listener.__name__} for {event_name_upper}: {e}"
|
||||
)
|
||||
if hasattr(self._client, "on_error"):
|
||||
try:
|
||||
await self._client.on_error(event_name_upper, e, listener)
|
||||
except Exception as client_err_e:
|
||||
print(f"Error in client.on_error itself: {client_err_e}")
|
46
disagreement/ext/app_commands/__init__.py
Normal file
46
disagreement/ext/app_commands/__init__.py
Normal file
@ -0,0 +1,46 @@
|
||||
# disagreement/ext/app_commands/__init__.py
|
||||
|
||||
"""
|
||||
Application Commands Extension for Disagreement.
|
||||
|
||||
This package provides the framework for creating and handling
|
||||
Discord Application Commands (slash commands, user commands, message commands).
|
||||
"""
|
||||
|
||||
from .commands import (
|
||||
AppCommand,
|
||||
SlashCommand,
|
||||
UserCommand,
|
||||
MessageCommand,
|
||||
AppCommandGroup,
|
||||
)
|
||||
from .decorators import (
|
||||
slash_command,
|
||||
user_command,
|
||||
message_command,
|
||||
hybrid_command,
|
||||
group,
|
||||
subcommand,
|
||||
subcommand_group,
|
||||
OptionMetadata,
|
||||
)
|
||||
from .context import AppCommandContext
|
||||
|
||||
# from .handler import AppCommandHandler # Will be imported when defined
|
||||
|
||||
__all__ = [
|
||||
"AppCommand",
|
||||
"SlashCommand",
|
||||
"UserCommand",
|
||||
"MessageCommand",
|
||||
"AppCommandGroup", # To be defined
|
||||
"slash_command",
|
||||
"user_command",
|
||||
"message_command",
|
||||
"hybrid_command",
|
||||
"group",
|
||||
"subcommand",
|
||||
"subcommand_group",
|
||||
"OptionMetadata",
|
||||
"AppCommandContext", # To be defined
|
||||
]
|
513
disagreement/ext/app_commands/commands.py
Normal file
513
disagreement/ext/app_commands/commands.py
Normal file
@ -0,0 +1,513 @@
|
||||
# disagreement/ext/app_commands/commands.py
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Optional, List, Dict, Any, Union, TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from disagreement.ext.commands.core import (
|
||||
Command as PrefixCommand,
|
||||
) # Alias to avoid name clash
|
||||
from disagreement.interactions import ApplicationCommandOption, Snowflake
|
||||
from disagreement.enums import (
|
||||
ApplicationCommandType,
|
||||
IntegrationType,
|
||||
InteractionContextType,
|
||||
ApplicationCommandOptionType, # Added
|
||||
)
|
||||
from disagreement.ext.commands.cog import Cog # Corrected import path
|
||||
|
||||
# Placeholder for Cog if not using the existing one or if it needs adaptation
|
||||
if not TYPE_CHECKING:
|
||||
# This dynamic Cog = Any might not be ideal if Cog is used in runtime type checks.
|
||||
# However, for type hinting purposes when TYPE_CHECKING is false, it avoids import.
|
||||
# If Cog is needed at runtime by this module (it is, for AppCommand.cog type hint),
|
||||
# it should be imported directly.
|
||||
# For now, the TYPE_CHECKING block handles the proper import for static analysis.
|
||||
# Let's ensure Cog is available at runtime if AppCommand.cog is accessed.
|
||||
# A simple way is to import it outside TYPE_CHECKING too, if it doesn't cause circularity.
|
||||
# Given its usage, a forward reference string 'Cog' might be better in AppCommand.cog type hint.
|
||||
# Let's try importing it directly for runtime, assuming no circularity with this specific module.
|
||||
try:
|
||||
from disagreement.ext.commands.cog import Cog
|
||||
except ImportError:
|
||||
Cog = Any # Fallback if direct import fails (e.g. during partial builds/tests)
|
||||
# Import PrefixCommand at runtime for HybridCommand
|
||||
try:
|
||||
from disagreement.ext.commands.core import Command as PrefixCommand
|
||||
except Exception: # pragma: no cover - safeguard against unusual import issues
|
||||
PrefixCommand = Any # type: ignore
|
||||
# Import enums used at runtime
|
||||
try:
|
||||
from disagreement.enums import (
|
||||
ApplicationCommandType,
|
||||
IntegrationType,
|
||||
InteractionContextType,
|
||||
ApplicationCommandOptionType,
|
||||
)
|
||||
from disagreement.interactions import ApplicationCommandOption, Snowflake
|
||||
except Exception: # pragma: no cover
|
||||
ApplicationCommandType = ApplicationCommandOptionType = IntegrationType = (
|
||||
InteractionContextType
|
||||
) = Any # type: ignore
|
||||
ApplicationCommandOption = Snowflake = Any # type: ignore
|
||||
else: # When TYPE_CHECKING is true, Cog and PrefixCommand are already imported above.
|
||||
pass
|
||||
|
||||
|
||||
class AppCommand:
|
||||
"""
|
||||
Base class for an application command.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the command.
|
||||
description (Optional[str]): The description of the command.
|
||||
Required for CHAT_INPUT, empty for USER and MESSAGE commands.
|
||||
callback (Callable[..., Any]): The coroutine function that will be called when the command is executed.
|
||||
type (ApplicationCommandType): The type of the application command.
|
||||
options (Optional[List[ApplicationCommandOption]]): Parameters for the command. Populated by decorators.
|
||||
guild_ids (Optional[List[Snowflake]]): List of guild IDs where this command is active. None for global.
|
||||
default_member_permissions (Optional[str]): Bitwise permissions required by default for users to run the command.
|
||||
nsfw (bool): Whether the command is age-restricted.
|
||||
parent (Optional['AppCommandGroup']): The parent group if this is a subcommand.
|
||||
cog (Optional[Cog]): The cog this command belongs to, if any.
|
||||
_full_description (Optional[str]): Stores the original full description, e.g. from docstring,
|
||||
even if the payload description is different (like for User/Message commands).
|
||||
name_localizations (Optional[Dict[str, str]]): Localizations for the command's name.
|
||||
description_localizations (Optional[Dict[str, str]]): Localizations for the command's description.
|
||||
integration_types (Optional[List[IntegrationType]]): Installation contexts.
|
||||
contexts (Optional[List[InteractionContextType]]): Interaction contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
callback: Callable[..., Any],
|
||||
*,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
locale: Optional[str] = None,
|
||||
type: "ApplicationCommandType",
|
||||
guild_ids: Optional[List["Snowflake"]] = None,
|
||||
default_member_permissions: Optional[str] = None,
|
||||
nsfw: bool = False,
|
||||
parent: Optional["AppCommandGroup"] = None,
|
||||
cog: Optional[
|
||||
Any
|
||||
] = None, # Changed 'Cog' to Any to avoid runtime import issues if Cog is complex
|
||||
name_localizations: Optional[Dict[str, str]] = None,
|
||||
description_localizations: Optional[Dict[str, str]] = None,
|
||||
integration_types: Optional[List["IntegrationType"]] = None,
|
||||
contexts: Optional[List["InteractionContextType"]] = None,
|
||||
):
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise TypeError(
|
||||
"Application command callback must be a coroutine function."
|
||||
)
|
||||
|
||||
if locale:
|
||||
from disagreement import i18n
|
||||
|
||||
translate = i18n.translate
|
||||
|
||||
self.name = translate(name, locale)
|
||||
self.description = (
|
||||
translate(description, locale) if description is not None else None
|
||||
)
|
||||
else:
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.locale: Optional[str] = locale
|
||||
self.callback: Callable[..., Any] = callback
|
||||
self.type: "ApplicationCommandType" = type
|
||||
self.options: List["ApplicationCommandOption"] = [] # Populated by decorator
|
||||
self.guild_ids: Optional[List["Snowflake"]] = guild_ids
|
||||
self.default_member_permissions: Optional[str] = default_member_permissions
|
||||
self.nsfw: bool = nsfw
|
||||
self.parent: Optional["AppCommandGroup"] = parent
|
||||
self.cog: Optional[Any] = cog # Changed 'Cog' to Any
|
||||
self.name_localizations: Optional[Dict[str, str]] = name_localizations
|
||||
self.description_localizations: Optional[Dict[str, str]] = (
|
||||
description_localizations
|
||||
)
|
||||
self.integration_types: Optional[List["IntegrationType"]] = integration_types
|
||||
self.contexts: Optional[List["InteractionContextType"]] = contexts
|
||||
self._full_description: Optional[str] = (
|
||||
None # Initialized by decorator if needed
|
||||
)
|
||||
|
||||
# Signature for argument parsing by decorators/handlers
|
||||
self.params = inspect.signature(callback).parameters
|
||||
|
||||
async def invoke(
|
||||
self, context: "AppCommandContext", *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Invokes the command's callback with the given context and arguments."""
|
||||
# Similar to Command.invoke, handle cog if present
|
||||
actual_args = []
|
||||
if self.cog:
|
||||
actual_args.append(self.cog)
|
||||
actual_args.append(context)
|
||||
actual_args.extend(args)
|
||||
|
||||
await self.callback(*actual_args, **kwargs)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Converts the command to a dictionary payload for Discord API."""
|
||||
payload: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"type": self.type.value,
|
||||
# CHAT_INPUT commands require a description.
|
||||
# USER and MESSAGE commands must have an empty description in the payload if not omitted.
|
||||
# The constructor for UserCommand/MessageCommand already sets self.description to ""
|
||||
"description": (
|
||||
self.description
|
||||
if self.type == ApplicationCommandType.CHAT_INPUT
|
||||
else ""
|
||||
),
|
||||
}
|
||||
|
||||
# For CHAT_INPUT commands, options are its parameters.
|
||||
# For USER/MESSAGE commands, options should be empty or not present.
|
||||
if self.type == ApplicationCommandType.CHAT_INPUT and self.options:
|
||||
payload["options"] = [opt.to_dict() for opt in self.options]
|
||||
|
||||
if self.default_member_permissions is not None: # Can be "0" for no permissions
|
||||
payload["default_member_permissions"] = str(self.default_member_permissions)
|
||||
|
||||
# nsfw defaults to False, only include if True
|
||||
if self.nsfw:
|
||||
payload["nsfw"] = True
|
||||
|
||||
if self.name_localizations:
|
||||
payload["name_localizations"] = self.name_localizations
|
||||
|
||||
# Description localizations only apply if there's a description (CHAT_INPUT commands)
|
||||
if (
|
||||
self.type == ApplicationCommandType.CHAT_INPUT
|
||||
and self.description
|
||||
and self.description_localizations
|
||||
):
|
||||
payload["description_localizations"] = self.description_localizations
|
||||
|
||||
if self.integration_types:
|
||||
payload["integration_types"] = [it.value for it in self.integration_types]
|
||||
|
||||
if self.contexts:
|
||||
payload["contexts"] = [ict.value for ict in self.contexts]
|
||||
|
||||
# According to Discord API, guild_id is not part of this payload,
|
||||
# it's used in the URL path for guild-specific command registration.
|
||||
# However, the global command registration takes an 'application_id' in the payload,
|
||||
# but that's handled by the HTTPClient.
|
||||
|
||||
return payload
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} name='{self.name}' type={self.type!r}>"
|
||||
|
||||
|
||||
class SlashCommand(AppCommand):
|
||||
"""Represents a CHAT_INPUT (slash) command."""
|
||||
|
||||
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
|
||||
if not kwargs.get("description"):
|
||||
raise ValueError("SlashCommand requires a description.")
|
||||
super().__init__(callback, type=ApplicationCommandType.CHAT_INPUT, **kwargs)
|
||||
|
||||
|
||||
class UserCommand(AppCommand):
|
||||
"""Represents a USER context menu command."""
|
||||
|
||||
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
|
||||
# Description is not allowed by Discord API for User Commands, but can be set to empty string.
|
||||
kwargs["description"] = kwargs.get(
|
||||
"description", ""
|
||||
) # Ensure it's empty or not present in payload
|
||||
super().__init__(callback, type=ApplicationCommandType.USER, **kwargs)
|
||||
|
||||
|
||||
class MessageCommand(AppCommand):
|
||||
"""Represents a MESSAGE context menu command."""
|
||||
|
||||
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
|
||||
# Description is not allowed by Discord API for Message Commands.
|
||||
kwargs["description"] = kwargs.get("description", "")
|
||||
super().__init__(callback, type=ApplicationCommandType.MESSAGE, **kwargs)
|
||||
|
||||
|
||||
class HybridCommand(SlashCommand, PrefixCommand): # Inherit from both
|
||||
"""
|
||||
Represents a command that can be invoked as both a slash command
|
||||
and a traditional prefix-based command.
|
||||
"""
|
||||
|
||||
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
|
||||
# Initialize SlashCommand part (which calls AppCommand.__init__)
|
||||
# We need to ensure 'type' is correctly passed for AppCommand
|
||||
# kwargs for SlashCommand: name, description, guild_ids, default_member_permissions, nsfw, parent, cog, etc.
|
||||
# kwargs for PrefixCommand: name, aliases, brief, description, cog
|
||||
|
||||
# Pop prefix-specific args before passing to SlashCommand constructor
|
||||
prefix_aliases = kwargs.pop("aliases", [])
|
||||
prefix_brief = kwargs.pop("brief", None)
|
||||
# Description is used by both, AppCommand's constructor will handle it.
|
||||
# Name is used by both. Cog is used by both.
|
||||
|
||||
# Call SlashCommand's __init__
|
||||
# This will set up name, description, callback, type=CHAT_INPUT, options, etc.
|
||||
super().__init__(callback, **kwargs) # This is SlashCommand.__init__
|
||||
|
||||
# Now, explicitly initialize the PrefixCommand parts that SlashCommand didn't cover
|
||||
# or that need specific values for the prefix version.
|
||||
# PrefixCommand.__init__(self, callback, name=self.name, aliases=prefix_aliases, brief=prefix_brief, description=self.description, cog=self.cog)
|
||||
# However, PrefixCommand.__init__ also sets self.params, which AppCommand already did.
|
||||
# We need to be careful not to re-initialize things unnecessarily or incorrectly.
|
||||
# Let's manually set the distinct attributes for the PrefixCommand aspect.
|
||||
|
||||
# Attributes from PrefixCommand:
|
||||
# self.callback is already set by AppCommand
|
||||
# self.name is already set by AppCommand
|
||||
self.aliases: List[str] = (
|
||||
prefix_aliases # This was specific to HybridCommand before, now aligns with PrefixCommand
|
||||
)
|
||||
self.brief: Optional[str] = prefix_brief
|
||||
# self.description is already set by AppCommand (SlashCommand ensures it exists)
|
||||
# self.cog is already set by AppCommand
|
||||
# self.params is already set by AppCommand
|
||||
|
||||
# Ensure the MRO is handled correctly. Python's MRO (C3 linearization)
|
||||
# should call SlashCommand's __init__ then AppCommand's __init__.
|
||||
# PrefixCommand.__init__ won't be called automatically unless we explicitly call it.
|
||||
# By setting attributes directly, we avoid potential issues with multiple __init__ calls
|
||||
# if their logic overlaps too much (e.g., both trying to set self.params).
|
||||
|
||||
# We might need to override invoke if the context or argument passing differs significantly
|
||||
# between app command invocation and prefix command invocation.
|
||||
# For now, SlashCommand.invoke and PrefixCommand.invoke are separate.
|
||||
# The correct one will be called depending on how the command is dispatched.
|
||||
# The AppCommandHandler will use AppCommand.invoke (via SlashCommand).
|
||||
# The prefix CommandHandler will use PrefixCommand.invoke.
|
||||
# This seems acceptable.
|
||||
|
||||
|
||||
class AppCommandGroup:
|
||||
"""
|
||||
Represents a group of application commands (subcommands or subcommand groups).
|
||||
This itself is not directly callable but acts as a namespace.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: Optional[
|
||||
str
|
||||
] = None, # Required for top-level groups that form part of a slash command
|
||||
guild_ids: Optional[List["Snowflake"]] = None,
|
||||
parent: Optional["AppCommandGroup"] = None,
|
||||
default_member_permissions: Optional[str] = None,
|
||||
nsfw: bool = False,
|
||||
name_localizations: Optional[Dict[str, str]] = None,
|
||||
description_localizations: Optional[Dict[str, str]] = None,
|
||||
integration_types: Optional[List["IntegrationType"]] = None,
|
||||
contexts: Optional[List["InteractionContextType"]] = None,
|
||||
):
|
||||
self.name: str = name
|
||||
self.description: Optional[str] = description
|
||||
self.guild_ids: Optional[List["Snowflake"]] = guild_ids
|
||||
self.parent: Optional["AppCommandGroup"] = parent
|
||||
self.commands: Dict[str, Union[AppCommand, "AppCommandGroup"]] = {}
|
||||
self.default_member_permissions: Optional[str] = default_member_permissions
|
||||
self.nsfw: bool = nsfw
|
||||
self.name_localizations: Optional[Dict[str, str]] = name_localizations
|
||||
self.description_localizations: Optional[Dict[str, str]] = (
|
||||
description_localizations
|
||||
)
|
||||
self.integration_types: Optional[List["IntegrationType"]] = integration_types
|
||||
self.contexts: Optional[List["InteractionContextType"]] = contexts
|
||||
# A group itself doesn't have a cog directly, its commands do.
|
||||
|
||||
def add_command(self, command: Union[AppCommand, "AppCommandGroup"]) -> None:
|
||||
if command.name in self.commands:
|
||||
raise ValueError(
|
||||
f"Command or group '{command.name}' already exists in group '{self.name}'."
|
||||
)
|
||||
command.parent = self
|
||||
self.commands[command.name] = command
|
||||
|
||||
def get_command(self, name: str) -> Optional[Union[AppCommand, "AppCommandGroup"]]:
|
||||
return self.commands.get(name)
|
||||
|
||||
def command(self, *d_args: Any, **d_kwargs: Any):
|
||||
d_kwargs.setdefault("parent", self)
|
||||
from .decorators import slash_command
|
||||
|
||||
return slash_command(*d_args, **d_kwargs)
|
||||
|
||||
def group(
|
||||
self,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
sub_group = AppCommandGroup(
|
||||
name=name,
|
||||
description=description,
|
||||
parent=self,
|
||||
guild_ids=kwargs.get("guild_ids"),
|
||||
default_member_permissions=kwargs.get("default_member_permissions"),
|
||||
nsfw=kwargs.get("nsfw", False),
|
||||
name_localizations=kwargs.get("name_localizations"),
|
||||
description_localizations=kwargs.get("description_localizations"),
|
||||
integration_types=kwargs.get("integration_types"),
|
||||
contexts=kwargs.get("contexts"),
|
||||
)
|
||||
self.add_command(sub_group)
|
||||
|
||||
def decorator(func: Optional[Callable[..., Any]] = None):
|
||||
if func is not None:
|
||||
setattr(func, "__app_command_object__", sub_group)
|
||||
return sub_group
|
||||
return sub_group
|
||||
|
||||
return decorator
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AppCommandGroup name='{self.name}' commands={len(self.commands)}>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Converts the command group to a dictionary payload for Discord API.
|
||||
This represents a top-level command that has subcommands/subcommand groups.
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"type": ApplicationCommandType.CHAT_INPUT.value, # Groups are implicitly CHAT_INPUT
|
||||
"description": self.description
|
||||
or "No description provided", # Top-level groups require a description
|
||||
"options": [],
|
||||
}
|
||||
|
||||
if self.default_member_permissions is not None:
|
||||
payload["default_member_permissions"] = str(self.default_member_permissions)
|
||||
if self.nsfw:
|
||||
payload["nsfw"] = True
|
||||
if self.name_localizations:
|
||||
payload["name_localizations"] = self.name_localizations
|
||||
if (
|
||||
self.description and self.description_localizations
|
||||
): # Only if description is not empty
|
||||
payload["description_localizations"] = self.description_localizations
|
||||
if self.integration_types:
|
||||
payload["integration_types"] = [it.value for it in self.integration_types]
|
||||
if self.contexts:
|
||||
payload["contexts"] = [ict.value for ict in self.contexts]
|
||||
|
||||
# guild_ids are handled at the registration level, not in this specific payload part.
|
||||
|
||||
options_payload: List[Dict[str, Any]] = []
|
||||
for cmd_name, command_or_group in self.commands.items():
|
||||
if isinstance(command_or_group, AppCommand): # This is a Subcommand
|
||||
# Subcommands use their own options (parameters)
|
||||
sub_options = (
|
||||
[opt.to_dict() for opt in command_or_group.options]
|
||||
if command_or_group.options
|
||||
else []
|
||||
)
|
||||
option_dict = {
|
||||
"type": ApplicationCommandOptionType.SUB_COMMAND.value,
|
||||
"name": command_or_group.name,
|
||||
"description": command_or_group.description
|
||||
or "No description provided",
|
||||
"options": sub_options,
|
||||
}
|
||||
# Add localization for subcommand name and description if available
|
||||
if command_or_group.name_localizations:
|
||||
option_dict["name_localizations"] = (
|
||||
command_or_group.name_localizations
|
||||
)
|
||||
if (
|
||||
command_or_group.description
|
||||
and command_or_group.description_localizations
|
||||
):
|
||||
option_dict["description_localizations"] = (
|
||||
command_or_group.description_localizations
|
||||
)
|
||||
options_payload.append(option_dict)
|
||||
|
||||
elif isinstance(
|
||||
command_or_group, AppCommandGroup
|
||||
): # This is a Subcommand Group
|
||||
# Subcommand groups have their subcommands/groups as options
|
||||
sub_group_options: List[Dict[str, Any]] = []
|
||||
for sub_cmd_name, sub_command in command_or_group.commands.items():
|
||||
# Nested groups can only contain subcommands, not further nested groups as per Discord rules.
|
||||
# So, sub_command here must be an AppCommand.
|
||||
if isinstance(
|
||||
sub_command, AppCommand
|
||||
): # Should always be AppCommand if structure is valid
|
||||
sub_cmd_options = (
|
||||
[opt.to_dict() for opt in sub_command.options]
|
||||
if sub_command.options
|
||||
else []
|
||||
)
|
||||
sub_group_option_entry = {
|
||||
"type": ApplicationCommandOptionType.SUB_COMMAND.value,
|
||||
"name": sub_command.name,
|
||||
"description": sub_command.description
|
||||
or "No description provided",
|
||||
"options": sub_cmd_options,
|
||||
}
|
||||
# Add localization for subcommand name and description if available
|
||||
if sub_command.name_localizations:
|
||||
sub_group_option_entry["name_localizations"] = (
|
||||
sub_command.name_localizations
|
||||
)
|
||||
if (
|
||||
sub_command.description
|
||||
and sub_command.description_localizations
|
||||
):
|
||||
sub_group_option_entry["description_localizations"] = (
|
||||
sub_command.description_localizations
|
||||
)
|
||||
sub_group_options.append(sub_group_option_entry)
|
||||
# else:
|
||||
# # This case implies a group nested inside a group, which then contains another group.
|
||||
# # Discord's structure is:
|
||||
# # command -> option (SUB_COMMAND_GROUP) -> option (SUB_COMMAND) -> option (param)
|
||||
# # This should be caught by validation logic in decorators or add_command.
|
||||
# # For now, we assume valid structure where AppCommandGroup's commands are AppCommands.
|
||||
# pass
|
||||
|
||||
option_dict = {
|
||||
"type": ApplicationCommandOptionType.SUB_COMMAND_GROUP.value,
|
||||
"name": command_or_group.name,
|
||||
"description": command_or_group.description
|
||||
or "No description provided",
|
||||
"options": sub_group_options, # These are the SUB_COMMANDs
|
||||
}
|
||||
# Add localization for subcommand group name and description if available
|
||||
if command_or_group.name_localizations:
|
||||
option_dict["name_localizations"] = (
|
||||
command_or_group.name_localizations
|
||||
)
|
||||
if (
|
||||
command_or_group.description
|
||||
and command_or_group.description_localizations
|
||||
):
|
||||
option_dict["description_localizations"] = (
|
||||
command_or_group.description_localizations
|
||||
)
|
||||
options_payload.append(option_dict)
|
||||
|
||||
payload["options"] = options_payload
|
||||
return payload
|
||||
|
||||
|
||||
# Need to import asyncio for iscoroutinefunction check
|
||||
import asyncio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import AppCommandContext # For type hint in AppCommand.invoke
|
||||
|
||||
# Ensure ApplicationCommandOptionType is available for the to_dict method
|
||||
from disagreement.enums import ApplicationCommandOptionType
|
556
disagreement/ext/app_commands/context.py
Normal file
556
disagreement/ext/app_commands/context.py
Normal file
@ -0,0 +1,556 @@
|
||||
# disagreement/ext/app_commands/context.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, List, Union, Any, Dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from disagreement.client import Client
|
||||
from disagreement.interactions import (
|
||||
Interaction,
|
||||
InteractionCallbackData,
|
||||
InteractionResponsePayload,
|
||||
Snowflake,
|
||||
)
|
||||
from disagreement.enums import InteractionCallbackType, MessageFlags
|
||||
from disagreement.models import (
|
||||
User,
|
||||
Member,
|
||||
Message,
|
||||
Channel,
|
||||
ActionRow,
|
||||
)
|
||||
from disagreement.ui.view import View
|
||||
|
||||
# For full model hints, these would be imported from disagreement.models when defined:
|
||||
Embed = Any
|
||||
PartialAttachment = Any
|
||||
Guild = Any # from disagreement.models import Guild
|
||||
TextChannel = Any # from disagreement.models import TextChannel, etc.
|
||||
from .commands import AppCommand
|
||||
|
||||
from disagreement.enums import InteractionCallbackType, MessageFlags
|
||||
from disagreement.interactions import (
|
||||
Interaction,
|
||||
InteractionCallbackData,
|
||||
InteractionResponsePayload,
|
||||
Snowflake,
|
||||
)
|
||||
from disagreement.models import Message
|
||||
from disagreement.typing import Typing
|
||||
|
||||
|
||||
class AppCommandContext:
|
||||
"""
|
||||
Represents the context in which an application command is being invoked.
|
||||
Provides methods to respond to the interaction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot: "Client",
|
||||
interaction: "Interaction",
|
||||
command: Optional["AppCommand"] = None,
|
||||
):
|
||||
self.bot: "Client" = bot
|
||||
self.interaction: "Interaction" = interaction
|
||||
self.command: Optional["AppCommand"] = command # The command that was invoked
|
||||
|
||||
self._responded: bool = False
|
||||
self._deferred: bool = False
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
"""The interaction token."""
|
||||
return self.interaction.token
|
||||
|
||||
@property
|
||||
def interaction_id(self) -> "Snowflake":
|
||||
"""The interaction ID."""
|
||||
return self.interaction.id
|
||||
|
||||
@property
|
||||
def application_id(self) -> "Snowflake":
|
||||
"""The application ID of the interaction."""
|
||||
return self.interaction.application_id
|
||||
|
||||
@property
|
||||
def guild_id(self) -> Optional["Snowflake"]:
|
||||
"""The ID of the guild where the interaction occurred, if any."""
|
||||
return self.interaction.guild_id
|
||||
|
||||
@property
|
||||
def channel_id(self) -> Optional["Snowflake"]:
|
||||
"""The ID of the channel where the interaction occurred."""
|
||||
return self.interaction.channel_id
|
||||
|
||||
@property
|
||||
def author(self) -> Optional[Union["User", "Member"]]:
|
||||
"""The user or member who invoked the interaction."""
|
||||
return self.interaction.member or self.interaction.user
|
||||
|
||||
@property
|
||||
def user(self) -> Optional["User"]:
|
||||
"""The user who invoked the interaction.
|
||||
If in a guild, this is the user part of the member.
|
||||
If in a DM, this is the top-level user.
|
||||
"""
|
||||
return self.interaction.user
|
||||
|
||||
@property
|
||||
def member(self) -> Optional["Member"]:
|
||||
"""The member who invoked the interaction, if this occurred in a guild."""
|
||||
return self.interaction.member
|
||||
|
||||
@property
|
||||
def locale(self) -> Optional[str]:
|
||||
"""The selected language of the invoking user."""
|
||||
return self.interaction.locale
|
||||
|
||||
@property
|
||||
def guild_locale(self) -> Optional[str]:
|
||||
"""The guild's preferred language, if applicable."""
|
||||
return self.interaction.guild_locale
|
||||
|
||||
@property
|
||||
async def guild(self) -> Optional["Guild"]:
|
||||
"""The guild object where the interaction occurred, if available."""
|
||||
|
||||
if not self.guild_id:
|
||||
return None
|
||||
|
||||
guild = None
|
||||
if hasattr(self.bot, "get_guild"):
|
||||
guild = self.bot.get_guild(self.guild_id)
|
||||
|
||||
if not guild and hasattr(self.bot, "fetch_guild"):
|
||||
try:
|
||||
guild = await self.bot.fetch_guild(self.guild_id)
|
||||
except Exception:
|
||||
guild = None
|
||||
|
||||
return guild
|
||||
|
||||
@property
|
||||
async def channel(self) -> Optional[Any]:
|
||||
"""The channel object where the interaction occurred, if available."""
|
||||
|
||||
if not self.channel_id:
|
||||
return None
|
||||
|
||||
channel = None
|
||||
if hasattr(self.bot, "get_channel"):
|
||||
channel = self.bot.get_channel(self.channel_id)
|
||||
elif hasattr(self.bot, "_channels"):
|
||||
channel = self.bot._channels.get(self.channel_id)
|
||||
|
||||
if not channel and hasattr(self.bot, "fetch_channel"):
|
||||
try:
|
||||
channel = await self.bot.fetch_channel(self.channel_id)
|
||||
except Exception:
|
||||
channel = None
|
||||
|
||||
return channel
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
response_type: "InteractionCallbackType",
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Internal helper to send interaction responses."""
|
||||
if (
|
||||
self._responded
|
||||
and not self._deferred
|
||||
and response_type
|
||||
!= InteractionCallbackType.APPLICATION_COMMAND_AUTOCOMPLETE_RESULT
|
||||
):
|
||||
# If already responded and not deferred, subsequent responses must be followups
|
||||
# (unless it's an autocomplete result which is a special case)
|
||||
# For now, let's assume followups are handled by separate methods.
|
||||
# This logic might need refinement based on how followups are exposed.
|
||||
raise RuntimeError(
|
||||
"Interaction has already been responded to. Use send_followup()."
|
||||
)
|
||||
|
||||
callback_data = InteractionCallbackData(data) if data else None
|
||||
payload = InteractionResponsePayload(type=response_type, data=callback_data)
|
||||
|
||||
await self.bot._http.create_interaction_response(
|
||||
interaction_id=self.interaction_id,
|
||||
interaction_token=self.token,
|
||||
payload=payload,
|
||||
)
|
||||
if (
|
||||
response_type
|
||||
!= InteractionCallbackType.APPLICATION_COMMAND_AUTOCOMPLETE_RESULT
|
||||
):
|
||||
self._responded = True
|
||||
if (
|
||||
response_type
|
||||
== InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE
|
||||
or response_type == InteractionCallbackType.DEFERRED_UPDATE_MESSAGE
|
||||
):
|
||||
self._deferred = True
|
||||
|
||||
async def defer(self, ephemeral: bool = False, thinking: bool = True) -> None:
|
||||
"""
|
||||
Defers the interaction response.
|
||||
|
||||
This is typically used when your command might take longer than 3 seconds to process.
|
||||
You must send a followup message within 15 minutes.
|
||||
|
||||
Args:
|
||||
ephemeral (bool): Whether the subsequent followup response should be ephemeral.
|
||||
Only applicable if `thinking` is True.
|
||||
thinking (bool): If True (default), responds with a "Bot is thinking..." message
|
||||
(DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE).
|
||||
If False, responds with DEFERRED_UPDATE_MESSAGE (for components).
|
||||
"""
|
||||
if self._responded:
|
||||
raise RuntimeError("Interaction has already been responded to or deferred.")
|
||||
|
||||
response_type = (
|
||||
InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE
|
||||
if thinking
|
||||
else InteractionCallbackType.DEFERRED_UPDATE_MESSAGE
|
||||
)
|
||||
data = None
|
||||
if ephemeral and thinking:
|
||||
data = {
|
||||
"flags": MessageFlags.EPHEMERAL.value
|
||||
} # Assuming MessageFlags enum exists
|
||||
|
||||
await self._send_response(response_type, data)
|
||||
self._deferred = True # Mark as deferred
|
||||
|
||||
async def send(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
embed: Optional["Embed"] = None, # Convenience for single embed
|
||||
embeds: Optional[List["Embed"]] = None,
|
||||
*,
|
||||
tts: bool = False,
|
||||
files: Optional[List[Any]] = None,
|
||||
components: Optional[List[ActionRow]] = None,
|
||||
view: Optional[View] = None,
|
||||
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||
ephemeral: bool = False,
|
||||
flags: Optional[int] = None,
|
||||
) -> Optional[
|
||||
"Message"
|
||||
]: # Returns Message if not ephemeral and response was not deferred
|
||||
"""
|
||||
Sends a response to the interaction.
|
||||
If the interaction was previously deferred, this will edit the original deferred response.
|
||||
Otherwise, it sends an initial response.
|
||||
|
||||
Args:
|
||||
content (Optional[str]): The message content.
|
||||
embed (Optional[Embed]): A single embed to send. If `embeds` is also provided, this is ignored.
|
||||
embeds (Optional[List[Embed]]): A list of embeds to send (max 10).
|
||||
ephemeral (bool): Whether the message should be ephemeral (only visible to the invoker).
|
||||
flags (Optional[int]): Additional message flags to apply.
|
||||
|
||||
Returns:
|
||||
Optional[Message]: The sent message object if a new message was created and not ephemeral.
|
||||
None if the response was ephemeral or an edit to a deferred message.
|
||||
"""
|
||||
if not self._responded and self._deferred: # Editing a deferred response
|
||||
# Use edit_original_interaction_response
|
||||
payload: Dict[str, Any] = {}
|
||||
if content is not None:
|
||||
payload["content"] = content
|
||||
|
||||
if tts:
|
||||
payload["tts"] = True
|
||||
|
||||
actual_embeds = embeds
|
||||
if embed and not embeds:
|
||||
actual_embeds = [embed]
|
||||
if actual_embeds:
|
||||
payload["embeds"] = [e.to_dict() for e in actual_embeds]
|
||||
|
||||
if view:
|
||||
await view._start(self.bot)
|
||||
payload["components"] = view.to_components_payload()
|
||||
elif components:
|
||||
payload["components"] = [c.to_dict() for c in components]
|
||||
|
||||
if files is not None:
|
||||
payload["attachments"] = [
|
||||
f.to_dict() if hasattr(f, "to_dict") else f for f in files
|
||||
]
|
||||
|
||||
if allowed_mentions is not None:
|
||||
payload["allowed_mentions"] = allowed_mentions
|
||||
|
||||
# Flags (like ephemeral) cannot be set when editing the original deferred response this way.
|
||||
# Ephemeral for deferred must be set during defer().
|
||||
|
||||
msg_data = await self.bot._http.edit_original_interaction_response(
|
||||
application_id=self.application_id,
|
||||
interaction_token=self.token,
|
||||
payload=payload,
|
||||
)
|
||||
self._responded = True # Ensure it's marked as fully responded
|
||||
if view and msg_data and "id" in msg_data:
|
||||
view.message_id = msg_data["id"]
|
||||
self.bot._views[msg_data["id"]] = view
|
||||
# Construct and return Message object if needed, for now returns None for edits
|
||||
return None
|
||||
|
||||
elif not self._responded: # Sending an initial response
|
||||
data: Dict[str, Any] = {}
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
|
||||
if tts:
|
||||
data["tts"] = True
|
||||
|
||||
actual_embeds = embeds
|
||||
if embed and not embeds:
|
||||
actual_embeds = [embed]
|
||||
if actual_embeds:
|
||||
data["embeds"] = [
|
||||
e.to_dict() for e in actual_embeds
|
||||
] # Assuming embeds have to_dict()
|
||||
|
||||
if view:
|
||||
await view._start(self.bot)
|
||||
data["components"] = view.to_components_payload()
|
||||
elif components:
|
||||
data["components"] = [c.to_dict() for c in components]
|
||||
|
||||
if files is not None:
|
||||
data["attachments"] = [
|
||||
f.to_dict() if hasattr(f, "to_dict") else f for f in files
|
||||
]
|
||||
|
||||
if allowed_mentions is not None:
|
||||
data["allowed_mentions"] = allowed_mentions
|
||||
|
||||
flags_value = 0
|
||||
if ephemeral:
|
||||
flags_value |= MessageFlags.EPHEMERAL.value
|
||||
if flags:
|
||||
flags_value |= flags
|
||||
if flags_value:
|
||||
data["flags"] = flags_value
|
||||
|
||||
await self._send_response(
|
||||
InteractionCallbackType.CHANNEL_MESSAGE_WITH_SOURCE, data
|
||||
)
|
||||
|
||||
if view and not ephemeral:
|
||||
try:
|
||||
msg_data = await self.bot._http.get_original_interaction_response(
|
||||
application_id=self.application_id,
|
||||
interaction_token=self.token,
|
||||
)
|
||||
if msg_data and "id" in msg_data:
|
||||
view.message_id = msg_data["id"]
|
||||
self.bot._views[msg_data["id"]] = view
|
||||
except Exception:
|
||||
pass
|
||||
if not ephemeral:
|
||||
return None
|
||||
return None
|
||||
else:
|
||||
# If already responded and not deferred, this should be a followup.
|
||||
# This method is for initial response or editing deferred.
|
||||
raise RuntimeError(
|
||||
"Interaction has already been responded to. Use send_followup()."
|
||||
)
|
||||
|
||||
async def send_followup(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
embed: Optional["Embed"] = None,
|
||||
embeds: Optional[List["Embed"]] = None,
|
||||
*,
|
||||
ephemeral: bool = False,
|
||||
tts: bool = False,
|
||||
files: Optional[List[Any]] = None,
|
||||
components: Optional[List["ActionRow"]] = None,
|
||||
view: Optional[View] = None,
|
||||
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||
flags: Optional[int] = None,
|
||||
) -> Optional["Message"]:
|
||||
"""
|
||||
Sends a followup message to an interaction.
|
||||
This can be used after an initial response or a deferred response.
|
||||
|
||||
Args:
|
||||
content (Optional[str]): The message content.
|
||||
embed (Optional[Embed]): A single embed to send.
|
||||
embeds (Optional[List[Embed]]): A list of embeds to send.
|
||||
ephemeral (bool): Whether the followup message should be ephemeral.
|
||||
flags (Optional[int]): Additional message flags to apply.
|
||||
|
||||
Returns:
|
||||
Message: The sent followup message object.
|
||||
"""
|
||||
if not self._responded:
|
||||
raise RuntimeError(
|
||||
"Must acknowledge or defer the interaction before sending a followup."
|
||||
)
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
if content is not None:
|
||||
payload["content"] = content
|
||||
|
||||
if tts:
|
||||
payload["tts"] = True
|
||||
|
||||
actual_embeds = embeds
|
||||
if embed and not embeds:
|
||||
actual_embeds = [embed]
|
||||
if actual_embeds:
|
||||
payload["embeds"] = [
|
||||
e.to_dict() for e in actual_embeds
|
||||
] # Assuming embeds have to_dict()
|
||||
|
||||
if view:
|
||||
await view._start(self.bot)
|
||||
payload["components"] = view.to_components_payload()
|
||||
elif components:
|
||||
payload["components"] = [c.to_dict() for c in components]
|
||||
|
||||
if files is not None:
|
||||
payload["attachments"] = [
|
||||
f.to_dict() if hasattr(f, "to_dict") else f for f in files
|
||||
]
|
||||
|
||||
if allowed_mentions is not None:
|
||||
payload["allowed_mentions"] = allowed_mentions
|
||||
|
||||
flags_value = 0
|
||||
if ephemeral:
|
||||
flags_value |= MessageFlags.EPHEMERAL.value
|
||||
if flags:
|
||||
flags_value |= flags
|
||||
if flags_value:
|
||||
payload["flags"] = flags_value
|
||||
|
||||
# Followup messages are sent to a webhook endpoint
|
||||
message_data = await self.bot._http.create_followup_message(
|
||||
application_id=self.application_id,
|
||||
interaction_token=self.token,
|
||||
payload=payload,
|
||||
)
|
||||
if view and message_data and "id" in message_data:
|
||||
view.message_id = message_data["id"]
|
||||
self.bot._views[message_data["id"]] = view
|
||||
from disagreement.models import Message # Ensure Message is available
|
||||
|
||||
return Message(data=message_data, client_instance=self.bot)
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
message_id: "Snowflake" = "@original", # Defaults to editing the original response
|
||||
content: Optional[str] = None,
|
||||
embed: Optional["Embed"] = None,
|
||||
embeds: Optional[List["Embed"]] = None,
|
||||
*,
|
||||
components: Optional[List["ActionRow"]] = None,
|
||||
attachments: Optional[List[Any]] = None,
|
||||
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional["Message"]:
|
||||
"""
|
||||
Edits a message previously sent in response to this interaction.
|
||||
Can edit the original response or a followup message.
|
||||
|
||||
Args:
|
||||
message_id (Snowflake): The ID of the message to edit. Defaults to "@original"
|
||||
to edit the initial interaction response.
|
||||
content (Optional[str]): The new message content.
|
||||
embed (Optional[Embed]): A single new embed.
|
||||
embeds (Optional[List[Embed]]): A list of new embeds.
|
||||
|
||||
Returns:
|
||||
Optional[Message]: The edited message object if available.
|
||||
"""
|
||||
if not self._responded:
|
||||
raise RuntimeError(
|
||||
"Cannot edit response if interaction hasn't been responded to or deferred."
|
||||
)
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
if content is not None:
|
||||
payload["content"] = content # Use None to clear
|
||||
|
||||
actual_embeds = embeds
|
||||
if embed and not embeds:
|
||||
actual_embeds = [embed]
|
||||
if actual_embeds is not None: # Allow passing empty list to clear embeds
|
||||
payload["embeds"] = [
|
||||
e.to_dict() for e in actual_embeds
|
||||
] # Assuming embeds have to_dict()
|
||||
|
||||
if components is not None:
|
||||
payload["components"] = [c.to_dict() for c in components]
|
||||
|
||||
if attachments is not None:
|
||||
payload["attachments"] = [
|
||||
a.to_dict() if hasattr(a, "to_dict") else a for a in attachments
|
||||
]
|
||||
|
||||
if allowed_mentions is not None:
|
||||
payload["allowed_mentions"] = allowed_mentions
|
||||
|
||||
if message_id == "@original":
|
||||
edited_message_data = (
|
||||
await self.bot._http.edit_original_interaction_response(
|
||||
application_id=self.application_id,
|
||||
interaction_token=self.token,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
else:
|
||||
edited_message_data = await self.bot._http.edit_followup_message(
|
||||
application_id=self.application_id,
|
||||
interaction_token=self.token,
|
||||
message_id=message_id,
|
||||
payload=payload,
|
||||
)
|
||||
# The HTTP methods used in tests return minimal data that is insufficient
|
||||
# to construct a full ``Message`` instance, so we simply return ``None``
|
||||
# rather than attempting to parse the response.
|
||||
return None
|
||||
|
||||
async def delete(self, message_id: "Snowflake" = "@original") -> None:
|
||||
"""
|
||||
Deletes a message previously sent in response to this interaction.
|
||||
Can delete the original response or a followup message.
|
||||
|
||||
Args:
|
||||
message_id (Snowflake): The ID of the message to delete. Defaults to "@original"
|
||||
to delete the initial interaction response.
|
||||
"""
|
||||
if not self._responded:
|
||||
# If not responded, there's nothing to delete via this interaction's lifecycle.
|
||||
# Deferral doesn't create a message to delete until a followup is sent.
|
||||
raise RuntimeError(
|
||||
"Cannot delete response if interaction hasn't been responded to."
|
||||
)
|
||||
|
||||
if message_id == "@original":
|
||||
await self.bot._http.delete_original_interaction_response(
|
||||
application_id=self.application_id, interaction_token=self.token
|
||||
)
|
||||
else:
|
||||
await self.bot._http.delete_followup_message(
|
||||
application_id=self.application_id,
|
||||
interaction_token=self.token,
|
||||
message_id=message_id,
|
||||
)
|
||||
# After deleting the original response, further followups might be problematic.
|
||||
# Discord docs: "Once the original message is deleted, you can no longer edit the message or send followups."
|
||||
# Consider implications for context state.
|
||||
|
||||
def typing(self) -> Typing:
|
||||
"""Return a typing context manager for this interaction's channel."""
|
||||
|
||||
if not self.channel_id:
|
||||
raise RuntimeError("Cannot send typing indicator without a channel.")
|
||||
return self.bot.typing(self.channel_id)
|
478
disagreement/ext/app_commands/converters.py
Normal file
478
disagreement/ext/app_commands/converters.py
Normal file
@ -0,0 +1,478 @@
|
||||
# disagreement/ext/app_commands/converters.py
|
||||
|
||||
"""
|
||||
Converters for transforming application command option values.
|
||||
"""
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from disagreement.enums import ApplicationCommandOptionType
|
||||
from disagreement.errors import (
|
||||
AppCommandOptionConversionError,
|
||||
) # To be created in disagreement/errors.py
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from disagreement.interactions import Interaction # For context if needed
|
||||
from disagreement.models import (
|
||||
User,
|
||||
Member,
|
||||
Role,
|
||||
Channel,
|
||||
Attachment,
|
||||
) # Discord models
|
||||
from disagreement.client import Client # For fetching objects
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
|
||||
|
||||
class Converter(Protocol[T]):
|
||||
"""
|
||||
A protocol for classes that can convert an interaction option value to a specific type.
|
||||
"""
|
||||
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> T:
|
||||
"""
|
||||
Converts the given value to the target type.
|
||||
|
||||
Parameters:
|
||||
interaction (Interaction): The interaction context.
|
||||
value (Any): The raw value from the interaction option.
|
||||
|
||||
Returns:
|
||||
T: The converted value.
|
||||
|
||||
Raises:
|
||||
AppCommandOptionConversionError: If conversion fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Basic Type Converters
|
||||
|
||||
|
||||
class StringConverter(Converter[str]):
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Expected a string, but got {type(value).__name__}: {value}"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class IntegerConverter(Converter[int]):
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> int:
|
||||
if not isinstance(value, int):
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Expected an integer, but got {type(value).__name__}: {value}"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class BooleanConverter(Converter[bool]):
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> bool:
|
||||
if not isinstance(value, bool):
|
||||
if isinstance(value, str):
|
||||
if value.lower() == "true":
|
||||
return True
|
||||
elif value.lower() == "false":
|
||||
return False
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Expected a boolean, but got {type(value).__name__}: {value}"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class NumberConverter(Converter[float]): # Discord 'NUMBER' type is float
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> float:
|
||||
if not isinstance(value, (int, float)):
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Expected a number (float), but got {type(value).__name__}: {value}"
|
||||
)
|
||||
return float(value) # Ensure it's a float even if int is passed
|
||||
|
||||
|
||||
# Discord Model Converters
|
||||
|
||||
|
||||
class UserConverter(Converter["User"]):
|
||||
def __init__(self, client: "Client"):
|
||||
self._client = client
|
||||
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> "User":
|
||||
if isinstance(value, str): # Assume it's a user ID
|
||||
user_id = value
|
||||
# Attempt to get from interaction resolved data first
|
||||
if (
|
||||
interaction.data
|
||||
and interaction.data.resolved
|
||||
and interaction.data.resolved.users
|
||||
):
|
||||
user_object = interaction.data.resolved.users.get(
|
||||
user_id
|
||||
) # This is already a User object
|
||||
if user_object:
|
||||
return user_object # Return the already parsed User object
|
||||
|
||||
# Fallback to fetching if not in resolved or if interaction has no resolved data
|
||||
try:
|
||||
user = await self._client.fetch_user(
|
||||
user_id
|
||||
) # fetch_user now also parses and caches
|
||||
if user:
|
||||
return user
|
||||
raise AppCommandOptionConversionError(
|
||||
f"User with ID '{user_id}' not found.",
|
||||
option_name="user",
|
||||
original_value=value,
|
||||
)
|
||||
except Exception as e: # Catch potential HTTP errors from fetch_user
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Failed to fetch user '{user_id}': {e}",
|
||||
option_name="user",
|
||||
original_value=value,
|
||||
)
|
||||
elif (
|
||||
isinstance(value, dict) and "id" in value
|
||||
): # If it's raw user data dict (less common path now)
|
||||
return self._client.parse_user(value) # parse_user handles dict -> User
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Expected a user ID string or user data dict, got {type(value).__name__}",
|
||||
option_name="user",
|
||||
original_value=value,
|
||||
)
|
||||
|
||||
|
||||
class MemberConverter(Converter["Member"]):
|
||||
def __init__(self, client: "Client"):
|
||||
self._client = client
|
||||
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> "Member":
|
||||
if not interaction.guild_id:
|
||||
raise AppCommandOptionConversionError(
|
||||
"Cannot convert to Member outside of a guild context.",
|
||||
option_name="member",
|
||||
)
|
||||
|
||||
if isinstance(value, str): # Assume it's a user ID
|
||||
member_id = value
|
||||
# Attempt to get from interaction resolved data first
|
||||
if (
|
||||
interaction.data
|
||||
and interaction.data.resolved
|
||||
and interaction.data.resolved.members
|
||||
):
|
||||
# The Member object from resolved.members should already be correctly initialized
|
||||
# by ResolvedData, including its User part.
|
||||
member = interaction.data.resolved.members.get(member_id)
|
||||
if member:
|
||||
return (
|
||||
member # Return the already resolved and parsed Member object
|
||||
)
|
||||
|
||||
# Fallback to fetching if not in resolved
|
||||
try:
|
||||
member = await self._client.fetch_member(
|
||||
interaction.guild_id, member_id
|
||||
)
|
||||
if member:
|
||||
return member
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Member with ID '{member_id}' not found in guild '{interaction.guild_id}'.",
|
||||
option_name="member",
|
||||
original_value=value,
|
||||
)
|
||||
except Exception as e:
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Failed to fetch member '{member_id}': {e}",
|
||||
option_name="member",
|
||||
original_value=value,
|
||||
)
|
||||
elif isinstance(value, dict) and "id" in value.get(
|
||||
"user", {}
|
||||
): # If it's already a member data dict
|
||||
return self._client.parse_member(value, interaction.guild_id)
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Expected a member ID string or member data dict, got {type(value).__name__}",
|
||||
option_name="member",
|
||||
original_value=value,
|
||||
)
|
||||
|
||||
|
||||
class RoleConverter(Converter["Role"]):
|
||||
def __init__(self, client: "Client"):
|
||||
self._client = client
|
||||
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> "Role":
|
||||
if not interaction.guild_id:
|
||||
raise AppCommandOptionConversionError(
|
||||
"Cannot convert to Role outside of a guild context.", option_name="role"
|
||||
)
|
||||
|
||||
if isinstance(value, str): # Assume it's a role ID
|
||||
role_id = value
|
||||
# Attempt to get from interaction resolved data first
|
||||
if (
|
||||
interaction.data
|
||||
and interaction.data.resolved
|
||||
and interaction.data.resolved.roles
|
||||
):
|
||||
role_object = interaction.data.resolved.roles.get(
|
||||
role_id
|
||||
) # Should be a Role object
|
||||
if role_object:
|
||||
return role_object
|
||||
|
||||
# Fallback to fetching from guild if not in resolved
|
||||
# This requires Client to have a fetch_role method or similar
|
||||
try:
|
||||
# Assuming Client.fetch_role(guild_id, role_id) will be implemented
|
||||
role = await self._client.fetch_role(interaction.guild_id, role_id)
|
||||
if role:
|
||||
return role
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Role with ID '{role_id}' not found in guild '{interaction.guild_id}'.",
|
||||
option_name="role",
|
||||
original_value=value,
|
||||
)
|
||||
except Exception as e:
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Failed to fetch role '{role_id}': {e}",
|
||||
option_name="role",
|
||||
original_value=value,
|
||||
)
|
||||
elif (
|
||||
isinstance(value, dict) and "id" in value
|
||||
): # If it's already role data dict
|
||||
if (
|
||||
not interaction.guild_id
|
||||
): # Should have been caught earlier, but as a safeguard
|
||||
raise AppCommandOptionConversionError(
|
||||
"Guild context is required to parse role data.",
|
||||
option_name="role",
|
||||
original_value=value,
|
||||
)
|
||||
return self._client.parse_role(value, interaction.guild_id)
|
||||
# This path is reached if value is not a string (role ID) and not a dict (role data)
|
||||
# or if it's a string but all fetching/lookup attempts failed.
|
||||
# The final raise AppCommandOptionConversionError should be outside the if/elif for string values.
|
||||
# If value was a string, an error should have been raised within the 'if isinstance(value, str)' block.
|
||||
# If it wasn't a string or dict, this is the correct place to raise.
|
||||
# The previous structure was slightly off, as the final raise was inside the string check block.
|
||||
# Let's ensure the final raise is at the correct scope.
|
||||
# The current structure seems to imply that if it's not a string, it must be a dict or error.
|
||||
# If it's a string and all lookups fail, an error is raised within that block.
|
||||
# If it's a dict and parsing fails (or guild_id missing), error raised.
|
||||
# If it's neither, this final raise is correct.
|
||||
# The "Function with declared return type "Role" must return value on all code paths"
|
||||
# error suggests a path where no return or raise happens.
|
||||
# This happens if `isinstance(value, str)` is true, but then all internal paths
|
||||
# (resolved check, fetch try/except) don't lead to a return or raise *before*
|
||||
# falling out of the `if isinstance(value, str)` block.
|
||||
# The `raise AppCommandOptionConversionError` at the end of the `if isinstance(value, str)` block
|
||||
# (line 156 in previous version) handles the case where a role ID is given but not found.
|
||||
# The one at the very end (line 164 in previous) handles cases where value is not str/dict.
|
||||
|
||||
# Corrected structure for the final raise:
|
||||
# It should be at the same level as the initial `if isinstance(value, str):`
|
||||
# to catch cases where `value` is neither a str nor a dict.
|
||||
# However, the current logic within the `if isinstance(value, str):` block
|
||||
# ensures a raise if the role ID is not found.
|
||||
# The `elif isinstance(value, dict)` handles the dict case.
|
||||
# The final `raise` (line 164) is for types other than str or dict.
|
||||
# The Pylance error "Function with declared return type "Role" must return value on all code paths"
|
||||
# implies that if `value` is a string, and `interaction.data.resolved.roles.get(role_id)` is None,
|
||||
# AND `self._client.fetch_role` returns None (which it can), then the
|
||||
# `raise AppCommandOptionConversionError` on line 156 is correctly hit.
|
||||
# The issue might be that Pylance doesn't see `AppCommandOptionConversionError` as definitively terminating.
|
||||
# This is unlikely. Let's re-verify the logic flow.
|
||||
|
||||
# The `raise` on line 156 is correct if role_id is not found after fetching.
|
||||
# The `raise` on line 164 is for when `value` is not a str and not a dict.
|
||||
# This seems logically sound. The Pylance error might be a misinterpretation or a subtle issue.
|
||||
# For now, the duplicated `except` is the primary syntax error.
|
||||
# The "must return value on all code paths" often occurs if an if/elif chain doesn't
|
||||
# exhaust all possibilities or if a path through a try/except doesn't guarantee a return/raise.
|
||||
# In this case, if `value` is a string, it either returns a Role or raises an error.
|
||||
# If `value` is a dict, it either returns a Role or raises an error.
|
||||
# If `value` is neither, it raises an error. All paths seem covered.
|
||||
# The syntax error from the duplicated `except` is the most likely culprit for Pylance's confusion.
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Expected a role ID string or role data dict, got {type(value).__name__}",
|
||||
option_name="role",
|
||||
original_value=value,
|
||||
)
|
||||
|
||||
|
||||
class ChannelConverter(Converter["Channel"]):
|
||||
def __init__(self, client: "Client"):
|
||||
self._client = client
|
||||
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> "Channel":
|
||||
if isinstance(value, str): # Assume it's a channel ID
|
||||
channel_id = value
|
||||
# Attempt to get from interaction resolved data first
|
||||
if (
|
||||
interaction.data
|
||||
and interaction.data.resolved
|
||||
and interaction.data.resolved.channels
|
||||
):
|
||||
# Resolved channels are PartialChannel. Client.fetch_channel will get the full typed one.
|
||||
partial_channel = interaction.data.resolved.channels.get(channel_id)
|
||||
if partial_channel:
|
||||
# Client.fetch_channel should handle fetching and parsing to the correct Channel subtype
|
||||
full_channel = await self._client.fetch_channel(partial_channel.id)
|
||||
if full_channel:
|
||||
return full_channel
|
||||
# If fetch_channel returns None even with a resolved ID, it's an issue.
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Failed to fetch full channel for resolved ID '{channel_id}'.",
|
||||
option_name="channel",
|
||||
original_value=value,
|
||||
)
|
||||
|
||||
# Fallback to fetching directly if not in resolved or if resolved fetch failed
|
||||
try:
|
||||
channel = await self._client.fetch_channel(
|
||||
channel_id
|
||||
) # fetch_channel handles parsing
|
||||
if channel:
|
||||
return channel
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Channel with ID '{channel_id}' not found.",
|
||||
option_name="channel",
|
||||
original_value=value,
|
||||
)
|
||||
except Exception as e:
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Failed to fetch channel '{channel_id}': {e}",
|
||||
option_name="channel",
|
||||
original_value=value,
|
||||
)
|
||||
# Raw channel data dicts are not typically provided for slash command options.
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Expected a channel ID string, got {type(value).__name__}",
|
||||
option_name="channel",
|
||||
original_value=value,
|
||||
)
|
||||
|
||||
|
||||
class AttachmentConverter(Converter["Attachment"]):
|
||||
def __init__(
|
||||
self, client: "Client"
|
||||
): # Client might be needed for future enhancements or consistency
|
||||
self._client = client
|
||||
|
||||
async def convert(self, interaction: "Interaction", value: Any) -> "Attachment":
|
||||
if isinstance(value, str): # Value is the attachment ID
|
||||
attachment_id = value
|
||||
if (
|
||||
interaction.data
|
||||
and interaction.data.resolved
|
||||
and interaction.data.resolved.attachments
|
||||
):
|
||||
attachment_object = interaction.data.resolved.attachments.get(
|
||||
attachment_id
|
||||
) # This is already an Attachment object
|
||||
if attachment_object:
|
||||
return (
|
||||
attachment_object # Return the already parsed Attachment object
|
||||
)
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Attachment with ID '{attachment_id}' not found in resolved data.",
|
||||
option_name="attachment",
|
||||
original_value=value,
|
||||
)
|
||||
raise AppCommandOptionConversionError(
|
||||
f"Expected an attachment ID string, got {type(value).__name__}",
|
||||
option_name="attachment",
|
||||
original_value=value,
|
||||
)
|
||||
|
||||
|
||||
# Converters can be registered dynamically using
|
||||
# :meth:`disagreement.ext.app_commands.handler.AppCommandHandler.register_converter`.
|
||||
|
||||
# Mapping from ApplicationCommandOptionType to default converters
|
||||
# This will be used by the AppCommandHandler to automatically apply converters
|
||||
# if no explicit converter is specified for a command option's type hint.
|
||||
DEFAULT_CONVERTERS: Dict[
|
||||
ApplicationCommandOptionType, Callable[..., Converter[Any]]
|
||||
] = { # Changed Callable signature
|
||||
ApplicationCommandOptionType.STRING: StringConverter,
|
||||
ApplicationCommandOptionType.INTEGER: IntegerConverter,
|
||||
ApplicationCommandOptionType.BOOLEAN: BooleanConverter,
|
||||
ApplicationCommandOptionType.NUMBER: NumberConverter,
|
||||
ApplicationCommandOptionType.USER: UserConverter,
|
||||
ApplicationCommandOptionType.CHANNEL: ChannelConverter,
|
||||
ApplicationCommandOptionType.ROLE: RoleConverter,
|
||||
# ApplicationCommandOptionType.MENTIONABLE: MentionableConverter, # Special case, can be User or Role
|
||||
ApplicationCommandOptionType.ATTACHMENT: AttachmentConverter, # Added
|
||||
}
|
||||
|
||||
|
||||
async def run_converters(
|
||||
interaction: "Interaction",
|
||||
param_type: Any, # The type hint of the parameter
|
||||
option_type: ApplicationCommandOptionType, # The Discord option type
|
||||
value: Any,
|
||||
client: "Client", # Needed for model converters
|
||||
) -> Any:
|
||||
"""
|
||||
Runs the appropriate converter for a given parameter type and value.
|
||||
This function will be more complex, handling custom converters, unions, optionals etc.
|
||||
For now, a basic lookup.
|
||||
"""
|
||||
converter_class_factory = DEFAULT_CONVERTERS.get(option_type)
|
||||
if converter_class_factory:
|
||||
# Check if the factory needs the client instance
|
||||
# This is a bit simplistic; a more robust way might involve inspecting __init__ signature
|
||||
# or having converters register their needs.
|
||||
if option_type in [
|
||||
ApplicationCommandOptionType.USER,
|
||||
ApplicationCommandOptionType.CHANNEL, # Anticipating these
|
||||
ApplicationCommandOptionType.ROLE,
|
||||
ApplicationCommandOptionType.MENTIONABLE,
|
||||
ApplicationCommandOptionType.ATTACHMENT,
|
||||
]:
|
||||
converter_instance = converter_class_factory(client=client)
|
||||
else:
|
||||
converter_instance = converter_class_factory()
|
||||
|
||||
return await converter_instance.convert(interaction, value)
|
||||
|
||||
# Fallback for unhandled types or if direct type matching is needed
|
||||
if param_type is str and isinstance(value, str):
|
||||
return value
|
||||
if param_type is int and isinstance(value, int):
|
||||
return value
|
||||
if param_type is bool and isinstance(value, bool):
|
||||
return value
|
||||
if param_type is float and isinstance(value, (float, int)):
|
||||
return float(value)
|
||||
|
||||
# If no specific converter, and it's not a basic type match, raise error or return raw
|
||||
# For now, let's raise if no converter found for a specific option type
|
||||
if option_type in DEFAULT_CONVERTERS: # Should have been handled
|
||||
pass # This path implies a logic error above or missing converter in DEFAULT_CONVERTERS
|
||||
|
||||
# If it's a model type but no converter yet, this will need to be handled
|
||||
# e.g. if param_type is User and option_type is ApplicationCommandOptionType.USER
|
||||
|
||||
raise AppCommandOptionConversionError(
|
||||
f"No suitable converter found for option type {option_type.name} "
|
||||
f"with value '{value}' to target type {param_type.__name__ if hasattr(param_type, '__name__') else param_type}"
|
||||
)
|
569
disagreement/ext/app_commands/decorators.py
Normal file
569
disagreement/ext/app_commands/decorators.py
Normal file
@ -0,0 +1,569 @@
|
||||
# disagreement/ext/app_commands/decorators.py
|
||||
|
||||
import inspect
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Callable,
|
||||
Optional,
|
||||
List,
|
||||
Dict,
|
||||
Any,
|
||||
Union,
|
||||
Type,
|
||||
get_origin,
|
||||
get_args,
|
||||
TYPE_CHECKING,
|
||||
Literal,
|
||||
Annotated,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from .commands import (
|
||||
SlashCommand,
|
||||
UserCommand,
|
||||
MessageCommand,
|
||||
AppCommand,
|
||||
AppCommandGroup,
|
||||
HybridCommand,
|
||||
)
|
||||
from disagreement.interactions import (
|
||||
ApplicationCommandOption,
|
||||
ApplicationCommandOptionChoice,
|
||||
Snowflake,
|
||||
)
|
||||
from disagreement.enums import (
|
||||
ApplicationCommandOptionType,
|
||||
IntegrationType,
|
||||
InteractionContextType,
|
||||
# Assuming ChannelType will be added to disagreement.enums
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from disagreement.client import Client # For potential future use
|
||||
from disagreement.models import Channel, User
|
||||
|
||||
# Assuming TextChannel, VoiceChannel etc. might be defined or aliased
|
||||
# For now, we'll use string comparisons for channel types or rely on a yet-to-be-defined ChannelType enum
|
||||
Channel = Any
|
||||
Member = Any
|
||||
Role = Any
|
||||
Attachment = Any
|
||||
# from .cog import Cog # Placeholder
|
||||
else:
|
||||
# Runtime fallbacks for optional model classes
|
||||
from disagreement.models import Channel
|
||||
|
||||
Client = Any # type: ignore
|
||||
User = Any # type: ignore
|
||||
Member = Any # type: ignore
|
||||
Role = Any # type: ignore
|
||||
Attachment = Any # type: ignore
|
||||
|
||||
# Mapping Python types to Discord ApplicationCommandOptionType
|
||||
# This will need to be expanded and made more robust.
|
||||
# Consider using a registry or a more sophisticated type mapping system.
|
||||
_type_mapping: Dict[Any, ApplicationCommandOptionType] = (
|
||||
{ # Changed Type to Any for key due to placeholders
|
||||
str: ApplicationCommandOptionType.STRING,
|
||||
int: ApplicationCommandOptionType.INTEGER,
|
||||
bool: ApplicationCommandOptionType.BOOLEAN,
|
||||
float: ApplicationCommandOptionType.NUMBER, # Discord 'NUMBER' type is for float/double
|
||||
User: ApplicationCommandOptionType.USER,
|
||||
Channel: ApplicationCommandOptionType.CHANNEL,
|
||||
# Placeholders for actual model types from disagreement.models
|
||||
# These will be resolved to their actual types when TYPE_CHECKING is False or via isinstance checks
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Helper dataclass for storing extra option metadata
|
||||
@dataclass
|
||||
class OptionMetadata:
|
||||
channel_types: Optional[List[int]] = None
|
||||
min_value: Optional[Union[int, float]] = None
|
||||
max_value: Optional[Union[int, float]] = None
|
||||
min_length: Optional[int] = None
|
||||
max_length: Optional[int] = None
|
||||
autocomplete: bool = False
|
||||
|
||||
|
||||
# Ensure these are updated if model names/locations change
|
||||
if TYPE_CHECKING:
|
||||
# _type_mapping[User] = ApplicationCommandOptionType.USER # Already added above
|
||||
_type_mapping[Member] = ApplicationCommandOptionType.USER # Member implies User
|
||||
_type_mapping[Role] = ApplicationCommandOptionType.ROLE
|
||||
_type_mapping[Attachment] = ApplicationCommandOptionType.ATTACHMENT
|
||||
_type_mapping[Channel] = ApplicationCommandOptionType.CHANNEL
|
||||
|
||||
# TypeVar for the app command decorator factory
|
||||
AppCmdType = TypeVar("AppCmdType", bound=AppCommand)
|
||||
|
||||
|
||||
def _extract_options_from_signature(
|
||||
func: Callable[..., Any], option_meta: Optional[Dict[str, OptionMetadata]] = None
|
||||
) -> List[ApplicationCommandOption]:
|
||||
"""
|
||||
Inspects a function signature and generates ApplicationCommandOption list.
|
||||
"""
|
||||
options: List[ApplicationCommandOption] = []
|
||||
params = inspect.signature(func).parameters
|
||||
|
||||
doc = inspect.getdoc(func)
|
||||
param_descriptions: Dict[str, str] = {}
|
||||
if doc:
|
||||
for line in inspect.cleandoc(doc).splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith(":param"):
|
||||
try:
|
||||
_, rest = line.split(" ", 1)
|
||||
name, desc = rest.split(":", 1)
|
||||
param_descriptions[name.strip()] = desc.strip()
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Skip 'self' (for cogs) and 'ctx' (context) parameters
|
||||
param_iter = iter(params.values())
|
||||
first_param = next(param_iter, None)
|
||||
|
||||
# Heuristic: if the function is bound to a class (cog), 'self' might be the first param.
|
||||
# A more robust way would be to check if `func` is a method of a Cog instance later.
|
||||
# For now, simple name check.
|
||||
if first_param and first_param.name == "self":
|
||||
first_param = next(param_iter, None) # Consume 'self', get next
|
||||
|
||||
if first_param and first_param.name == "ctx": # Consume 'ctx'
|
||||
pass # ctx is handled, now iterate over actual command options
|
||||
elif (
|
||||
first_param
|
||||
): # If first_param was not 'self' and not 'ctx', it's a command option
|
||||
param_iter = iter(params.values()) # Reset iterator to include the first param
|
||||
|
||||
for param in param_iter:
|
||||
if param.name == "self" or param.name == "ctx": # Should have been skipped
|
||||
continue
|
||||
|
||||
if param.kind == param.VAR_POSITIONAL or param.kind == param.VAR_KEYWORD:
|
||||
# *args and **kwargs are not directly supported by slash command options structure.
|
||||
# Could raise an error or ignore. For now, ignore.
|
||||
# print(f"Warning: *args/**kwargs ({param.name}) are not supported for slash command options.")
|
||||
continue
|
||||
|
||||
option_name = param.name
|
||||
option_description = param_descriptions.get(
|
||||
option_name, f"Description for {option_name}"
|
||||
)
|
||||
meta = option_meta.get(option_name) if option_meta else None
|
||||
|
||||
param_type_hint = param.annotation
|
||||
if param_type_hint == inspect.Parameter.empty:
|
||||
# Default to string if no type hint, or raise error.
|
||||
# Forcing type hints is generally better for slash commands.
|
||||
# raise TypeError(f"Option '{option_name}' must have a type hint for slash commands.")
|
||||
param_type_hint = str # Defaulting to string, can be made stricter
|
||||
|
||||
option_type: Optional[ApplicationCommandOptionType] = None
|
||||
choices: Optional[List[ApplicationCommandOptionChoice]] = None
|
||||
|
||||
origin = get_origin(param_type_hint)
|
||||
args = get_args(param_type_hint)
|
||||
|
||||
if origin is Annotated:
|
||||
param_type_hint = args[0]
|
||||
for extra in args[1:]:
|
||||
if isinstance(extra, OptionMetadata):
|
||||
meta = extra
|
||||
origin = get_origin(param_type_hint)
|
||||
args = get_args(param_type_hint)
|
||||
|
||||
actual_type_for_mapping = param_type_hint
|
||||
is_optional = False
|
||||
|
||||
if origin is Union: # Handles Optional[T] which is Union[T, NoneType]
|
||||
# Filter out NoneType to get the actual type for mapping
|
||||
union_types = [t for t in args if t is not type(None)]
|
||||
if len(union_types) == 1:
|
||||
actual_type_for_mapping = union_types[0]
|
||||
is_optional = True
|
||||
else:
|
||||
# More complex Unions are not directly supported by a single option type.
|
||||
# Could default to STRING or raise.
|
||||
# For now, let's assume simple Optional[T] or direct types.
|
||||
# print(f"Warning: Complex Union type for '{option_name}' not fully supported, defaulting to STRING.")
|
||||
actual_type_for_mapping = str
|
||||
|
||||
elif origin is list and len(args) == 1:
|
||||
# List[T] is not a direct option type. Discord handles multiple values for some types
|
||||
# via repeated options or specific component interactions, not directly in slash command options.
|
||||
# This might indicate a need for a different interaction pattern or custom parsing.
|
||||
# For now, treat List[str] as a string, others might error or default.
|
||||
# print(f"Warning: List type for '{option_name}' not directly supported as a single option. Consider type {args[0]}.")
|
||||
actual_type_for_mapping = args[
|
||||
0
|
||||
] # Use the inner type for mapping, but this is a simplification.
|
||||
|
||||
if origin is Literal: # typing.Literal['a', 'b']
|
||||
choices = []
|
||||
for choice_val in args:
|
||||
if not isinstance(choice_val, (str, int, float)):
|
||||
raise TypeError(
|
||||
f"Literal choices for '{option_name}' must be str, int, or float. Got {type(choice_val)}."
|
||||
)
|
||||
choices.append(
|
||||
ApplicationCommandOptionChoice(
|
||||
data={"name": str(choice_val), "value": choice_val}
|
||||
)
|
||||
)
|
||||
# The type of the Literal's arguments determines the option type
|
||||
if choices:
|
||||
literal_arg_type = type(choices[0].value)
|
||||
option_type = _type_mapping.get(literal_arg_type)
|
||||
if (
|
||||
not option_type and literal_arg_type is float
|
||||
): # float maps to NUMBER
|
||||
option_type = ApplicationCommandOptionType.NUMBER
|
||||
|
||||
if not option_type: # If not determined by Literal
|
||||
option_type = _type_mapping.get(actual_type_for_mapping)
|
||||
# Special handling for User, Member, Role, Attachment, Channel if not directly in _type_mapping
|
||||
# This is a bit crude; a proper registry or isinstance checks would be better.
|
||||
if not option_type:
|
||||
if (
|
||||
actual_type_for_mapping.__name__ == "User"
|
||||
or actual_type_for_mapping.__name__ == "Member"
|
||||
):
|
||||
option_type = ApplicationCommandOptionType.USER
|
||||
elif actual_type_for_mapping.__name__ == "Role":
|
||||
option_type = ApplicationCommandOptionType.ROLE
|
||||
elif actual_type_for_mapping.__name__ == "Attachment":
|
||||
option_type = ApplicationCommandOptionType.ATTACHMENT
|
||||
elif (
|
||||
inspect.isclass(actual_type_for_mapping)
|
||||
and isinstance(Channel, type)
|
||||
and issubclass(actual_type_for_mapping, cast(type, Channel))
|
||||
):
|
||||
option_type = ApplicationCommandOptionType.CHANNEL
|
||||
|
||||
if not option_type:
|
||||
# Fallback or error if type couldn't be mapped
|
||||
# print(f"Warning: Could not map type '{actual_type_for_mapping}' for option '{option_name}'. Defaulting to STRING.")
|
||||
option_type = ApplicationCommandOptionType.STRING # Default fallback
|
||||
|
||||
required = (param.default == inspect.Parameter.empty) and not is_optional
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"name": option_name,
|
||||
"description": option_description,
|
||||
"type": option_type.value,
|
||||
"required": required,
|
||||
"choices": ([c.to_dict() for c in choices] if choices else None),
|
||||
}
|
||||
|
||||
if meta:
|
||||
if meta.channel_types is not None:
|
||||
data["channel_types"] = meta.channel_types
|
||||
if meta.min_value is not None:
|
||||
data["min_value"] = meta.min_value
|
||||
if meta.max_value is not None:
|
||||
data["max_value"] = meta.max_value
|
||||
if meta.min_length is not None:
|
||||
data["min_length"] = meta.min_length
|
||||
if meta.max_length is not None:
|
||||
data["max_length"] = meta.max_length
|
||||
if meta.autocomplete:
|
||||
data["autocomplete"] = True
|
||||
|
||||
options.append(ApplicationCommandOption(data=data))
|
||||
return options
|
||||
|
||||
|
||||
def _app_command_decorator(
|
||||
cls: Type[AppCmdType],
|
||||
option_meta: Optional[Dict[str, OptionMetadata]] = None,
|
||||
**attrs: Any,
|
||||
) -> Callable[[Callable[..., Any]], AppCmdType]:
|
||||
"""Generic factory for creating app command decorators."""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> AppCmdType:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError(
|
||||
"Application command callback must be a coroutine function."
|
||||
)
|
||||
|
||||
name = attrs.pop("name", None) or func.__name__
|
||||
description = attrs.pop("description", None) or inspect.getdoc(func)
|
||||
if description: # Clean up docstring
|
||||
description = inspect.cleandoc(description).split("\n\n", 1)[
|
||||
0
|
||||
] # Use first paragraph
|
||||
|
||||
parent_group = attrs.pop("parent", None)
|
||||
if parent_group and not isinstance(parent_group, AppCommandGroup):
|
||||
raise TypeError(
|
||||
"The 'parent' argument must be an AppCommandGroup instance."
|
||||
)
|
||||
|
||||
# For User/Message commands, description should be empty for payload, but can be stored for help.
|
||||
if cls is UserCommand or cls is MessageCommand:
|
||||
actual_description_for_payload = ""
|
||||
else:
|
||||
actual_description_for_payload = description
|
||||
if not actual_description_for_payload and cls is SlashCommand:
|
||||
raise ValueError(f"Slash command '{name}' must have a description.")
|
||||
|
||||
# Create the command instance
|
||||
cmd_instance = cls(
|
||||
callback=func,
|
||||
name=name,
|
||||
description=actual_description_for_payload, # Use payload-appropriate description
|
||||
**attrs, # Remaining attributes like guild_ids, nsfw, etc.
|
||||
)
|
||||
|
||||
# Store original description if different (e.g. for User/Message commands for help text)
|
||||
if description != actual_description_for_payload:
|
||||
cmd_instance._full_description = (
|
||||
description # Custom attribute for library use
|
||||
)
|
||||
|
||||
if isinstance(cmd_instance, SlashCommand):
|
||||
cmd_instance.options = _extract_options_from_signature(func, option_meta)
|
||||
|
||||
if parent_group:
|
||||
parent_group.add_command(cmd_instance) # This also sets cmd_instance.parent
|
||||
|
||||
# Attach command object to the function for later collection by Cog or Client
|
||||
# This is a common pattern.
|
||||
if hasattr(func, "__app_command_object__"):
|
||||
# Function might already be decorated (e.g. hybrid or stacked decorators)
|
||||
# Decide on behavior: error, overwrite, or store list of commands.
|
||||
# For now, let's assume one app command decorator of a specific type per function.
|
||||
# Hybrid commands will need special handling.
|
||||
print(
|
||||
f"Warning: Function {func.__name__} is already an app command or has one attached. Overwriting."
|
||||
)
|
||||
|
||||
setattr(func, "__app_command_object__", cmd_instance)
|
||||
setattr(cmd_instance, "__app_command_object__", cmd_instance)
|
||||
|
||||
# If the command is a HybridCommand, also set the attribute
|
||||
# that the prefix command system's Cog._inject looks for.
|
||||
if isinstance(cmd_instance, HybridCommand):
|
||||
setattr(func, "__command_object__", cmd_instance)
|
||||
setattr(cmd_instance, "__command_object__", cmd_instance)
|
||||
|
||||
return cmd_instance # Return the command instance itself, not the function
|
||||
# This allows it to be added to cogs/handlers directly.
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def slash_command(
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
guild_ids: Optional[List[Snowflake]] = None,
|
||||
default_member_permissions: Optional[str] = None,
|
||||
nsfw: bool = False,
|
||||
name_localizations: Optional[Dict[str, str]] = None,
|
||||
description_localizations: Optional[Dict[str, str]] = None,
|
||||
integration_types: Optional[List[IntegrationType]] = None,
|
||||
contexts: Optional[List[InteractionContextType]] = None,
|
||||
*,
|
||||
guilds: bool = True,
|
||||
dms: bool = True,
|
||||
private_channels: bool = True,
|
||||
parent: Optional[AppCommandGroup] = None, # Added parent parameter
|
||||
locale: Optional[str] = None,
|
||||
option_meta: Optional[Dict[str, OptionMetadata]] = None,
|
||||
) -> Callable[[Callable[..., Any]], SlashCommand]:
|
||||
"""
|
||||
Decorator to create a CHAT_INPUT (slash) command.
|
||||
Options are inferred from the function's type hints.
|
||||
"""
|
||||
if contexts is None:
|
||||
ctxs: List[InteractionContextType] = []
|
||||
if guilds:
|
||||
ctxs.append(InteractionContextType.GUILD)
|
||||
if dms:
|
||||
ctxs.append(InteractionContextType.BOT_DM)
|
||||
if private_channels:
|
||||
ctxs.append(InteractionContextType.PRIVATE_CHANNEL)
|
||||
if len(ctxs) != 3:
|
||||
contexts = ctxs
|
||||
attrs = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"guild_ids": guild_ids,
|
||||
"default_member_permissions": default_member_permissions,
|
||||
"nsfw": nsfw,
|
||||
"name_localizations": name_localizations,
|
||||
"description_localizations": description_localizations,
|
||||
"integration_types": integration_types,
|
||||
"contexts": contexts,
|
||||
"parent": parent, # Pass parent to attrs
|
||||
"locale": locale,
|
||||
}
|
||||
# Filter out None values to avoid passing them as explicit None to command constructor
|
||||
# Keep 'parent' even if None, as _app_command_decorator handles None parent.
|
||||
# nsfw default is False, so it's fine if not present and defaults.
|
||||
attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw", "parent"]}
|
||||
return _app_command_decorator(SlashCommand, option_meta, **attrs)
|
||||
|
||||
|
||||
def user_command(
|
||||
name: Optional[str] = None,
|
||||
guild_ids: Optional[List[Snowflake]] = None,
|
||||
default_member_permissions: Optional[str] = None,
|
||||
nsfw: bool = False, # Though less common for user commands
|
||||
name_localizations: Optional[Dict[str, str]] = None,
|
||||
integration_types: Optional[List[IntegrationType]] = None,
|
||||
contexts: Optional[List[InteractionContextType]] = None,
|
||||
locale: Optional[str] = None,
|
||||
# description is not used by Discord for User commands
|
||||
) -> Callable[[Callable[..., Any]], UserCommand]:
|
||||
"""Decorator to create a USER context menu command."""
|
||||
attrs = {
|
||||
"name": name,
|
||||
"guild_ids": guild_ids,
|
||||
"default_member_permissions": default_member_permissions,
|
||||
"nsfw": nsfw,
|
||||
"name_localizations": name_localizations,
|
||||
"integration_types": integration_types,
|
||||
"contexts": contexts,
|
||||
"locale": locale,
|
||||
}
|
||||
attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw"]}
|
||||
return _app_command_decorator(UserCommand, **attrs)
|
||||
|
||||
|
||||
def message_command(
|
||||
name: Optional[str] = None,
|
||||
guild_ids: Optional[List[Snowflake]] = None,
|
||||
default_member_permissions: Optional[str] = None,
|
||||
nsfw: bool = False, # Though less common for message commands
|
||||
name_localizations: Optional[Dict[str, str]] = None,
|
||||
integration_types: Optional[List[IntegrationType]] = None,
|
||||
contexts: Optional[List[InteractionContextType]] = None,
|
||||
locale: Optional[str] = None,
|
||||
# description is not used by Discord for Message commands
|
||||
) -> Callable[[Callable[..., Any]], MessageCommand]:
|
||||
"""Decorator to create a MESSAGE context menu command."""
|
||||
attrs = {
|
||||
"name": name,
|
||||
"guild_ids": guild_ids,
|
||||
"default_member_permissions": default_member_permissions,
|
||||
"nsfw": nsfw,
|
||||
"name_localizations": name_localizations,
|
||||
"integration_types": integration_types,
|
||||
"contexts": contexts,
|
||||
"locale": locale,
|
||||
}
|
||||
attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw"]}
|
||||
return _app_command_decorator(MessageCommand, **attrs)
|
||||
|
||||
|
||||
def hybrid_command(
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
guild_ids: Optional[List[Snowflake]] = None,
|
||||
default_member_permissions: Optional[str] = None,
|
||||
nsfw: bool = False,
|
||||
name_localizations: Optional[Dict[str, str]] = None,
|
||||
description_localizations: Optional[Dict[str, str]] = None,
|
||||
integration_types: Optional[List[IntegrationType]] = None,
|
||||
contexts: Optional[List[InteractionContextType]] = None,
|
||||
*,
|
||||
guilds: bool = True,
|
||||
dms: bool = True,
|
||||
private_channels: bool = True,
|
||||
aliases: Optional[List[str]] = None, # Specific to prefix command aspect
|
||||
# Other prefix-specific options can be added here (e.g., help, brief)
|
||||
option_meta: Optional[Dict[str, OptionMetadata]] = None,
|
||||
locale: Optional[str] = None,
|
||||
) -> Callable[[Callable[..., Any]], HybridCommand]:
|
||||
"""
|
||||
Decorator to create a command that can be invoked as both a slash command
|
||||
and a traditional prefix-based command.
|
||||
Options for the slash command part are inferred from the function's type hints.
|
||||
"""
|
||||
if contexts is None:
|
||||
ctxs: List[InteractionContextType] = []
|
||||
if guilds:
|
||||
ctxs.append(InteractionContextType.GUILD)
|
||||
if dms:
|
||||
ctxs.append(InteractionContextType.BOT_DM)
|
||||
if private_channels:
|
||||
ctxs.append(InteractionContextType.PRIVATE_CHANNEL)
|
||||
if len(ctxs) != 3:
|
||||
contexts = ctxs
|
||||
attrs = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"guild_ids": guild_ids,
|
||||
"default_member_permissions": default_member_permissions,
|
||||
"nsfw": nsfw,
|
||||
"name_localizations": name_localizations,
|
||||
"description_localizations": description_localizations,
|
||||
"integration_types": integration_types,
|
||||
"contexts": contexts,
|
||||
"aliases": aliases or [], # Ensure aliases is a list
|
||||
"locale": locale,
|
||||
}
|
||||
# Filter out None values to avoid passing them as explicit None to command constructor
|
||||
# Keep 'nsfw' and 'aliases' as they have defaults (False, [])
|
||||
attrs = {
|
||||
k: v for k, v in attrs.items() if v is not None or k in ["nsfw", "aliases"]
|
||||
}
|
||||
return _app_command_decorator(HybridCommand, option_meta, **attrs)
|
||||
|
||||
|
||||
def subcommand(
|
||||
parent: AppCommandGroup, *d_args: Any, **d_kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], SlashCommand]:
|
||||
"""Create a subcommand under an existing :class:`AppCommandGroup`."""
|
||||
|
||||
d_kwargs.setdefault("parent", parent)
|
||||
return slash_command(*d_args, **d_kwargs)
|
||||
|
||||
|
||||
def group(
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Callable[[Optional[Callable[..., Any]]], AppCommandGroup]:
|
||||
"""Decorator to declare a top level :class:`AppCommandGroup`."""
|
||||
|
||||
def decorator(func: Optional[Callable[..., Any]] = None) -> AppCommandGroup:
|
||||
grp = AppCommandGroup(
|
||||
name=name,
|
||||
description=description,
|
||||
guild_ids=kwargs.get("guild_ids"),
|
||||
parent=kwargs.get("parent"),
|
||||
default_member_permissions=kwargs.get("default_member_permissions"),
|
||||
nsfw=kwargs.get("nsfw", False),
|
||||
name_localizations=kwargs.get("name_localizations"),
|
||||
description_localizations=kwargs.get("description_localizations"),
|
||||
integration_types=kwargs.get("integration_types"),
|
||||
contexts=kwargs.get("contexts"),
|
||||
)
|
||||
|
||||
if func is not None:
|
||||
setattr(func, "__app_command_object__", grp)
|
||||
return grp
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def subcommand_group(
|
||||
parent: AppCommandGroup,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Callable[[Optional[Callable[..., Any]]], AppCommandGroup]:
|
||||
"""Create a nested :class:`AppCommandGroup` under ``parent``."""
|
||||
|
||||
return parent.group(
|
||||
name=name,
|
||||
description=description,
|
||||
**kwargs,
|
||||
)
|
627
disagreement/ext/app_commands/handler.py
Normal file
627
disagreement/ext/app_commands/handler.py
Normal file
@ -0,0 +1,627 @@
|
||||
# disagreement/ext/app_commands/handler.py
|
||||
|
||||
import inspect
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Optional,
|
||||
List,
|
||||
Any,
|
||||
Tuple,
|
||||
Union,
|
||||
get_origin,
|
||||
get_args,
|
||||
Literal,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from disagreement.client import Client
|
||||
from disagreement.interactions import Interaction, ResolvedData, Snowflake
|
||||
from disagreement.enums import (
|
||||
ApplicationCommandType,
|
||||
ApplicationCommandOptionType,
|
||||
InteractionType,
|
||||
)
|
||||
from .commands import (
|
||||
AppCommand,
|
||||
SlashCommand,
|
||||
UserCommand,
|
||||
MessageCommand,
|
||||
AppCommandGroup,
|
||||
)
|
||||
from .context import AppCommandContext
|
||||
from disagreement.models import (
|
||||
User,
|
||||
Member,
|
||||
Role,
|
||||
Attachment,
|
||||
Message,
|
||||
) # For resolved data
|
||||
|
||||
# Channel models would also go here
|
||||
|
||||
# Placeholder for models not yet fully defined or imported
|
||||
if not TYPE_CHECKING:
|
||||
from disagreement.enums import (
|
||||
ApplicationCommandType,
|
||||
ApplicationCommandOptionType,
|
||||
InteractionType,
|
||||
)
|
||||
from .commands import (
|
||||
AppCommand,
|
||||
SlashCommand,
|
||||
UserCommand,
|
||||
MessageCommand,
|
||||
AppCommandGroup,
|
||||
)
|
||||
from .context import AppCommandContext
|
||||
|
||||
User = Any
|
||||
Member = Any
|
||||
Role = Any
|
||||
Attachment = Any
|
||||
Channel = Any
|
||||
Message = Any
|
||||
|
||||
|
||||
class AppCommandHandler:
|
||||
"""
|
||||
Manages application command registration, parsing, and dispatching.
|
||||
"""
|
||||
|
||||
def __init__(self, client: "Client"):
|
||||
self.client: "Client" = client
|
||||
# Store commands: key could be (name, type) for global, or (name, type, guild_id) for guild-specific
|
||||
# For simplicity, let's start with a flat structure and refine if needed for guild commands.
|
||||
# A more robust system might have separate dicts for global and guild commands.
|
||||
self._slash_commands: Dict[str, SlashCommand] = {}
|
||||
self._user_commands: Dict[str, UserCommand] = {}
|
||||
self._message_commands: Dict[str, MessageCommand] = {}
|
||||
self._app_command_groups: Dict[str, AppCommandGroup] = {}
|
||||
self._converter_registry: Dict[type, type] = {}
|
||||
|
||||
def add_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
|
||||
"""Adds an application command or a command group to the handler."""
|
||||
if isinstance(command, AppCommandGroup):
|
||||
if command.name in self._app_command_groups:
|
||||
raise ValueError(
|
||||
f"AppCommandGroup '{command.name}' is already registered."
|
||||
)
|
||||
self._app_command_groups[command.name] = command
|
||||
return
|
||||
|
||||
if isinstance(command, SlashCommand):
|
||||
if command.name in self._slash_commands:
|
||||
raise ValueError(
|
||||
f"SlashCommand '{command.name}' is already registered."
|
||||
)
|
||||
self._slash_commands[command.name] = command
|
||||
return
|
||||
|
||||
if isinstance(command, UserCommand):
|
||||
if command.name in self._user_commands:
|
||||
raise ValueError(f"UserCommand '{command.name}' is already registered.")
|
||||
self._user_commands[command.name] = command
|
||||
return
|
||||
|
||||
if isinstance(command, MessageCommand):
|
||||
if command.name in self._message_commands:
|
||||
raise ValueError(
|
||||
f"MessageCommand '{command.name}' is already registered."
|
||||
)
|
||||
self._message_commands[command.name] = command
|
||||
return
|
||||
|
||||
if isinstance(command, AppCommand):
|
||||
# Fallback for plain AppCommand objects
|
||||
if command.type == ApplicationCommandType.CHAT_INPUT:
|
||||
if command.name in self._slash_commands:
|
||||
raise ValueError(
|
||||
f"SlashCommand '{command.name}' is already registered."
|
||||
)
|
||||
self._slash_commands[command.name] = command # type: ignore
|
||||
elif command.type == ApplicationCommandType.USER:
|
||||
if command.name in self._user_commands:
|
||||
raise ValueError(
|
||||
f"UserCommand '{command.name}' is already registered."
|
||||
)
|
||||
self._user_commands[command.name] = command # type: ignore
|
||||
elif command.type == ApplicationCommandType.MESSAGE:
|
||||
if command.name in self._message_commands:
|
||||
raise ValueError(
|
||||
f"MessageCommand '{command.name}' is already registered."
|
||||
)
|
||||
self._message_commands[command.name] = command # type: ignore
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported command type: {command.type} for '{command.name}'"
|
||||
)
|
||||
else:
|
||||
raise TypeError("Can only add AppCommand or AppCommandGroup instances.")
|
||||
|
||||
def remove_command(
|
||||
self, name: str
|
||||
) -> Optional[Union["AppCommand", "AppCommandGroup"]]:
|
||||
"""Removes an application command or group by name."""
|
||||
if name in self._slash_commands:
|
||||
return self._slash_commands.pop(name)
|
||||
if name in self._user_commands:
|
||||
return self._user_commands.pop(name)
|
||||
if name in self._message_commands:
|
||||
return self._message_commands.pop(name)
|
||||
if name in self._app_command_groups:
|
||||
return self._app_command_groups.pop(name)
|
||||
return None
|
||||
|
||||
def register_converter(self, annotation: type, converter_cls: type) -> None:
|
||||
"""Register a custom converter class for a type annotation."""
|
||||
self._converter_registry[annotation] = converter_cls
|
||||
|
||||
def get_converter(self, annotation: type) -> Optional[type]:
|
||||
"""Retrieve a registered converter class for a type annotation."""
|
||||
return self._converter_registry.get(annotation)
|
||||
|
||||
def get_command(
|
||||
self,
|
||||
name: str,
|
||||
command_type: "ApplicationCommandType",
|
||||
interaction_options: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Optional["AppCommand"]:
|
||||
"""Retrieves a command of a specific type."""
|
||||
if command_type == ApplicationCommandType.CHAT_INPUT:
|
||||
if not interaction_options:
|
||||
return self._slash_commands.get(name)
|
||||
|
||||
# Handle subcommands/groups
|
||||
current_options = interaction_options
|
||||
target_command_or_group: Optional[Union[AppCommand, AppCommandGroup]] = (
|
||||
self._app_command_groups.get(name)
|
||||
)
|
||||
|
||||
if not target_command_or_group:
|
||||
return self._slash_commands.get(name)
|
||||
|
||||
final_command: Optional[AppCommand] = None
|
||||
|
||||
while current_options:
|
||||
opt_data = current_options[0]
|
||||
opt_name = opt_data.get("name")
|
||||
opt_type = (
|
||||
ApplicationCommandOptionType(opt_data["type"])
|
||||
if opt_data.get("type")
|
||||
else None
|
||||
)
|
||||
|
||||
if not opt_name or not isinstance(
|
||||
target_command_or_group, AppCommandGroup
|
||||
):
|
||||
break
|
||||
|
||||
next_target = target_command_or_group.get_command(opt_name)
|
||||
|
||||
if isinstance(next_target, AppCommand) and (
|
||||
opt_type == ApplicationCommandOptionType.SUB_COMMAND
|
||||
or not opt_data.get("options")
|
||||
):
|
||||
final_command = next_target
|
||||
break
|
||||
elif (
|
||||
isinstance(next_target, AppCommandGroup)
|
||||
and opt_type == ApplicationCommandOptionType.SUB_COMMAND_GROUP
|
||||
):
|
||||
target_command_or_group = next_target
|
||||
current_options = opt_data.get("options", [])
|
||||
if not current_options:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
return final_command
|
||||
|
||||
if command_type == ApplicationCommandType.USER:
|
||||
return self._user_commands.get(name)
|
||||
|
||||
if command_type == ApplicationCommandType.MESSAGE:
|
||||
return self._message_commands.get(name)
|
||||
|
||||
return None
|
||||
|
||||
async def _resolve_option_value(
|
||||
self,
|
||||
value: Any,
|
||||
expected_type: Any,
|
||||
resolved_data: Optional["ResolvedData"],
|
||||
guild_id: Optional["Snowflake"],
|
||||
) -> Any:
|
||||
"""
|
||||
Resolves an option value to the expected Python type using resolved_data.
|
||||
"""
|
||||
converter_cls = self.get_converter(expected_type)
|
||||
if converter_cls:
|
||||
try:
|
||||
init_params = inspect.signature(converter_cls.__init__).parameters
|
||||
if "client" in init_params:
|
||||
converter_instance = converter_cls(client=self.client) # type: ignore[arg-type]
|
||||
else:
|
||||
converter_instance = converter_cls()
|
||||
return await converter_instance.convert(None, value) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# This is a simplified resolver. A more robust one would use converters.
|
||||
if resolved_data:
|
||||
if expected_type is User or expected_type.__name__ == "User":
|
||||
return resolved_data.users.get(value) if resolved_data.users else None
|
||||
|
||||
if expected_type is Member or expected_type.__name__ == "Member":
|
||||
member_obj = (
|
||||
resolved_data.members.get(value) if resolved_data.members else None
|
||||
)
|
||||
if member_obj:
|
||||
if (
|
||||
hasattr(member_obj, "username")
|
||||
and not member_obj.username
|
||||
and resolved_data.users
|
||||
):
|
||||
user_obj = resolved_data.users.get(value)
|
||||
if user_obj:
|
||||
member_obj.username = user_obj.username
|
||||
member_obj.discriminator = user_obj.discriminator
|
||||
member_obj.avatar = user_obj.avatar
|
||||
member_obj.bot = user_obj.bot
|
||||
member_obj.user = user_obj # type: ignore[attr-defined]
|
||||
return member_obj
|
||||
return None
|
||||
if expected_type is Role or expected_type.__name__ == "Role":
|
||||
return resolved_data.roles.get(value) if resolved_data.roles else None
|
||||
if expected_type is Attachment or expected_type.__name__ == "Attachment":
|
||||
return (
|
||||
resolved_data.attachments.get(value)
|
||||
if resolved_data.attachments
|
||||
else None
|
||||
)
|
||||
if expected_type is Message or expected_type.__name__ == "Message":
|
||||
return (
|
||||
resolved_data.messages.get(value)
|
||||
if resolved_data.messages
|
||||
else None
|
||||
)
|
||||
if "Channel" in expected_type.__name__:
|
||||
return (
|
||||
resolved_data.channels.get(value)
|
||||
if resolved_data.channels
|
||||
else None
|
||||
)
|
||||
|
||||
# For basic types, Discord already sends them correctly (string, int, bool, float)
|
||||
if isinstance(value, expected_type):
|
||||
return value
|
||||
try: # Attempt direct conversion for basic types if Discord sent string for int/float/bool
|
||||
if expected_type is int:
|
||||
return int(value)
|
||||
if expected_type is float:
|
||||
return float(value)
|
||||
if expected_type is bool: # Discord sends true/false
|
||||
if isinstance(value, str):
|
||||
return value.lower() == "true"
|
||||
return bool(value)
|
||||
except (ValueError, TypeError):
|
||||
pass # Conversion failed
|
||||
return value # Return as is if no specific resolution or conversion applied
|
||||
|
||||
async def _resolve_value(
|
||||
self,
|
||||
value: Any,
|
||||
expected_type: Any,
|
||||
resolved_data: Optional["ResolvedData"],
|
||||
guild_id: Optional["Snowflake"],
|
||||
) -> Any:
|
||||
"""Public wrapper around ``_resolve_option_value`` used by tests."""
|
||||
|
||||
return await self._resolve_option_value(
|
||||
value=value,
|
||||
expected_type=expected_type,
|
||||
resolved_data=resolved_data,
|
||||
guild_id=guild_id,
|
||||
)
|
||||
|
||||
async def _parse_interaction_options(
|
||||
self,
|
||||
command_params: Dict[str, inspect.Parameter], # From command.params
|
||||
interaction_options: Optional[List[Dict[str, Any]]],
|
||||
resolved_data: Optional["ResolvedData"],
|
||||
guild_id: Optional["Snowflake"],
|
||||
) -> Tuple[List[Any], Dict[str, Any]]:
|
||||
"""
|
||||
Parses options from an interaction payload and maps them to command function arguments.
|
||||
"""
|
||||
args_list: List[Any] = []
|
||||
kwargs_dict: Dict[str, Any] = {}
|
||||
|
||||
if not interaction_options: # No options provided in interaction
|
||||
# Check if command has required params without defaults
|
||||
for name, param in command_params.items():
|
||||
if param.default == inspect.Parameter.empty:
|
||||
# This should ideally be caught by Discord if option is marked required
|
||||
raise ValueError(f"Missing required option: {name}")
|
||||
return args_list, kwargs_dict
|
||||
|
||||
# Create a dictionary of provided options by name for easier lookup
|
||||
provided_options: Dict[str, Any] = {
|
||||
opt["name"]: opt["value"] for opt in interaction_options if "value" in opt
|
||||
}
|
||||
|
||||
for name, param in command_params.items():
|
||||
if name in provided_options:
|
||||
raw_value = provided_options[name]
|
||||
expected_type = (
|
||||
param.annotation
|
||||
if param.annotation != inspect.Parameter.empty
|
||||
else str
|
||||
)
|
||||
|
||||
# Handle Optional[T]
|
||||
origin_type = get_origin(expected_type)
|
||||
if origin_type is Union:
|
||||
union_args = get_args(expected_type)
|
||||
# Assuming Optional[T] is Union[T, NoneType]
|
||||
non_none_types = [t for t in union_args if t is not type(None)]
|
||||
if len(non_none_types) == 1:
|
||||
expected_type = non_none_types[0]
|
||||
# Else, complex Union, might need more sophisticated handling or default to raw_value/str
|
||||
elif origin_type is Literal:
|
||||
literal_args = get_args(expected_type)
|
||||
if literal_args:
|
||||
expected_type = type(literal_args[0])
|
||||
else:
|
||||
expected_type = str
|
||||
|
||||
resolved_value = await self._resolve_option_value(
|
||||
raw_value, expected_type, resolved_data, guild_id
|
||||
)
|
||||
|
||||
if (
|
||||
param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
or param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
):
|
||||
kwargs_dict[name] = resolved_value
|
||||
# Note: Slash commands don't map directly to *args. All options are named.
|
||||
# So, we'll primarily use kwargs_dict and then construct args_list based on param order if needed,
|
||||
# but Discord sends named options, so direct kwarg usage is more natural.
|
||||
elif param.default != inspect.Parameter.empty:
|
||||
kwargs_dict[name] = param.default
|
||||
else:
|
||||
# Required parameter not provided by Discord - this implies an issue with command definition
|
||||
# or Discord's validation, as Discord should enforce required options.
|
||||
raise ValueError(
|
||||
f"Required option '{name}' not found in interaction payload."
|
||||
)
|
||||
|
||||
# Populate args_list based on the order in command_params for positional arguments
|
||||
# This assumes that all args that are not keyword-only are passed positionally if present in kwargs_dict
|
||||
for name, param in command_params.items():
|
||||
if param.kind == inspect.Parameter.POSITIONAL_ONLY or (
|
||||
param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
and name in kwargs_dict
|
||||
):
|
||||
if name in kwargs_dict: # Ensure it was resolved or had a default
|
||||
args_list.append(kwargs_dict[name])
|
||||
# If it was POSITIONAL_ONLY and not in kwargs_dict, it's an error (already raised)
|
||||
elif param.kind == inspect.Parameter.VAR_POSITIONAL: # *args
|
||||
# Slash commands don't map to *args well. This would be empty.
|
||||
pass
|
||||
|
||||
# Filter kwargs_dict to only include actual KEYWORD_ONLY or POSITIONAL_OR_KEYWORD params
|
||||
# that were not used for args_list (if strict positional/keyword separation is desired).
|
||||
# For slash commands, it's simpler to pass all resolved named options as kwargs.
|
||||
final_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs_dict.items()
|
||||
if k in command_params
|
||||
and command_params[k].kind != inspect.Parameter.POSITIONAL_ONLY
|
||||
}
|
||||
|
||||
# For simplicity with slash commands, let's assume all resolved options are passed via kwargs
|
||||
# and the command signature is primarily (self, ctx, **options) or (ctx, **options)
|
||||
# or (self, ctx, option1, option2) where names match.
|
||||
# The AppCommand.invoke will handle passing them.
|
||||
# The current args_list and final_kwargs might be redundant if invoke just uses **final_kwargs.
|
||||
# Let's return kwargs_dict directly for now, and AppCommand.invoke can map them.
|
||||
|
||||
return [], kwargs_dict # Return empty args, all in kwargs for now.
|
||||
|
||||
async def dispatch_app_command_error(
|
||||
self, context: "AppCommandContext", error: Exception
|
||||
) -> None:
|
||||
"""Dispatches an app command error to the client if implemented."""
|
||||
if hasattr(self.client, "on_app_command_error"):
|
||||
await self.client.on_app_command_error(context, error)
|
||||
|
||||
async def process_interaction(self, interaction: "Interaction") -> None:
|
||||
"""Processes an incoming interaction."""
|
||||
if interaction.type == InteractionType.MODAL_SUBMIT:
|
||||
callback = getattr(self.client, "on_modal_submit", None)
|
||||
if callback is not None:
|
||||
from typing import Awaitable, Callable, cast
|
||||
|
||||
await cast(Callable[["Interaction"], Awaitable[None]], callback)(
|
||||
interaction
|
||||
)
|
||||
return
|
||||
|
||||
if interaction.type == InteractionType.APPLICATION_COMMAND_AUTOCOMPLETE:
|
||||
callback = getattr(self.client, "on_autocomplete", None)
|
||||
if callback is not None:
|
||||
from typing import Awaitable, Callable, cast
|
||||
|
||||
await cast(Callable[["Interaction"], Awaitable[None]], callback)(
|
||||
interaction
|
||||
)
|
||||
return
|
||||
|
||||
if interaction.type != InteractionType.APPLICATION_COMMAND:
|
||||
return
|
||||
|
||||
if not interaction.data or not interaction.data.name:
|
||||
from .context import AppCommandContext
|
||||
|
||||
ctx = AppCommandContext(
|
||||
bot=self.client, interaction=interaction, command=None
|
||||
)
|
||||
await ctx.send("Command not found.", ephemeral=True)
|
||||
return
|
||||
|
||||
command_name = interaction.data.name
|
||||
command_type = interaction.data.type or ApplicationCommandType.CHAT_INPUT
|
||||
command = self.get_command(
|
||||
command_name,
|
||||
command_type,
|
||||
interaction.data.options if interaction.data else None,
|
||||
)
|
||||
|
||||
if not command:
|
||||
from .context import AppCommandContext
|
||||
|
||||
ctx = AppCommandContext(
|
||||
bot=self.client, interaction=interaction, command=None
|
||||
)
|
||||
await ctx.send(f"Command '{command_name}' not found.", ephemeral=True)
|
||||
return
|
||||
|
||||
# Create context
|
||||
from .context import AppCommandContext # Ensure AppCommandContext is available
|
||||
|
||||
ctx = AppCommandContext(
|
||||
bot=self.client, interaction=interaction, command=command
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare arguments for the command callback
|
||||
# Skip 'self' and 'ctx' from command.params for parsing interaction options
|
||||
params_to_parse = {
|
||||
name: param
|
||||
for name, param in command.params.items()
|
||||
if name not in ("self", "ctx")
|
||||
}
|
||||
|
||||
if command.type in (
|
||||
ApplicationCommandType.USER,
|
||||
ApplicationCommandType.MESSAGE,
|
||||
):
|
||||
# Context menu commands provide a target_id. Resolve and pass it
|
||||
args = []
|
||||
kwargs = {}
|
||||
if params_to_parse and interaction.data and interaction.data.target_id:
|
||||
first_param = next(iter(params_to_parse.values()))
|
||||
expected = (
|
||||
first_param.annotation
|
||||
if first_param.annotation != inspect.Parameter.empty
|
||||
else str
|
||||
)
|
||||
resolved = await self._resolve_option_value(
|
||||
interaction.data.target_id,
|
||||
expected,
|
||||
interaction.data.resolved,
|
||||
interaction.guild_id,
|
||||
)
|
||||
if first_param.kind in (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
):
|
||||
args.append(resolved)
|
||||
else:
|
||||
kwargs[first_param.name] = resolved
|
||||
|
||||
await command.invoke(ctx, *args, **kwargs)
|
||||
else:
|
||||
parsed_args, parsed_kwargs = await self._parse_interaction_options(
|
||||
command_params=params_to_parse,
|
||||
interaction_options=interaction.data.options,
|
||||
resolved_data=interaction.data.resolved,
|
||||
guild_id=interaction.guild_id,
|
||||
)
|
||||
|
||||
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error invoking app command '{command.name}': {e}")
|
||||
await self.dispatch_app_command_error(ctx, e)
|
||||
# else:
|
||||
# # Default error reply if no handler on client
|
||||
# try:
|
||||
# await ctx.send(f"An error occurred: {e}", ephemeral=True)
|
||||
# except Exception as send_e:
|
||||
# print(f"Failed to send error message for app command: {send_e}")
|
||||
|
||||
async def sync_commands(
|
||||
self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None
|
||||
) -> None:
|
||||
"""
|
||||
Synchronizes (registers/updates) all application commands with Discord.
|
||||
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
|
||||
"""
|
||||
commands_to_sync: List[Dict[str, Any]] = []
|
||||
|
||||
# Collect commands based on scope (global or specific guild)
|
||||
# This needs to be more sophisticated to handle guild_ids on commands/groups
|
||||
|
||||
source_commands = (
|
||||
list(self._slash_commands.values())
|
||||
+ list(self._user_commands.values())
|
||||
+ list(self._message_commands.values())
|
||||
+ list(self._app_command_groups.values())
|
||||
)
|
||||
|
||||
for cmd_or_group in source_commands:
|
||||
# Determine if this command/group should be synced for the current scope
|
||||
is_guild_specific_command = (
|
||||
cmd_or_group.guild_ids is not None and len(cmd_or_group.guild_ids) > 0
|
||||
)
|
||||
|
||||
if guild_id: # Syncing for a specific guild
|
||||
# Skip if not a guild-specific command OR if it's for a different guild
|
||||
if not is_guild_specific_command or (
|
||||
cmd_or_group.guild_ids is not None
|
||||
and guild_id not in cmd_or_group.guild_ids
|
||||
):
|
||||
continue
|
||||
else: # Syncing global commands
|
||||
if is_guild_specific_command:
|
||||
continue # Skip guild-specific commands when syncing global
|
||||
|
||||
# Use the to_dict() method from AppCommand or AppCommandGroup
|
||||
try:
|
||||
payload = cmd_or_group.to_dict()
|
||||
commands_to_sync.append(payload)
|
||||
except AttributeError:
|
||||
print(
|
||||
f"Warning: Command or group '{cmd_or_group.name}' does not have a to_dict() method. Skipping."
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error converting command/group '{cmd_or_group.name}' to dict: {e}. Skipping."
|
||||
)
|
||||
|
||||
if not commands_to_sync:
|
||||
print(
|
||||
f"No commands to sync for {'guild ' + str(guild_id) if guild_id else 'global'} scope."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if guild_id:
|
||||
print(
|
||||
f"Syncing {len(commands_to_sync)} commands for guild {guild_id}..."
|
||||
)
|
||||
await self.client._http.bulk_overwrite_guild_application_commands(
|
||||
application_id, guild_id, commands_to_sync
|
||||
)
|
||||
else:
|
||||
print(f"Syncing {len(commands_to_sync)} global commands...")
|
||||
await self.client._http.bulk_overwrite_global_application_commands(
|
||||
application_id, commands_to_sync
|
||||
)
|
||||
print("Command sync successful.")
|
||||
except Exception as e:
|
||||
print(f"Error syncing application commands: {e}")
|
||||
# Consider re-raising or specific error handling
|
49
disagreement/ext/commands/__init__.py
Normal file
49
disagreement/ext/commands/__init__.py
Normal file
@ -0,0 +1,49 @@
|
||||
# disagreement/ext/commands/__init__.py
|
||||
|
||||
"""
|
||||
disagreement.ext.commands - A command framework extension for the Disagreement library.
|
||||
"""
|
||||
|
||||
from .cog import Cog
|
||||
from .core import (
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandHandler,
|
||||
) # CommandHandler might be internal
|
||||
from .decorators import command, listener, check, check_any, cooldown
|
||||
from .errors import (
|
||||
CommandError,
|
||||
CommandNotFound,
|
||||
BadArgument,
|
||||
MissingRequiredArgument,
|
||||
ArgumentParsingError,
|
||||
CheckFailure,
|
||||
CheckAnyFailure,
|
||||
CommandOnCooldown,
|
||||
CommandInvokeError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Cog
|
||||
"Cog",
|
||||
# Core
|
||||
"Command",
|
||||
"CommandContext",
|
||||
# "CommandHandler", # Usually not part of public API for direct use by bot devs
|
||||
# Decorators
|
||||
"command",
|
||||
"listener",
|
||||
"check",
|
||||
"check_any",
|
||||
"cooldown",
|
||||
# Errors
|
||||
"CommandError",
|
||||
"CommandNotFound",
|
||||
"BadArgument",
|
||||
"MissingRequiredArgument",
|
||||
"ArgumentParsingError",
|
||||
"CheckFailure",
|
||||
"CheckAnyFailure",
|
||||
"CommandOnCooldown",
|
||||
"CommandInvokeError",
|
||||
]
|
155
disagreement/ext/commands/cog.py
Normal file
155
disagreement/ext/commands/cog.py
Normal file
@ -0,0 +1,155 @@
|
||||
# disagreement/ext/commands/cog.py
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, List, Tuple, Callable, Awaitable, Any, Dict, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from disagreement.client import Client
|
||||
from .core import Command
|
||||
from disagreement.ext.app_commands.commands import (
|
||||
AppCommand,
|
||||
AppCommandGroup,
|
||||
) # Added
|
||||
else: # pragma: no cover - runtime imports for isinstance checks
|
||||
from disagreement.ext.app_commands.commands import AppCommand, AppCommandGroup
|
||||
|
||||
# EventDispatcher might be needed if cogs register listeners directly
|
||||
# from disagreement.event_dispatcher import EventDispatcher
|
||||
|
||||
|
||||
class Cog:
|
||||
"""
|
||||
The base class for cogs, which are collections of commands and listeners.
|
||||
"""
|
||||
|
||||
def __init__(self, client: "Client"):
|
||||
self._client: "Client" = client
|
||||
self._cog_name: str = self.__class__.__name__
|
||||
self._commands: Dict[str, "Command"] = {}
|
||||
self._listeners: List[Tuple[str, Callable[..., Awaitable[None]]]] = []
|
||||
self._app_commands_and_groups: List[Union["AppCommand", "AppCommandGroup"]] = (
|
||||
[]
|
||||
) # Added
|
||||
|
||||
# Discover commands and listeners defined in this cog instance
|
||||
self._inject()
|
||||
|
||||
@property
|
||||
def client(self) -> "Client":
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def cog_name(self) -> str:
|
||||
return self._cog_name
|
||||
|
||||
def _inject(self) -> None:
|
||||
"""
|
||||
Called to discover and prepare commands and listeners within this cog.
|
||||
This is typically called by the CommandHandler when adding the cog.
|
||||
"""
|
||||
# Clear any previously injected state (e.g., if re-injecting)
|
||||
self._commands.clear()
|
||||
self._listeners.clear()
|
||||
self._app_commands_and_groups.clear() # Added
|
||||
|
||||
for member_name, member in inspect.getmembers(self):
|
||||
if hasattr(member, "__command_object__"):
|
||||
# This is a prefix or hybrid command object
|
||||
cmd: "Command" = getattr(member, "__command_object__")
|
||||
cmd.cog = self # Assign the cog instance to the command
|
||||
if cmd.name in self._commands:
|
||||
# This should ideally be caught earlier or handled by CommandHandler
|
||||
print(
|
||||
f"Warning: Duplicate command name '{cmd.name}' in cog '{self.cog_name}'. Overwriting."
|
||||
)
|
||||
self._commands[cmd.name.lower()] = cmd
|
||||
# Also register aliases
|
||||
for alias in cmd.aliases:
|
||||
self._commands[alias.lower()] = cmd
|
||||
|
||||
# If this command is also an application command (HybridCommand)
|
||||
if isinstance(cmd, (AppCommand, AppCommandGroup)):
|
||||
self._app_commands_and_groups.append(cmd)
|
||||
|
||||
elif hasattr(member, "__app_command_object__"): # Added for app commands
|
||||
app_cmd_obj = getattr(member, "__app_command_object__")
|
||||
if isinstance(app_cmd_obj, (AppCommand, AppCommandGroup)):
|
||||
if isinstance(app_cmd_obj, AppCommand):
|
||||
app_cmd_obj.cog = self # Associate cog
|
||||
# For AppCommandGroup, its commands will have cog set individually if they are AppCommands
|
||||
self._app_commands_and_groups.append(app_cmd_obj)
|
||||
else:
|
||||
print(
|
||||
f"Warning: Member '{member_name}' in cog '{self.cog_name}' has '__app_command_object__' but it's not an AppCommand or AppCommandGroup."
|
||||
)
|
||||
|
||||
elif isinstance(member, (AppCommand, AppCommandGroup)):
|
||||
if isinstance(member, AppCommand):
|
||||
member.cog = self
|
||||
self._app_commands_and_groups.append(member)
|
||||
|
||||
elif hasattr(member, "__listener_name__"):
|
||||
# This is a method decorated with @commands.Cog.listener or @commands.listener
|
||||
if not inspect.iscoroutinefunction(member):
|
||||
# Decorator should have caught this, but double check
|
||||
print(
|
||||
f"Warning: Listener '{member_name}' in cog '{self.cog_name}' is not a coroutine. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
event_name: str = getattr(member, "__listener_name__")
|
||||
# The callback needs to be the bound method from this cog instance
|
||||
self._listeners.append((event_name, member))
|
||||
|
||||
def _eject(self) -> None:
|
||||
"""
|
||||
Called when the cog is being removed.
|
||||
The CommandHandler will handle unregistering commands/listeners.
|
||||
This method is for any cog-specific cleanup before that.
|
||||
"""
|
||||
# For now, just clear local collections. Actual unregistration is external.
|
||||
self._commands.clear()
|
||||
self._listeners.clear()
|
||||
self._app_commands_and_groups.clear() # Added
|
||||
|
||||
def get_commands(self) -> List["Command"]:
|
||||
"""Returns a list of commands in this cog."""
|
||||
# Avoid duplicates if aliases point to the same command object
|
||||
return list(dict.fromkeys(self._commands.values()))
|
||||
|
||||
def get_listeners(self) -> List[Tuple[str, Callable[..., Awaitable[None]]]]:
|
||||
"""Returns a list of (event_name, callback) tuples for listeners in this cog."""
|
||||
return self._listeners
|
||||
|
||||
def get_app_commands_and_groups(
|
||||
self,
|
||||
) -> List[Union["AppCommand", "AppCommandGroup"]]:
|
||||
"""Returns a list of application commands and groups in this cog."""
|
||||
return self._app_commands_and_groups
|
||||
|
||||
async def cog_load(self) -> None:
|
||||
"""
|
||||
A special method that is called when the cog is loaded.
|
||||
This is a good place for any asynchronous setup.
|
||||
Subclasses should override this if they need async setup.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def cog_unload(self) -> None:
|
||||
"""
|
||||
A special method that is called when the cog is unloaded.
|
||||
This is a good place for any asynchronous cleanup.
|
||||
Subclasses should override this if they need async cleanup.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Example of how a listener might be defined within a Cog using the decorator
|
||||
# from .decorators import listener # Would be imported at module level
|
||||
#
|
||||
# @listener(name="ON_MESSAGE_CREATE_CUSTOM") # Explicit name
|
||||
# async def on_my_event(self, message: 'Message'):
|
||||
# print(f"Cog '{self.cog_name}' received event with message: {message.content}")
|
||||
#
|
||||
# @listener() # Name derived from method: on_ready
|
||||
# async def on_ready(self):
|
||||
# print(f"Cog '{self.cog_name}' is ready.")
|
175
disagreement/ext/commands/converters.py
Normal file
175
disagreement/ext/commands/converters.py
Normal file
@ -0,0 +1,175 @@
|
||||
# disagreement/ext/commands/converters.py
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic
|
||||
from abc import ABC, abstractmethod
|
||||
import re
|
||||
|
||||
from .errors import BadArgument
|
||||
from disagreement.models import Member, Guild, Role
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import CommandContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Converter(ABC, Generic[T]):
|
||||
"""
|
||||
Base class for custom command argument converters.
|
||||
Subclasses must implement the `convert` method.
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> T:
|
||||
"""
|
||||
Converts the argument to the desired type.
|
||||
|
||||
Args:
|
||||
ctx: The invocation context.
|
||||
argument: The string argument to convert.
|
||||
|
||||
Returns:
|
||||
The converted argument.
|
||||
|
||||
Raises:
|
||||
BadArgument: If the conversion fails.
|
||||
"""
|
||||
raise NotImplementedError("Converter subclass must implement convert method.")
|
||||
|
||||
|
||||
# --- Built-in Type Converters ---
|
||||
|
||||
|
||||
class IntConverter(Converter[int]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> int:
|
||||
try:
|
||||
return int(argument)
|
||||
except ValueError:
|
||||
raise BadArgument(f"'{argument}' is not a valid integer.")
|
||||
|
||||
|
||||
class FloatConverter(Converter[float]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> float:
|
||||
try:
|
||||
return float(argument)
|
||||
except ValueError:
|
||||
raise BadArgument(f"'{argument}' is not a valid number.")
|
||||
|
||||
|
||||
class BoolConverter(Converter[bool]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> bool:
|
||||
lowered = argument.lower()
|
||||
if lowered in ("yes", "y", "true", "t", "1", "on", "enable", "enabled"):
|
||||
return True
|
||||
elif lowered in ("no", "n", "false", "f", "0", "off", "disable", "disabled"):
|
||||
return False
|
||||
raise BadArgument(f"'{argument}' is not a valid boolean-like value.")
|
||||
|
||||
|
||||
class StringConverter(Converter[str]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> str:
|
||||
# For basic string, no conversion is needed, but this provides a consistent interface
|
||||
return argument
|
||||
|
||||
|
||||
# --- Discord Model Converters ---
|
||||
|
||||
|
||||
class MemberConverter(Converter["Member"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "Member":
|
||||
if not ctx.message.guild_id:
|
||||
raise BadArgument("Member converter requires guild context.")
|
||||
|
||||
match = re.match(r"<@!?(\d+)>$", argument)
|
||||
member_id = match.group(1) if match else argument
|
||||
|
||||
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||
if guild:
|
||||
member = guild.get_member(member_id)
|
||||
if member:
|
||||
return member
|
||||
|
||||
member = await ctx.bot.fetch_member(ctx.message.guild_id, member_id)
|
||||
if member:
|
||||
return member
|
||||
raise BadArgument(f"Member '{argument}' not found.")
|
||||
|
||||
|
||||
class RoleConverter(Converter["Role"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "Role":
|
||||
if not ctx.message.guild_id:
|
||||
raise BadArgument("Role converter requires guild context.")
|
||||
|
||||
match = re.match(r"<@&(?P<id>\d+)>$", argument)
|
||||
role_id = match.group("id") if match else argument
|
||||
|
||||
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||
if guild:
|
||||
role = guild.get_role(role_id)
|
||||
if role:
|
||||
return role
|
||||
|
||||
role = await ctx.bot.fetch_role(ctx.message.guild_id, role_id)
|
||||
if role:
|
||||
return role
|
||||
raise BadArgument(f"Role '{argument}' not found.")
|
||||
|
||||
|
||||
class GuildConverter(Converter["Guild"]):
|
||||
async def convert(self, ctx: "CommandContext", argument: str) -> "Guild":
|
||||
guild_id = argument.strip("<>") # allow <id> style
|
||||
|
||||
guild = ctx.bot.get_guild(guild_id)
|
||||
if guild:
|
||||
return guild
|
||||
|
||||
guild = await ctx.bot.fetch_guild(guild_id)
|
||||
if guild:
|
||||
return guild
|
||||
raise BadArgument(f"Guild '{argument}' not found.")
|
||||
|
||||
|
||||
# Default converters mapping
|
||||
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||
int: IntConverter(),
|
||||
float: FloatConverter(),
|
||||
bool: BoolConverter(),
|
||||
str: StringConverter(),
|
||||
Member: MemberConverter(),
|
||||
Guild: GuildConverter(),
|
||||
Role: RoleConverter(),
|
||||
# User: UserConverter(), # Add when User model and converter are ready
|
||||
}
|
||||
|
||||
|
||||
async def run_converters(ctx: "CommandContext", annotation: Any, argument: str) -> Any:
|
||||
"""
|
||||
Attempts to run a converter for the given annotation and argument.
|
||||
"""
|
||||
converter = DEFAULT_CONVERTERS.get(annotation)
|
||||
if converter:
|
||||
return await converter.convert(ctx, argument)
|
||||
|
||||
# If no direct converter, check if annotation itself is a Converter subclass
|
||||
if inspect.isclass(annotation) and issubclass(annotation, Converter):
|
||||
try:
|
||||
instance = annotation() # type: ignore
|
||||
return await instance.convert(ctx, argument)
|
||||
except Exception as e: # Catch instantiation errors or other issues
|
||||
raise BadArgument(
|
||||
f"Failed to use custom converter {annotation.__name__}: {e}"
|
||||
)
|
||||
|
||||
# If it's a custom class that's not a Converter, we can't handle it by default
|
||||
# Or if it's a complex type hint like Union, Optional, Literal etc.
|
||||
# This part needs more advanced logic for those.
|
||||
|
||||
# For now, if no specific converter, and it's not 'str', raise error or return as str?
|
||||
# Let's be strict for now if an annotation is given but no converter found.
|
||||
if annotation is not str and annotation is not inspect.Parameter.empty:
|
||||
raise BadArgument(f"No converter found for type annotation '{annotation}'.")
|
||||
|
||||
return argument # Default to string if no annotation or annotation is str
|
||||
|
||||
|
||||
# Need to import inspect for the run_converters function
|
||||
import inspect
|
490
disagreement/ext/commands/core.py
Normal file
490
disagreement/ext/commands/core.py
Normal file
@ -0,0 +1,490 @@
|
||||
# disagreement/ext/commands/core.py
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Optional,
|
||||
List,
|
||||
Dict,
|
||||
Any,
|
||||
Union,
|
||||
Callable,
|
||||
Awaitable,
|
||||
Tuple,
|
||||
get_origin,
|
||||
get_args,
|
||||
)
|
||||
|
||||
from .view import StringView
|
||||
from .errors import (
|
||||
CommandError,
|
||||
CommandNotFound,
|
||||
BadArgument,
|
||||
MissingRequiredArgument,
|
||||
ArgumentParsingError,
|
||||
CheckFailure,
|
||||
CommandInvokeError,
|
||||
)
|
||||
from .converters import run_converters, DEFAULT_CONVERTERS, Converter
|
||||
from .cog import Cog
|
||||
from disagreement.typing import Typing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import Message, User
|
||||
|
||||
|
||||
class Command:
|
||||
"""
|
||||
Represents a bot command.
|
||||
|
||||
Attributes:
|
||||
name (str): The primary name of the command.
|
||||
callback (Callable[..., Awaitable[None]]): The coroutine function to execute.
|
||||
aliases (List[str]): Alternative names for the command.
|
||||
brief (Optional[str]): A short description for help commands.
|
||||
description (Optional[str]): A longer description for help commands.
|
||||
cog (Optional['Cog']): Reference to the Cog this command belongs to.
|
||||
params (Dict[str, inspect.Parameter]): Cached parameters of the callback.
|
||||
"""
|
||||
|
||||
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise TypeError("Command callback must be a coroutine function.")
|
||||
|
||||
self.callback: Callable[..., Awaitable[None]] = callback
|
||||
self.name: str = attrs.get("name", callback.__name__)
|
||||
self.aliases: List[str] = attrs.get("aliases", [])
|
||||
self.brief: Optional[str] = attrs.get("brief")
|
||||
self.description: Optional[str] = attrs.get("description") or callback.__doc__
|
||||
self.cog: Optional["Cog"] = attrs.get("cog")
|
||||
|
||||
self.params = inspect.signature(callback).parameters
|
||||
self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = []
|
||||
if hasattr(callback, "__command_checks__"):
|
||||
self.checks.extend(getattr(callback, "__command_checks__"))
|
||||
|
||||
def add_check(
|
||||
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
) -> None:
|
||||
self.checks.append(predicate)
|
||||
|
||||
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
|
||||
from .errors import CheckFailure
|
||||
|
||||
for predicate in self.checks:
|
||||
result = predicate(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if not result:
|
||||
raise CheckFailure("Check predicate failed.")
|
||||
|
||||
if self.cog:
|
||||
await self.callback(self.cog, ctx, *args, **kwargs)
|
||||
else:
|
||||
await self.callback(ctx, *args, **kwargs)
|
||||
|
||||
|
||||
class CommandContext:
|
||||
"""
|
||||
Represents the context in which a command is being invoked.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message: "Message",
|
||||
bot: "Client",
|
||||
prefix: str,
|
||||
command: "Command",
|
||||
invoked_with: str,
|
||||
args: Optional[List[Any]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
cog: Optional["Cog"] = None,
|
||||
):
|
||||
self.message: "Message" = message
|
||||
self.bot: "Client" = bot
|
||||
self.prefix: str = prefix
|
||||
self.command: "Command" = command
|
||||
self.invoked_with: str = invoked_with
|
||||
self.args: List[Any] = args or []
|
||||
self.kwargs: Dict[str, Any] = kwargs or {}
|
||||
self.cog: Optional["Cog"] = cog
|
||||
|
||||
self.author: "User" = message.author
|
||||
|
||||
async def reply(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
mention_author: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> "Message":
|
||||
"""Replies to the invoking message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
content: str
|
||||
The content to send.
|
||||
mention_author: Optional[bool]
|
||||
Whether to mention the author in the reply. If ``None`` the
|
||||
client's :attr:`mention_replies` value is used.
|
||||
"""
|
||||
|
||||
allowed_mentions = kwargs.pop("allowed_mentions", None)
|
||||
if mention_author is None:
|
||||
mention_author = getattr(self.bot, "mention_replies", False)
|
||||
|
||||
if allowed_mentions is None:
|
||||
allowed_mentions = {"replied_user": mention_author}
|
||||
else:
|
||||
allowed_mentions = dict(allowed_mentions)
|
||||
allowed_mentions.setdefault("replied_user", mention_author)
|
||||
|
||||
return await self.bot.send_message(
|
||||
channel_id=self.message.channel_id,
|
||||
content=content,
|
||||
message_reference={
|
||||
"message_id": self.message.id,
|
||||
"channel_id": self.message.channel_id,
|
||||
"guild_id": self.message.guild_id,
|
||||
},
|
||||
allowed_mentions=allowed_mentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def send(self, content: str, **kwargs: Any) -> "Message":
|
||||
return await self.bot.send_message(
|
||||
channel_id=self.message.channel_id, content=content, **kwargs
|
||||
)
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
message: Union[str, "Message"],
|
||||
*,
|
||||
content: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> "Message":
|
||||
"""Edits a message previously sent by the bot."""
|
||||
|
||||
message_id = message if isinstance(message, str) else message.id
|
||||
return await self.bot.edit_message(
|
||||
channel_id=self.message.channel_id,
|
||||
message_id=message_id,
|
||||
content=content,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def typing(self) -> "Typing":
|
||||
"""Return a typing context manager for this context's channel."""
|
||||
|
||||
return self.bot.typing(self.message.channel_id)
|
||||
|
||||
|
||||
class CommandHandler:
|
||||
"""
|
||||
Manages command registration, parsing, and dispatching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: "Client",
|
||||
prefix: Union[
|
||||
str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
|
||||
],
|
||||
):
|
||||
self.client: "Client" = client
|
||||
self.prefix: Union[
|
||||
str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
|
||||
] = prefix
|
||||
self.commands: Dict[str, Command] = {}
|
||||
self.cogs: Dict[str, "Cog"] = {}
|
||||
|
||||
from .help import HelpCommand
|
||||
|
||||
self.add_command(HelpCommand(self))
|
||||
|
||||
def add_command(self, command: Command) -> None:
|
||||
if command.name in self.commands:
|
||||
raise ValueError(f"Command '{command.name}' is already registered.")
|
||||
|
||||
self.commands[command.name.lower()] = command
|
||||
for alias in command.aliases:
|
||||
if alias in self.commands:
|
||||
print(
|
||||
f"Warning: Alias '{alias}' for command '{command.name}' conflicts with an existing command or alias."
|
||||
)
|
||||
self.commands[alias.lower()] = command
|
||||
|
||||
def remove_command(self, name: str) -> Optional[Command]:
|
||||
command = self.commands.pop(name.lower(), None)
|
||||
if command:
|
||||
for alias in command.aliases:
|
||||
self.commands.pop(alias.lower(), None)
|
||||
return command
|
||||
|
||||
def get_command(self, name: str) -> Optional[Command]:
|
||||
return self.commands.get(name.lower())
|
||||
|
||||
def add_cog(self, cog_to_add: "Cog") -> None:
|
||||
if not isinstance(cog_to_add, Cog):
|
||||
raise TypeError("Argument must be a subclass of Cog.")
|
||||
|
||||
if cog_to_add.cog_name in self.cogs:
|
||||
raise ValueError(
|
||||
f"Cog with name '{cog_to_add.cog_name}' is already registered."
|
||||
)
|
||||
|
||||
self.cogs[cog_to_add.cog_name] = cog_to_add
|
||||
|
||||
for cmd in cog_to_add.get_commands():
|
||||
self.add_command(cmd)
|
||||
|
||||
if hasattr(self.client, "_event_dispatcher"):
|
||||
for event_name, callback in cog_to_add.get_listeners():
|
||||
self.client._event_dispatcher.register(event_name.upper(), callback)
|
||||
else:
|
||||
print(
|
||||
f"Warning: Client does not have '_event_dispatcher'. Listeners for cog '{cog_to_add.cog_name}' not registered."
|
||||
)
|
||||
|
||||
if hasattr(cog_to_add, "cog_load") and inspect.iscoroutinefunction(
|
||||
cog_to_add.cog_load
|
||||
):
|
||||
asyncio.create_task(cog_to_add.cog_load())
|
||||
|
||||
print(f"Cog '{cog_to_add.cog_name}' added.")
|
||||
|
||||
def remove_cog(self, cog_name: str) -> Optional["Cog"]:
|
||||
cog_to_remove = self.cogs.pop(cog_name, None)
|
||||
if cog_to_remove:
|
||||
for cmd in cog_to_remove.get_commands():
|
||||
self.remove_command(cmd.name)
|
||||
|
||||
if hasattr(self.client, "_event_dispatcher"):
|
||||
for event_name, callback in cog_to_remove.get_listeners():
|
||||
print(
|
||||
f"Note: Listener '{callback.__name__}' for event '{event_name}' from cog '{cog_name}' needs manual unregistration logic in EventDispatcher."
|
||||
)
|
||||
|
||||
if hasattr(cog_to_remove, "cog_unload") and inspect.iscoroutinefunction(
|
||||
cog_to_remove.cog_unload
|
||||
):
|
||||
asyncio.create_task(cog_to_remove.cog_unload())
|
||||
|
||||
cog_to_remove._eject()
|
||||
print(f"Cog '{cog_name}' removed.")
|
||||
return cog_to_remove
|
||||
|
||||
async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
|
||||
if callable(self.prefix):
|
||||
if inspect.iscoroutinefunction(self.prefix):
|
||||
return await self.prefix(self.client, message)
|
||||
else:
|
||||
return self.prefix(self.client, message) # type: ignore
|
||||
return self.prefix
|
||||
|
||||
async def _parse_arguments(
|
||||
self, command: Command, ctx: CommandContext, view: StringView
|
||||
) -> Tuple[List[Any], Dict[str, Any]]:
|
||||
args_list = []
|
||||
kwargs_dict = {}
|
||||
params_to_parse = list(command.params.values())
|
||||
|
||||
if params_to_parse and params_to_parse[0].name == "self" and command.cog:
|
||||
params_to_parse.pop(0)
|
||||
if params_to_parse and params_to_parse[0].name == "ctx":
|
||||
params_to_parse.pop(0)
|
||||
|
||||
for param in params_to_parse:
|
||||
view.skip_whitespace()
|
||||
final_value_for_param: Any = inspect.Parameter.empty
|
||||
|
||||
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
while not view.eof:
|
||||
view.skip_whitespace()
|
||||
if view.eof:
|
||||
break
|
||||
word = view.get_word()
|
||||
if word or not view.eof:
|
||||
args_list.append(word)
|
||||
elif view.eof:
|
||||
break
|
||||
break
|
||||
|
||||
arg_str_value: Optional[str] = (
|
||||
None # Holds the raw string for current param
|
||||
)
|
||||
|
||||
if view.eof: # No more input string
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
final_value_for_param = param.default
|
||||
elif param.kind != inspect.Parameter.VAR_KEYWORD:
|
||||
raise MissingRequiredArgument(param.name)
|
||||
else: # VAR_KEYWORD at EOF is fine
|
||||
break
|
||||
else: # Input available
|
||||
is_last_pos_str_greedy = (
|
||||
param == params_to_parse[-1]
|
||||
and param.annotation is str
|
||||
and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
|
||||
if is_last_pos_str_greedy:
|
||||
arg_str_value = view.read_rest().strip()
|
||||
if (
|
||||
not arg_str_value
|
||||
and param.default is not inspect.Parameter.empty
|
||||
):
|
||||
final_value_for_param = param.default
|
||||
else: # Includes empty string if that's what's left
|
||||
final_value_for_param = arg_str_value
|
||||
else: # Not greedy, or not string, or not last positional
|
||||
if view.buffer[view.index] == '"':
|
||||
arg_str_value = view.get_quoted_string()
|
||||
if arg_str_value == "" and view.buffer[view.index] == '"':
|
||||
raise BadArgument(
|
||||
f"Unterminated quoted string for argument '{param.name}'."
|
||||
)
|
||||
else:
|
||||
arg_str_value = view.get_word()
|
||||
|
||||
# If final_value_for_param was not set by greedy logic, try conversion
|
||||
if final_value_for_param is inspect.Parameter.empty:
|
||||
if (
|
||||
arg_str_value is None
|
||||
): # Should not happen if view.get_word/get_quoted_string is robust
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
final_value_for_param = param.default
|
||||
else:
|
||||
raise MissingRequiredArgument(param.name)
|
||||
else: # We have an arg_str_value (could be empty string "" from quotes)
|
||||
annotation = param.annotation
|
||||
origin = get_origin(annotation)
|
||||
|
||||
if origin is Union: # Handles Optional[T] and Union[T1, T2]
|
||||
union_args = get_args(annotation)
|
||||
is_optional = (
|
||||
len(union_args) == 2 and type(None) in union_args
|
||||
)
|
||||
|
||||
converted_for_union = False
|
||||
last_err_union: Optional[BadArgument] = None
|
||||
for t_arg in union_args:
|
||||
if t_arg is type(None):
|
||||
continue
|
||||
try:
|
||||
final_value_for_param = await run_converters(
|
||||
ctx, t_arg, arg_str_value
|
||||
)
|
||||
converted_for_union = True
|
||||
break
|
||||
except BadArgument as e:
|
||||
last_err_union = e
|
||||
|
||||
if not converted_for_union:
|
||||
if (
|
||||
is_optional and param.default is None
|
||||
): # Special handling for Optional[T] if conversion failed
|
||||
# If arg_str_value was "" and type was Optional[str], StringConverter would return ""
|
||||
# If arg_str_value was "" and type was Optional[int], BadArgument would be raised.
|
||||
# This path is for when all actual types in Optional[T] fail conversion.
|
||||
# If default is None, we can assign None.
|
||||
final_value_for_param = None
|
||||
elif last_err_union:
|
||||
raise last_err_union
|
||||
else: # Should not be reached if logic is correct
|
||||
raise BadArgument(
|
||||
f"Could not convert '{arg_str_value}' to any of {union_args} for param '{param.name}'."
|
||||
)
|
||||
elif annotation is inspect.Parameter.empty or annotation is str:
|
||||
final_value_for_param = arg_str_value
|
||||
else: # Standard type hint
|
||||
final_value_for_param = await run_converters(
|
||||
ctx, annotation, arg_str_value
|
||||
)
|
||||
|
||||
# Final check if value was resolved
|
||||
if final_value_for_param is inspect.Parameter.empty:
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
final_value_for_param = param.default
|
||||
elif param.kind != inspect.Parameter.VAR_KEYWORD:
|
||||
# This state implies an issue if required and no default, and no input was parsed.
|
||||
raise MissingRequiredArgument(
|
||||
f"Parameter '{param.name}' could not be resolved."
|
||||
)
|
||||
|
||||
# Assign to args_list or kwargs_dict if a value was determined
|
||||
if final_value_for_param is not inspect.Parameter.empty:
|
||||
if (
|
||||
param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
or param.kind == inspect.Parameter.POSITIONAL_ONLY
|
||||
):
|
||||
args_list.append(final_value_for_param)
|
||||
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
|
||||
kwargs_dict[param.name] = final_value_for_param
|
||||
|
||||
return args_list, kwargs_dict
|
||||
|
||||
async def process_commands(self, message: "Message") -> None:
|
||||
if not message.content:
|
||||
return
|
||||
|
||||
prefix_to_use = await self.get_prefix(message)
|
||||
if not prefix_to_use:
|
||||
return
|
||||
|
||||
actual_prefix: Optional[str] = None
|
||||
if isinstance(prefix_to_use, list):
|
||||
for p in prefix_to_use:
|
||||
if message.content.startswith(p):
|
||||
actual_prefix = p
|
||||
break
|
||||
if not actual_prefix:
|
||||
return
|
||||
elif isinstance(prefix_to_use, str):
|
||||
if message.content.startswith(prefix_to_use):
|
||||
actual_prefix = prefix_to_use
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
|
||||
if actual_prefix is None:
|
||||
return
|
||||
|
||||
content_without_prefix = message.content[len(actual_prefix) :]
|
||||
view = StringView(content_without_prefix)
|
||||
|
||||
command_name = view.get_word()
|
||||
if not command_name:
|
||||
return
|
||||
|
||||
command = self.get_command(command_name)
|
||||
if not command:
|
||||
return
|
||||
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=self.client,
|
||||
prefix=actual_prefix,
|
||||
command=command,
|
||||
invoked_with=command_name,
|
||||
cog=command.cog,
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
|
||||
ctx.args = parsed_args
|
||||
ctx.kwargs = parsed_kwargs
|
||||
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
|
||||
except CommandError as e:
|
||||
print(f"Command error for '{command.name}': {e}")
|
||||
if hasattr(self.client, "on_command_error"):
|
||||
await self.client.on_command_error(ctx, e)
|
||||
except Exception as e:
|
||||
print(f"Unexpected error invoking command '{command.name}': {e}")
|
||||
exc = CommandInvokeError(e)
|
||||
if hasattr(self.client, "on_command_error"):
|
||||
await self.client.on_command_error(ctx, exc)
|
150
disagreement/ext/commands/decorators.py
Normal file
150
disagreement/ext/commands/decorators.py
Normal file
@ -0,0 +1,150 @@
|
||||
# disagreement/ext/commands/decorators.py
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import Command, CommandContext # For type hinting return or internal use
|
||||
|
||||
# from .cog import Cog # For Cog specific decorators
|
||||
|
||||
|
||||
def command(
|
||||
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
|
||||
) -> Callable:
|
||||
"""
|
||||
A decorator that transforms a function into a Command.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): The name of the command. Defaults to the function name.
|
||||
aliases (Optional[List[str]]): Alternative names for the command.
|
||||
**attrs: Additional attributes to pass to the Command constructor
|
||||
(e.g., brief, description, hidden).
|
||||
|
||||
Returns:
|
||||
Callable: A decorator that registers the command.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Command callback must be a coroutine function.")
|
||||
|
||||
from .core import (
|
||||
Command,
|
||||
) # Late import to avoid circular dependencies at module load time
|
||||
|
||||
# The actual registration will happen when a Cog is added or if commands are global.
|
||||
# For now, this decorator creates a Command instance and attaches it to the function,
|
||||
# or returns a Command instance that can be collected.
|
||||
|
||||
cmd_name = name or func.__name__
|
||||
|
||||
# Store command attributes on the function itself for later collection by Cog or Client
|
||||
# This is a common pattern.
|
||||
if hasattr(func, "__command_attrs__"):
|
||||
# This case might occur if decorators are stacked in an unusual way,
|
||||
# or if a function is decorated multiple times (which should be disallowed or handled).
|
||||
# For now, let's assume one @command decorator per function.
|
||||
raise TypeError("Function is already a command or has command attributes.")
|
||||
|
||||
# Create the command object. It will be registered by the Cog or Client.
|
||||
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
||||
|
||||
# We can attach the command object to the function, so Cogs can find it.
|
||||
func.__command_object__ = cmd # type: ignore # type: ignore[attr-defined]
|
||||
return func # Return the original function, now marked.
|
||||
# Or return `cmd` if commands are registered globally immediately.
|
||||
# For Cogs, returning `func` and letting Cog collect is cleaner.
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def listener(
|
||||
name: Optional[str] = None,
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""
|
||||
A decorator that marks a function as an event listener within a Cog.
|
||||
The actual registration happens when the Cog is added to the client.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): The name of the event to listen to.
|
||||
Defaults to the function name (e.g., `on_message`).
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Listener callback must be a coroutine function.")
|
||||
|
||||
# 'name' here is from the outer 'listener' scope (closure)
|
||||
actual_event_name = name or func.__name__
|
||||
# Store listener info on the function for Cog to collect
|
||||
setattr(func, "__listener_name__", actual_event_name)
|
||||
return func
|
||||
|
||||
return decorator # This must be correctly indented under 'listener'
|
||||
|
||||
|
||||
def check(
|
||||
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator to add a check to a command."""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[None]],
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
checks = getattr(func, "__command_checks__", [])
|
||||
checks.append(predicate)
|
||||
setattr(func, "__command_checks__", checks)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_any(
|
||||
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Decorator that passes if any predicate returns ``True``."""
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CheckAnyFailure, CheckFailure
|
||||
|
||||
errors = []
|
||||
for p in predicates:
|
||||
try:
|
||||
result = p(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if result:
|
||||
return True
|
||||
except CheckFailure as e:
|
||||
errors.append(e)
|
||||
raise CheckAnyFailure(errors)
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def cooldown(
|
||||
rate: int, per: float
|
||||
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||
"""Simple per-user cooldown decorator."""
|
||||
|
||||
buckets: dict[str, dict[str, float]] = {}
|
||||
|
||||
async def predicate(ctx: "CommandContext") -> bool:
|
||||
from .errors import CommandOnCooldown
|
||||
|
||||
now = time.monotonic()
|
||||
user_buckets = buckets.setdefault(ctx.command.name, {})
|
||||
reset = user_buckets.get(ctx.author.id, 0)
|
||||
if now < reset:
|
||||
raise CommandOnCooldown(reset - now)
|
||||
user_buckets[ctx.author.id] = now + per
|
||||
return True
|
||||
|
||||
return check(predicate)
|
76
disagreement/ext/commands/errors.py
Normal file
76
disagreement/ext/commands/errors.py
Normal file
@ -0,0 +1,76 @@
|
||||
# disagreement/ext/commands/errors.py
|
||||
|
||||
"""
|
||||
Custom exceptions for the command extension.
|
||||
"""
|
||||
|
||||
from disagreement.errors import DisagreementException
|
||||
|
||||
|
||||
class CommandError(DisagreementException):
|
||||
"""Base exception for errors raised by the commands extension."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CommandNotFound(CommandError):
|
||||
"""Exception raised when a command is not found."""
|
||||
|
||||
def __init__(self, command_name: str):
|
||||
self.command_name = command_name
|
||||
super().__init__(f"Command '{command_name}' not found.")
|
||||
|
||||
|
||||
class BadArgument(CommandError):
|
||||
"""Exception raised when a command argument fails to parse or validate."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MissingRequiredArgument(BadArgument):
|
||||
"""Exception raised when a required command argument is missing."""
|
||||
|
||||
def __init__(self, param_name: str):
|
||||
self.param_name = param_name
|
||||
super().__init__(f"Missing required argument: {param_name}")
|
||||
|
||||
|
||||
class ArgumentParsingError(BadArgument):
|
||||
"""Exception raised during the argument parsing process."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CheckFailure(CommandError):
|
||||
"""Exception raised when a command check fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CheckAnyFailure(CheckFailure):
|
||||
"""Raised when :func:`check_any` fails all checks."""
|
||||
|
||||
def __init__(self, errors: list[CheckFailure]):
|
||||
self.errors = errors
|
||||
msg = "; ".join(str(e) for e in errors)
|
||||
super().__init__(f"All checks failed: {msg}")
|
||||
|
||||
|
||||
class CommandOnCooldown(CheckFailure):
|
||||
"""Raised when a command is invoked while on cooldown."""
|
||||
|
||||
def __init__(self, retry_after: float):
|
||||
self.retry_after = retry_after
|
||||
super().__init__(f"Command is on cooldown. Retry in {retry_after:.2f}s")
|
||||
|
||||
|
||||
class CommandInvokeError(CommandError):
|
||||
"""Exception raised when an error occurs during command invocation."""
|
||||
|
||||
def __init__(self, original: Exception):
|
||||
self.original = original
|
||||
super().__init__(f"Error during command invocation: {original}")
|
||||
|
||||
|
||||
# Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc.
|
||||
# These might inherit from BadArgument.
|
37
disagreement/ext/commands/help.py
Normal file
37
disagreement/ext/commands/help.py
Normal file
@ -0,0 +1,37 @@
|
||||
# disagreement/ext/commands/help.py
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .core import Command, CommandContext, CommandHandler
|
||||
|
||||
|
||||
class HelpCommand(Command):
|
||||
"""Built-in command that displays help information for other commands."""
|
||||
|
||||
def __init__(self, handler: CommandHandler) -> None:
|
||||
self.handler = handler
|
||||
|
||||
async def callback(ctx: CommandContext, command: Optional[str] = None) -> None:
|
||||
if command:
|
||||
cmd = handler.get_command(command)
|
||||
if not cmd or cmd.name.lower() != command.lower():
|
||||
await ctx.send(f"Command '{command}' not found.")
|
||||
return
|
||||
description = cmd.description or cmd.brief or "No description provided."
|
||||
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
|
||||
else:
|
||||
lines: List[str] = []
|
||||
for registered in dict.fromkeys(handler.commands.values()):
|
||||
brief = registered.brief or registered.description or ""
|
||||
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
|
||||
if lines:
|
||||
await ctx.send("\n".join(lines))
|
||||
else:
|
||||
await ctx.send("No commands available.")
|
||||
|
||||
super().__init__(
|
||||
callback,
|
||||
name="help",
|
||||
brief="Show command help.",
|
||||
description="Displays help for commands.",
|
||||
)
|
103
disagreement/ext/commands/view.py
Normal file
103
disagreement/ext/commands/view.py
Normal file
@ -0,0 +1,103 @@
|
||||
# disagreement/ext/commands/view.py
|
||||
|
||||
import re
|
||||
|
||||
|
||||
class StringView:
|
||||
"""
|
||||
A utility class to help with parsing strings, particularly for command arguments.
|
||||
It keeps track of the current position in the string and provides methods
|
||||
to read parts of it.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer: str):
|
||||
self.buffer: str = buffer
|
||||
self.original: str = buffer # Keep original for error reporting if needed
|
||||
self.index: int = 0
|
||||
self.end: int = len(buffer)
|
||||
self.previous: int = 0 # Index before the last successful read
|
||||
|
||||
@property
|
||||
def remaining(self) -> str:
|
||||
"""Returns the rest of the string that hasn't been consumed."""
|
||||
return self.buffer[self.index :]
|
||||
|
||||
@property
|
||||
def eof(self) -> bool:
|
||||
"""Checks if the end of the string has been reached."""
|
||||
return self.index >= self.end
|
||||
|
||||
def skip_whitespace(self) -> None:
|
||||
"""Skips any leading whitespace from the current position."""
|
||||
while not self.eof and self.buffer[self.index].isspace():
|
||||
self.index += 1
|
||||
|
||||
def get_word(self) -> str:
|
||||
"""
|
||||
Reads a "word" from the current position.
|
||||
A word is a sequence of non-whitespace characters.
|
||||
"""
|
||||
self.skip_whitespace()
|
||||
if self.eof:
|
||||
return ""
|
||||
|
||||
self.previous = self.index
|
||||
match = re.match(r"\S+", self.buffer[self.index :])
|
||||
if match:
|
||||
word = match.group(0)
|
||||
self.index += len(word)
|
||||
return word
|
||||
return "" # Should not happen if not eof and skip_whitespace was called
|
||||
|
||||
def get_quoted_string(self) -> str:
|
||||
"""
|
||||
Reads a string enclosed in double quotes.
|
||||
Handles escaped quotes inside the string.
|
||||
"""
|
||||
self.skip_whitespace()
|
||||
if self.eof or self.buffer[self.index] != '"':
|
||||
return "" # Or raise an error, or return None
|
||||
|
||||
self.previous = self.index
|
||||
self.index += 1 # Skip the opening quote
|
||||
result = []
|
||||
escaped = False
|
||||
|
||||
while not self.eof:
|
||||
char = self.buffer[self.index]
|
||||
self.index += 1
|
||||
|
||||
if escaped:
|
||||
result.append(char)
|
||||
escaped = False
|
||||
elif char == "\\":
|
||||
escaped = True
|
||||
elif char == '"':
|
||||
return "".join(result) # Closing quote found
|
||||
else:
|
||||
result.append(char)
|
||||
|
||||
# If loop finishes, means EOF was reached before closing quote
|
||||
# This is an error condition. Restore index and indicate failure.
|
||||
self.index = self.previous
|
||||
# Consider raising an error like UnterminatedQuotedStringError
|
||||
return "" # Or raise
|
||||
|
||||
def read_rest(self) -> str:
|
||||
"""Reads all remaining characters from the current position."""
|
||||
self.skip_whitespace()
|
||||
if self.eof:
|
||||
return ""
|
||||
|
||||
self.previous = self.index
|
||||
result = self.buffer[self.index :]
|
||||
self.index = self.end
|
||||
return result
|
||||
|
||||
def undo(self) -> None:
|
||||
"""Resets the current position to before the last successful read."""
|
||||
self.index = self.previous
|
||||
|
||||
# Could add more methods like:
|
||||
# peek() - look at next char without consuming
|
||||
# match_regex(pattern) - consume if regex matches
|
43
disagreement/ext/loader.py
Normal file
43
disagreement/ext/loader.py
Normal file
@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Dict
|
||||
|
||||
__all__ = ["load_extension", "unload_extension"]
|
||||
|
||||
_loaded_extensions: Dict[str, ModuleType] = {}
|
||||
|
||||
|
||||
def load_extension(name: str) -> ModuleType:
|
||||
"""Load an extension by name.
|
||||
|
||||
The extension module must define a ``setup`` coroutine or function that
|
||||
will be called after loading. Any value returned by ``setup`` is ignored.
|
||||
"""
|
||||
|
||||
if name in _loaded_extensions:
|
||||
raise ValueError(f"Extension '{name}' already loaded")
|
||||
|
||||
module = import_module(name)
|
||||
|
||||
if not hasattr(module, "setup"):
|
||||
raise ImportError(f"Extension '{name}' does not define a setup function")
|
||||
|
||||
module.setup()
|
||||
_loaded_extensions[name] = module
|
||||
return module
|
||||
|
||||
|
||||
def unload_extension(name: str) -> None:
|
||||
"""Unload a previously loaded extension."""
|
||||
|
||||
module = _loaded_extensions.pop(name, None)
|
||||
if module is None:
|
||||
raise ValueError(f"Extension '{name}' is not loaded")
|
||||
|
||||
if hasattr(module, "teardown"):
|
||||
module.teardown()
|
||||
|
||||
sys.modules.pop(name, None)
|
89
disagreement/ext/tasks.py
Normal file
89
disagreement/ext/tasks.py
Normal file
@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
|
||||
__all__ = ["loop", "Task"]
|
||||
|
||||
|
||||
class Task:
|
||||
"""Simple repeating task."""
|
||||
|
||||
def __init__(self, coro: Callable[..., Awaitable[Any]], *, seconds: float) -> None:
|
||||
self._coro = coro
|
||||
self._seconds = float(seconds)
|
||||
self._task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
async def _run(self, *args: Any, **kwargs: Any) -> None:
|
||||
try:
|
||||
while True:
|
||||
await self._coro(*args, **kwargs)
|
||||
await asyncio.sleep(self._seconds)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||
if self._task is None or self._task.done():
|
||||
self._task = asyncio.create_task(self._run(*args, **kwargs))
|
||||
return self._task
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self._task is not None and not self._task.done()
|
||||
|
||||
|
||||
class _Loop:
|
||||
def __init__(self, func: Callable[..., Awaitable[Any]], seconds: float) -> None:
|
||||
self.func = func
|
||||
self.seconds = seconds
|
||||
self._task: Optional[Task] = None
|
||||
self._owner: Any = None
|
||||
|
||||
def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop":
|
||||
return _BoundLoop(self, obj)
|
||||
|
||||
def _coro(self, *args: Any, **kwargs: Any) -> Awaitable[Any]:
|
||||
if self._owner is None:
|
||||
return self.func(*args, **kwargs)
|
||||
return self.func(self._owner, *args, **kwargs)
|
||||
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||
self._task = Task(self._coro, seconds=self.seconds)
|
||||
return self._task.start(*args, **kwargs)
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._task is not None:
|
||||
self._task.stop()
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self._task.running if self._task else False
|
||||
|
||||
|
||||
class _BoundLoop:
|
||||
def __init__(self, parent: _Loop, owner: Any) -> None:
|
||||
self._parent = parent
|
||||
self._owner = owner
|
||||
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||
self._parent._owner = self._owner
|
||||
return self._parent.start(*args, **kwargs)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._parent.stop()
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self._parent.running
|
||||
|
||||
|
||||
def loop(*, seconds: float) -> Callable[[Callable[..., Awaitable[Any]]], _Loop]:
|
||||
"""Decorator to create a looping task."""
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]) -> _Loop:
|
||||
return _Loop(func, seconds)
|
||||
|
||||
return decorator
|
490
disagreement/gateway.py
Normal file
490
disagreement/gateway.py
Normal file
@ -0,0 +1,490 @@
|
||||
# disagreement/gateway.py
|
||||
|
||||
"""
|
||||
Manages the WebSocket connection to the Discord Gateway.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
import aiohttp
|
||||
import json
|
||||
import zlib
|
||||
import time
|
||||
from typing import Optional, TYPE_CHECKING, Any, Dict
|
||||
|
||||
from .enums import GatewayOpcode, GatewayIntent
|
||||
from .errors import GatewayException, DisagreementException, AuthenticationError
|
||||
from .interactions import Interaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client # For type hinting
|
||||
from .event_dispatcher import EventDispatcher
|
||||
from .http import HTTPClient
|
||||
from .interactions import Interaction # Added for INTERACTION_CREATE
|
||||
|
||||
# ZLIB Decompression constants
|
||||
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
||||
MAX_DECOMPRESSION_SIZE = 10 * 1024 * 1024 # 10 MiB, adjust as needed
|
||||
|
||||
|
||||
class GatewayClient:
|
||||
"""
|
||||
Handles the Discord Gateway WebSocket connection, heartbeating, and event dispatching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
http_client: "HTTPClient",
|
||||
event_dispatcher: "EventDispatcher",
|
||||
token: str,
|
||||
intents: int,
|
||||
client_instance: "Client", # Pass the main client instance
|
||||
verbose: bool = False,
|
||||
*,
|
||||
shard_id: Optional[int] = None,
|
||||
shard_count: Optional[int] = None,
|
||||
):
|
||||
self._http: "HTTPClient" = http_client
|
||||
self._dispatcher: "EventDispatcher" = event_dispatcher
|
||||
self._token: str = token
|
||||
self._intents: int = intents
|
||||
self._client_instance: "Client" = client_instance # Store client instance
|
||||
self.verbose: bool = verbose
|
||||
self._shard_id: Optional[int] = shard_id
|
||||
self._shard_count: Optional[int] = shard_count
|
||||
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
self._heartbeat_interval: Optional[float] = None
|
||||
self._last_sequence: Optional[int] = None
|
||||
self._session_id: Optional[str] = None
|
||||
self._resume_gateway_url: Optional[str] = None
|
||||
|
||||
self._keep_alive_task: Optional[asyncio.Task] = None
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
|
||||
# For zlib decompression
|
||||
self._buffer = bytearray()
|
||||
self._inflator = zlib.decompressobj()
|
||||
|
||||
async def _decompress_message(
|
||||
self, message_bytes: bytes
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Decompresses a zlib-compressed message from the Gateway."""
|
||||
self._buffer.extend(message_bytes)
|
||||
|
||||
if len(message_bytes) < 4 or message_bytes[-4:] != ZLIB_SUFFIX:
|
||||
# Message is not complete or not zlib compressed in the expected way
|
||||
return None
|
||||
# Or handle partial messages if Discord ever sends them fragmented like this,
|
||||
# but typically each binary message is a complete zlib stream.
|
||||
|
||||
try:
|
||||
decompressed = self._inflator.decompress(self._buffer)
|
||||
self._buffer.clear() # Reset buffer after successful decompression
|
||||
return json.loads(decompressed.decode("utf-8"))
|
||||
except zlib.error as e:
|
||||
print(f"Zlib decompression error: {e}")
|
||||
self._buffer.clear() # Clear buffer on error
|
||||
self._inflator = zlib.decompressobj() # Reset inflator
|
||||
return None
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON decode error after decompression: {e}")
|
||||
return None
|
||||
|
||||
async def _send_json(self, payload: Dict[str, Any]):
|
||||
if self._ws and not self._ws.closed:
|
||||
if self.verbose:
|
||||
print(f"GATEWAY SEND: {payload}")
|
||||
await self._ws.send_json(payload)
|
||||
else:
|
||||
print("Gateway send attempted but WebSocket is closed or not available.")
|
||||
# raise GatewayException("WebSocket is not connected.")
|
||||
|
||||
async def _heartbeat(self):
|
||||
"""Sends a heartbeat to the Gateway."""
|
||||
payload = {"op": GatewayOpcode.HEARTBEAT, "d": self._last_sequence}
|
||||
await self._send_json(payload)
|
||||
# print("Sent heartbeat.")
|
||||
|
||||
async def _keep_alive(self):
|
||||
"""Manages the heartbeating loop."""
|
||||
if self._heartbeat_interval is None:
|
||||
# This should not happen if HELLO was processed correctly
|
||||
print("Error: Heartbeat interval not set. Cannot start keep_alive.")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
await self._heartbeat()
|
||||
await asyncio.sleep(
|
||||
self._heartbeat_interval / 1000
|
||||
) # Interval is in ms
|
||||
except asyncio.CancelledError:
|
||||
print("Keep_alive task cancelled.")
|
||||
except Exception as e:
|
||||
print(f"Error in keep_alive loop: {e}")
|
||||
# Potentially trigger a reconnect here or notify client
|
||||
await self._client_instance.close_gateway(code=1000) # Generic close
|
||||
|
||||
async def _identify(self):
|
||||
"""Sends the IDENTIFY payload to the Gateway."""
|
||||
payload = {
|
||||
"op": GatewayOpcode.IDENTIFY,
|
||||
"d": {
|
||||
"token": self._token,
|
||||
"intents": self._intents,
|
||||
"properties": {
|
||||
"$os": "python", # Or platform.system()
|
||||
"$browser": "disagreement", # Library name
|
||||
"$device": "disagreement", # Library name
|
||||
},
|
||||
"compress": True, # Request zlib compression
|
||||
},
|
||||
}
|
||||
if self._shard_id is not None and self._shard_count is not None:
|
||||
payload["d"]["shard"] = [self._shard_id, self._shard_count]
|
||||
await self._send_json(payload)
|
||||
print("Sent IDENTIFY.")
|
||||
|
||||
async def _resume(self):
|
||||
"""Sends the RESUME payload to the Gateway."""
|
||||
if not self._session_id or self._last_sequence is None:
|
||||
print("Cannot RESUME: session_id or last_sequence is missing.")
|
||||
await self._identify() # Fallback to identify
|
||||
return
|
||||
|
||||
payload = {
|
||||
"op": GatewayOpcode.RESUME,
|
||||
"d": {
|
||||
"token": self._token,
|
||||
"session_id": self._session_id,
|
||||
"seq": self._last_sequence,
|
||||
},
|
||||
}
|
||||
await self._send_json(payload)
|
||||
print(
|
||||
f"Sent RESUME for session {self._session_id} at sequence {self._last_sequence}."
|
||||
)
|
||||
async def update_presence(
|
||||
self,
|
||||
status: str,
|
||||
activity_name: Optional[str] = None,
|
||||
activity_type: int = 0,
|
||||
since: int = 0,
|
||||
afk: bool = False,
|
||||
):
|
||||
"""Sends the presence update payload to the Gateway."""
|
||||
payload = {
|
||||
"op": GatewayOpcode.PRESENCE_UPDATE,
|
||||
"d": {
|
||||
"since": since,
|
||||
"activities": [
|
||||
{
|
||||
"name": activity_name,
|
||||
"type": activity_type,
|
||||
}
|
||||
]
|
||||
if activity_name
|
||||
else [],
|
||||
"status": status,
|
||||
"afk": afk,
|
||||
},
|
||||
}
|
||||
await self._send_json(payload)
|
||||
|
||||
async def _handle_dispatch(self, data: Dict[str, Any]):
|
||||
"""Handles DISPATCH events (actual Discord events)."""
|
||||
event_name = data.get("t")
|
||||
sequence_num = data.get("s")
|
||||
raw_event_d_payload = data.get(
|
||||
"d"
|
||||
) # This is the 'd' field from the gateway event
|
||||
|
||||
if sequence_num is not None:
|
||||
self._last_sequence = sequence_num
|
||||
|
||||
if event_name == "READY": # Special handling for READY
|
||||
if not isinstance(raw_event_d_payload, dict):
|
||||
print(
|
||||
f"Error: READY event 'd' payload is not a dict or is missing: {raw_event_d_payload}"
|
||||
)
|
||||
# Consider raising an error or attempting a reconnect
|
||||
return
|
||||
self._session_id = raw_event_d_payload.get("session_id")
|
||||
self._resume_gateway_url = raw_event_d_payload.get("resume_gateway_url")
|
||||
|
||||
app_id_str = "N/A"
|
||||
# Store application_id on the client instance
|
||||
if (
|
||||
"application" in raw_event_d_payload
|
||||
and isinstance(raw_event_d_payload["application"], dict)
|
||||
and "id" in raw_event_d_payload["application"]
|
||||
):
|
||||
app_id_value = raw_event_d_payload["application"]["id"]
|
||||
self._client_instance.application_id = (
|
||||
app_id_value # Snowflake can be str or int
|
||||
)
|
||||
app_id_str = str(app_id_value)
|
||||
else:
|
||||
print(
|
||||
f"Warning: Could not find application ID in READY payload. App commands may not work."
|
||||
)
|
||||
|
||||
# Parse and store the bot's own user object
|
||||
if "user" in raw_event_d_payload and isinstance(
|
||||
raw_event_d_payload["user"], dict
|
||||
):
|
||||
try:
|
||||
# Assuming Client has a parse_user method that takes user data dict
|
||||
# and returns a User object, also caching it.
|
||||
bot_user_obj = self._client_instance.parse_user(
|
||||
raw_event_d_payload["user"]
|
||||
)
|
||||
self._client_instance.user = bot_user_obj
|
||||
print(
|
||||
f"Gateway READY. Bot User: {bot_user_obj.username}#{bot_user_obj.discriminator}. Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error parsing bot user from READY payload: {e}")
|
||||
print(
|
||||
f"Gateway READY (user parse failed). Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: Bot user object not found or invalid in READY payload."
|
||||
)
|
||||
print(
|
||||
f"Gateway READY (no user). Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}"
|
||||
)
|
||||
|
||||
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
|
||||
elif event_name == "INTERACTION_CREATE":
|
||||
# print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}")
|
||||
if isinstance(raw_event_d_payload, dict):
|
||||
interaction = Interaction(
|
||||
data=raw_event_d_payload, client_instance=self._client_instance
|
||||
)
|
||||
await self._dispatcher.dispatch(
|
||||
"INTERACTION_CREATE", raw_event_d_payload
|
||||
)
|
||||
# Dispatch to a new client method that will then call AppCommandHandler
|
||||
if hasattr(self._client_instance, "process_interaction"):
|
||||
asyncio.create_task(
|
||||
self._client_instance.process_interaction(interaction)
|
||||
) # type: ignore
|
||||
else:
|
||||
print(
|
||||
"Warning: Client instance does not have process_interaction method for INTERACTION_CREATE."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Error: INTERACTION_CREATE event 'd' payload is not a dict: {raw_event_d_payload}"
|
||||
)
|
||||
elif event_name == "RESUMED":
|
||||
print("Gateway RESUMED successfully.")
|
||||
# RESUMED 'd' payload is often an empty object or debug info.
|
||||
# Ensure it's a dict for the dispatcher.
|
||||
event_data_to_dispatch = (
|
||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||
)
|
||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||
elif event_name:
|
||||
# For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing.
|
||||
# Models/parsers in EventDispatcher will need to handle potentially empty dicts.
|
||||
event_data_to_dispatch = (
|
||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||
)
|
||||
# print(f"GATEWAY RECV EVENT: {event_name} | DATA: {event_data_to_dispatch}")
|
||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||
else:
|
||||
print(f"Received dispatch with no event name: {data}")
|
||||
|
||||
async def _process_message(self, msg: aiohttp.WSMessage):
|
||||
"""Processes a single message from the WebSocket."""
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
f"Failed to decode JSON from Gateway: {msg.data[:200]}"
|
||||
) # Log snippet
|
||||
return
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
decompressed_data = await self._decompress_message(msg.data)
|
||||
if decompressed_data is None:
|
||||
print("Failed to decompress or decode binary message from Gateway.")
|
||||
return
|
||||
data = decompressed_data
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print(
|
||||
f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}"
|
||||
)
|
||||
raise GatewayException(
|
||||
f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
close_code = (
|
||||
self._ws.close_code
|
||||
if self._ws and hasattr(self._ws, "close_code")
|
||||
else "N/A"
|
||||
)
|
||||
print(f"WebSocket connection closed by server. Code: {close_code}")
|
||||
# Raise an exception to signal the closure to the client's main run loop
|
||||
raise GatewayException(f"WebSocket closed by server. Code: {close_code}")
|
||||
else:
|
||||
print(f"Received unhandled WebSocket message type: {msg.type}")
|
||||
return
|
||||
|
||||
if self.verbose:
|
||||
print(f"GATEWAY RECV: {data}")
|
||||
op = data.get("op")
|
||||
# 'd' payload (event_data) is handled specifically by each opcode handler below
|
||||
|
||||
if op == GatewayOpcode.DISPATCH:
|
||||
await self._handle_dispatch(data) # _handle_dispatch will extract 'd'
|
||||
elif op == GatewayOpcode.HEARTBEAT: # Server requests a heartbeat
|
||||
await self._heartbeat()
|
||||
elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect
|
||||
print("Gateway requested RECONNECT. Closing and will attempt to reconnect.")
|
||||
await self.close(code=4000) # Use a non-1000 code to indicate reconnect
|
||||
elif op == GatewayOpcode.INVALID_SESSION:
|
||||
# The 'd' payload for INVALID_SESSION is a boolean indicating resumability
|
||||
can_resume = data.get("d") is True
|
||||
print(f"Gateway indicated INVALID_SESSION. Resumable: {can_resume}")
|
||||
if not can_resume:
|
||||
self._session_id = None # Clear session_id to force re-identify
|
||||
self._last_sequence = None
|
||||
# Close and reconnect. The connect logic will decide to resume or identify.
|
||||
await self.close(
|
||||
code=4000 if can_resume else 4009
|
||||
) # 4009 for non-resumable
|
||||
elif op == GatewayOpcode.HELLO:
|
||||
hello_d_payload = data.get("d")
|
||||
if (
|
||||
not isinstance(hello_d_payload, dict)
|
||||
or "heartbeat_interval" not in hello_d_payload
|
||||
):
|
||||
print(
|
||||
f"Error: HELLO event 'd' payload is invalid or missing heartbeat_interval: {hello_d_payload}"
|
||||
)
|
||||
await self.close(code=1011) # Internal error, malformed HELLO
|
||||
return
|
||||
self._heartbeat_interval = hello_d_payload["heartbeat_interval"]
|
||||
print(f"Gateway HELLO. Heartbeat interval: {self._heartbeat_interval}ms.")
|
||||
# Start heartbeating
|
||||
if self._keep_alive_task:
|
||||
self._keep_alive_task.cancel()
|
||||
self._keep_alive_task = self._loop.create_task(self._keep_alive())
|
||||
|
||||
# Identify or Resume
|
||||
if self._session_id and self._resume_gateway_url: # Check if we can resume
|
||||
print("Attempting to RESUME session.")
|
||||
await self._resume()
|
||||
else:
|
||||
print("Performing initial IDENTIFY.")
|
||||
await self._identify()
|
||||
elif op == GatewayOpcode.HEARTBEAT_ACK:
|
||||
# print("Received heartbeat ACK.")
|
||||
pass # Good, connection is alive
|
||||
else:
|
||||
print(f"Received unhandled Gateway Opcode: {op} with data: {data}")
|
||||
|
||||
async def _receive_loop(self):
|
||||
"""Continuously receives and processes messages from the WebSocket."""
|
||||
if not self._ws or self._ws.closed:
|
||||
print("Receive loop cannot start: WebSocket is not connected or closed.")
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
await self._process_message(msg)
|
||||
except asyncio.CancelledError:
|
||||
print("Receive_loop task cancelled.")
|
||||
except aiohttp.ClientConnectionError as e:
|
||||
print(f"ClientConnectionError in receive_loop: {e}. Attempting reconnect.")
|
||||
# This might be handled by an outer reconnect loop in the Client class
|
||||
await self.close(code=1006) # Abnormal closure
|
||||
except Exception as e:
|
||||
print(f"Unexpected error in receive_loop: {e}")
|
||||
traceback.print_exc()
|
||||
# Consider specific error types for more granular handling
|
||||
await self.close(code=1011) # Internal error
|
||||
finally:
|
||||
print("Receive_loop ended.")
|
||||
# If the loop ends unexpectedly (not due to explicit close),
|
||||
# the main client might want to try reconnecting.
|
||||
|
||||
async def connect(self):
|
||||
"""Connects to the Discord Gateway."""
|
||||
if self._ws and not self._ws.closed:
|
||||
print("Gateway already connected or connecting.")
|
||||
return
|
||||
|
||||
gateway_url = (
|
||||
self._resume_gateway_url or (await self._http.get_gateway_bot())["url"]
|
||||
)
|
||||
if not gateway_url.endswith("?v=10&encoding=json&compress=zlib-stream"):
|
||||
gateway_url += "?v=10&encoding=json&compress=zlib-stream"
|
||||
|
||||
print(f"Connecting to Gateway: {gateway_url}")
|
||||
try:
|
||||
await self._http._ensure_session() # Ensure the HTTP client's session is active
|
||||
assert (
|
||||
self._http._session is not None
|
||||
), "HTTPClient session not initialized after ensure_session"
|
||||
self._ws = await self._http._session.ws_connect(gateway_url, max_msg_size=0)
|
||||
print("Gateway WebSocket connection established.")
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
self._receive_task = self._loop.create_task(self._receive_loop())
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
raise GatewayException(
|
||||
f"Failed to connect to Gateway (Connector Error): {e}"
|
||||
) from e
|
||||
except aiohttp.WSServerHandshakeError as e:
|
||||
if e.status == 401: # Unauthorized during handshake
|
||||
raise AuthenticationError(
|
||||
f"Gateway handshake failed (401 Unauthorized): {e.message}. Check your bot token."
|
||||
) from e
|
||||
raise GatewayException(
|
||||
f"Gateway handshake failed (Status: {e.status}): {e.message}"
|
||||
) from e
|
||||
except Exception as e: # Catch other potential errors during connection
|
||||
raise GatewayException(
|
||||
f"An unexpected error occurred during Gateway connection: {e}"
|
||||
) from e
|
||||
|
||||
async def close(self, code: int = 1000):
|
||||
"""Closes the Gateway connection."""
|
||||
print(f"Closing Gateway connection with code {code}...")
|
||||
if self._keep_alive_task and not self._keep_alive_task.done():
|
||||
self._keep_alive_task.cancel()
|
||||
try:
|
||||
await self._keep_alive_task
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected
|
||||
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected
|
||||
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.close(code=code)
|
||||
print("Gateway WebSocket closed.")
|
||||
|
||||
self._ws = None
|
||||
# Do not reset session_id, last_sequence, or resume_gateway_url here
|
||||
# if the close code indicates a resumable disconnect (e.g. 4000-4009, or server-initiated RECONNECT)
|
||||
# The connect logic will decide whether to resume or re-identify.
|
||||
# However, if it's a non-resumable close (e.g. Invalid Session non-resumable), clear them.
|
||||
if code == 4009: # Invalid session, not resumable
|
||||
print("Clearing session state due to non-resumable invalid session.")
|
||||
self._session_id = None
|
||||
self._last_sequence = None
|
||||
self._resume_gateway_url = None # This might be re-fetched anyway
|
657
disagreement/http.py
Normal file
657
disagreement/http.py
Normal file
@ -0,0 +1,657 @@
|
||||
# disagreement/http.py
|
||||
|
||||
"""
|
||||
HTTP client for interacting with the Discord REST API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp # pylint: disable=import-error
|
||||
import json
|
||||
from urllib.parse import quote
|
||||
from typing import Optional, Dict, Any, Union, TYPE_CHECKING, List
|
||||
|
||||
from .errors import (
|
||||
HTTPException,
|
||||
RateLimitError,
|
||||
AuthenticationError,
|
||||
DisagreementException,
|
||||
)
|
||||
from . import __version__ # For User-Agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
from .models import Message
|
||||
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
|
||||
|
||||
# Discord API constants
|
||||
API_BASE_URL = "https://discord.com/api/v10" # Using API v10
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
"""Handles HTTP requests to the Discord API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token: str,
|
||||
client_session: Optional[aiohttp.ClientSession] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
self.token = token
|
||||
self._session: Optional[aiohttp.ClientSession] = (
|
||||
client_session # Can be externally managed
|
||||
)
|
||||
self.user_agent = f"DiscordBot (https://github.com/yourusername/disagreement, {__version__})" # Customize URL
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
self._global_rate_limit_lock = asyncio.Event()
|
||||
self._global_rate_limit_lock.set() # Initially unlocked
|
||||
|
||||
async def _ensure_session(self):
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
async def close(self):
|
||||
"""Closes the underlying aiohttp.ClientSession."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
payload: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
is_json: bool = True,
|
||||
use_auth_header: bool = True,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Makes an HTTP request to the Discord API."""
|
||||
await self._ensure_session()
|
||||
|
||||
url = f"{API_BASE_URL}{endpoint}"
|
||||
final_headers: Dict[str, str] = { # Renamed to final_headers
|
||||
"User-Agent": self.user_agent,
|
||||
}
|
||||
if use_auth_header:
|
||||
final_headers["Authorization"] = f"Bot {self.token}"
|
||||
|
||||
if is_json and payload:
|
||||
final_headers["Content-Type"] = "application/json"
|
||||
|
||||
if custom_headers: # Merge custom headers
|
||||
final_headers.update(custom_headers)
|
||||
|
||||
if self.verbose:
|
||||
print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
|
||||
|
||||
# Global rate limit handling
|
||||
await self._global_rate_limit_lock.wait()
|
||||
|
||||
for attempt in range(5): # Max 5 retries for rate limits
|
||||
assert self._session is not None, "ClientSession not initialized"
|
||||
async with self._session.request(
|
||||
method,
|
||||
url,
|
||||
json=payload if is_json else None,
|
||||
data=payload if not is_json else None,
|
||||
headers=final_headers,
|
||||
params=params,
|
||||
) as response:
|
||||
|
||||
data = None
|
||||
try:
|
||||
if response.headers.get("Content-Type", "").startswith(
|
||||
"application/json"
|
||||
):
|
||||
data = await response.json()
|
||||
else:
|
||||
# For non-JSON responses, like fetching images or other files
|
||||
# We might return the raw response or handle it differently
|
||||
# For now, let's assume most API calls expect JSON
|
||||
data = await response.text()
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
data = (
|
||||
await response.text()
|
||||
) # Fallback to text if JSON parsing fails
|
||||
|
||||
if self.verbose:
|
||||
print(f"HTTP RESPONSE: {response.status} {url} | {data}")
|
||||
|
||||
if 200 <= response.status < 300:
|
||||
if response.status == 204:
|
||||
return None
|
||||
return data
|
||||
|
||||
# Rate limit handling
|
||||
if response.status == 429: # Rate limited
|
||||
retry_after_str = response.headers.get("Retry-After", "1")
|
||||
try:
|
||||
retry_after = float(retry_after_str)
|
||||
except ValueError:
|
||||
retry_after = 1.0 # Default retry if header is malformed
|
||||
|
||||
is_global = (
|
||||
response.headers.get("X-RateLimit-Global", "false").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
error_message = f"Rate limited on {method} {endpoint}."
|
||||
if data and isinstance(data, dict) and "message" in data:
|
||||
error_message += f" Discord says: {data['message']}"
|
||||
|
||||
if is_global:
|
||||
self._global_rate_limit_lock.clear()
|
||||
await asyncio.sleep(retry_after)
|
||||
self._global_rate_limit_lock.set()
|
||||
else:
|
||||
await asyncio.sleep(retry_after)
|
||||
|
||||
if attempt < 4: # Don't log on the last attempt before raising
|
||||
print(
|
||||
f"{error_message} Retrying after {retry_after}s (Attempt {attempt + 1}/5). Global: {is_global}"
|
||||
)
|
||||
continue # Retry the request
|
||||
else: # Last attempt failed
|
||||
raise RateLimitError(
|
||||
response,
|
||||
message=error_message,
|
||||
retry_after=retry_after,
|
||||
is_global=is_global,
|
||||
)
|
||||
|
||||
# Other error handling
|
||||
if response.status == 401: # Unauthorized
|
||||
raise AuthenticationError(response, "Invalid token provided.")
|
||||
if response.status == 403: # Forbidden
|
||||
raise HTTPException(
|
||||
response,
|
||||
"Missing permissions or access denied.",
|
||||
status=response.status,
|
||||
text=str(data),
|
||||
)
|
||||
|
||||
# General HTTP error
|
||||
error_text = str(data) if data else "Unknown error"
|
||||
discord_error_code = (
|
||||
data.get("code") if isinstance(data, dict) else None
|
||||
)
|
||||
raise HTTPException(
|
||||
response,
|
||||
f"API Error on {method} {endpoint}: {error_text}",
|
||||
status=response.status,
|
||||
text=error_text,
|
||||
error_code=discord_error_code,
|
||||
)
|
||||
|
||||
# Should not be reached if retries are exhausted by RateLimitError
|
||||
raise DisagreementException(
|
||||
f"Failed request to {method} {endpoint} after multiple retries."
|
||||
)
|
||||
|
||||
# --- Specific API call methods ---
|
||||
|
||||
async def get_gateway_bot(self) -> Dict[str, Any]:
|
||||
"""Gets the WSS URL and sharding information for the Gateway."""
|
||||
return await self.request("GET", "/gateway/bot")
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
channel_id: str,
|
||||
content: Optional[str] = None,
|
||||
tts: bool = False,
|
||||
embeds: Optional[List[Dict[str, Any]]] = None,
|
||||
components: Optional[List[Dict[str, Any]]] = None,
|
||||
allowed_mentions: Optional[dict] = None,
|
||||
message_reference: Optional[Dict[str, Any]] = None,
|
||||
flags: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Sends a message to a channel.
|
||||
|
||||
Returns the created message data as a dict.
|
||||
"""
|
||||
payload: Dict[str, Any] = {}
|
||||
if content is not None: # Content is optional if embeds/components are present
|
||||
payload["content"] = content
|
||||
if tts:
|
||||
payload["tts"] = True
|
||||
if embeds:
|
||||
payload["embeds"] = embeds
|
||||
if components:
|
||||
payload["components"] = components
|
||||
if allowed_mentions:
|
||||
payload["allowed_mentions"] = allowed_mentions
|
||||
if flags:
|
||||
payload["flags"] = flags
|
||||
if message_reference:
|
||||
payload["message_reference"] = message_reference
|
||||
|
||||
if not payload:
|
||||
raise ValueError("Message must have content, embeds, or components.")
|
||||
|
||||
return await self.request(
|
||||
"POST", f"/channels/{channel_id}/messages", payload=payload
|
||||
)
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
channel_id: str,
|
||||
message_id: str,
|
||||
payload: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Edits a message in a channel."""
|
||||
|
||||
return await self.request(
|
||||
"PATCH",
|
||||
f"/channels/{channel_id}/messages/{message_id}",
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
async def get_message(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetches a message from a channel."""
|
||||
|
||||
return await self.request(
|
||||
"GET", f"/channels/{channel_id}/messages/{message_id}"
|
||||
)
|
||||
|
||||
async def create_reaction(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
|
||||
) -> None:
|
||||
"""Adds a reaction to a message as the current user."""
|
||||
encoded = quote(emoji)
|
||||
await self.request(
|
||||
"PUT",
|
||||
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
|
||||
)
|
||||
|
||||
async def delete_reaction(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
|
||||
) -> None:
|
||||
"""Removes the current user's reaction from a message."""
|
||||
encoded = quote(emoji)
|
||||
await self.request(
|
||||
"DELETE",
|
||||
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
|
||||
)
|
||||
|
||||
async def get_reactions(
|
||||
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Fetches the users that reacted with a specific emoji."""
|
||||
encoded = quote(emoji)
|
||||
return await self.request(
|
||||
"GET",
|
||||
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}",
|
||||
)
|
||||
|
||||
async def delete_channel(
|
||||
self, channel_id: str, reason: Optional[str] = None
|
||||
) -> None:
|
||||
"""Deletes a channel.
|
||||
|
||||
If the channel is a guild channel, requires the MANAGE_CHANNELS permission.
|
||||
If the channel is a thread, requires the MANAGE_THREADS permission (if locked) or
|
||||
be the thread creator (if not locked).
|
||||
Deleting a category does not delete its child channels.
|
||||
"""
|
||||
custom_headers = {}
|
||||
if reason:
|
||||
custom_headers["X-Audit-Log-Reason"] = reason
|
||||
|
||||
await self.request(
|
||||
"DELETE",
|
||||
f"/channels/{channel_id}",
|
||||
custom_headers=custom_headers if custom_headers else None,
|
||||
)
|
||||
|
||||
async def get_channel(self, channel_id: str) -> Dict[str, Any]:
|
||||
"""Fetches a channel by ID."""
|
||||
return await self.request("GET", f"/channels/{channel_id}")
|
||||
|
||||
async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]:
|
||||
"""Fetches a user object for a given user ID."""
|
||||
return await self.request("GET", f"/users/{user_id}")
|
||||
|
||||
async def get_guild_member(
|
||||
self, guild_id: "Snowflake", user_id: "Snowflake"
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns a guild member object for the specified user."""
|
||||
return await self.request("GET", f"/guilds/{guild_id}/members/{user_id}")
|
||||
|
||||
async def kick_member(
|
||||
self, guild_id: "Snowflake", user_id: "Snowflake", reason: Optional[str] = None
|
||||
) -> None:
|
||||
"""Kicks a member from the guild."""
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
await self.request(
|
||||
"DELETE",
|
||||
f"/guilds/{guild_id}/members/{user_id}",
|
||||
custom_headers=headers,
|
||||
)
|
||||
|
||||
async def ban_member(
|
||||
self,
|
||||
guild_id: "Snowflake",
|
||||
user_id: "Snowflake",
|
||||
*,
|
||||
delete_message_seconds: int = 0,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Bans a member from the guild."""
|
||||
payload = {}
|
||||
if delete_message_seconds:
|
||||
payload["delete_message_seconds"] = delete_message_seconds
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
await self.request(
|
||||
"PUT",
|
||||
f"/guilds/{guild_id}/bans/{user_id}",
|
||||
payload=payload if payload else None,
|
||||
custom_headers=headers,
|
||||
)
|
||||
|
||||
async def timeout_member(
|
||||
self,
|
||||
guild_id: "Snowflake",
|
||||
user_id: "Snowflake",
|
||||
*,
|
||||
until: Optional[str],
|
||||
reason: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Times out a member until the given ISO8601 timestamp."""
|
||||
payload = {"communication_disabled_until": until}
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
return await self.request(
|
||||
"PATCH",
|
||||
f"/guilds/{guild_id}/members/{user_id}",
|
||||
payload=payload,
|
||||
custom_headers=headers,
|
||||
)
|
||||
|
||||
async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
|
||||
"""Returns a list of role objects for the guild."""
|
||||
return await self.request("GET", f"/guilds/{guild_id}/roles")
|
||||
|
||||
async def get_guild(self, guild_id: "Snowflake") -> Dict[str, Any]:
|
||||
"""Fetches a guild object for a given guild ID."""
|
||||
return await self.request("GET", f"/guilds/{guild_id}")
|
||||
|
||||
# Add other methods like:
|
||||
# async def get_guild(self, guild_id: str) -> Dict[str, Any]: ...
|
||||
# async def create_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: ...
|
||||
# etc.
|
||||
# --- Application Command Endpoints ---
|
||||
|
||||
# Global Application Commands
|
||||
async def get_global_application_commands(
|
||||
self, application_id: "Snowflake", with_localizations: bool = False
|
||||
) -> List["ApplicationCommand"]:
|
||||
"""Fetches all global commands for your application."""
|
||||
params = {"with_localizations": str(with_localizations).lower()}
|
||||
data = await self.request(
|
||||
"GET", f"/applications/{application_id}/commands", params=params
|
||||
)
|
||||
from .interactions import ApplicationCommand # Ensure constructor is available
|
||||
|
||||
return [ApplicationCommand(cmd_data) for cmd_data in data]
|
||||
|
||||
async def create_global_application_command(
|
||||
self, application_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "ApplicationCommand":
|
||||
"""Creates a new global command."""
|
||||
data = await self.request(
|
||||
"POST", f"/applications/{application_id}/commands", payload=payload
|
||||
)
|
||||
from .interactions import ApplicationCommand
|
||||
|
||||
return ApplicationCommand(data)
|
||||
|
||||
async def get_global_application_command(
|
||||
self, application_id: "Snowflake", command_id: "Snowflake"
|
||||
) -> "ApplicationCommand":
|
||||
"""Fetches a specific global command."""
|
||||
data = await self.request(
|
||||
"GET", f"/applications/{application_id}/commands/{command_id}"
|
||||
)
|
||||
from .interactions import ApplicationCommand
|
||||
|
||||
return ApplicationCommand(data)
|
||||
|
||||
async def edit_global_application_command(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
command_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
) -> "ApplicationCommand":
|
||||
"""Edits a specific global command."""
|
||||
data = await self.request(
|
||||
"PATCH",
|
||||
f"/applications/{application_id}/commands/{command_id}",
|
||||
payload=payload,
|
||||
)
|
||||
from .interactions import ApplicationCommand
|
||||
|
||||
return ApplicationCommand(data)
|
||||
|
||||
async def delete_global_application_command(
|
||||
self, application_id: "Snowflake", command_id: "Snowflake"
|
||||
) -> None:
|
||||
"""Deletes a specific global command."""
|
||||
await self.request(
|
||||
"DELETE", f"/applications/{application_id}/commands/{command_id}"
|
||||
)
|
||||
|
||||
async def bulk_overwrite_global_application_commands(
|
||||
self, application_id: "Snowflake", payload: List[Dict[str, Any]]
|
||||
) -> List["ApplicationCommand"]:
|
||||
"""Bulk overwrites all global commands for your application."""
|
||||
data = await self.request(
|
||||
"PUT", f"/applications/{application_id}/commands", payload=payload
|
||||
)
|
||||
from .interactions import ApplicationCommand
|
||||
|
||||
return [ApplicationCommand(cmd_data) for cmd_data in data]
|
||||
|
||||
# Guild Application Commands
|
||||
async def get_guild_application_commands(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
guild_id: "Snowflake",
|
||||
with_localizations: bool = False,
|
||||
) -> List["ApplicationCommand"]:
|
||||
"""Fetches all commands for your application for a specific guild."""
|
||||
params = {"with_localizations": str(with_localizations).lower()}
|
||||
data = await self.request(
|
||||
"GET",
|
||||
f"/applications/{application_id}/guilds/{guild_id}/commands",
|
||||
params=params,
|
||||
)
|
||||
from .interactions import ApplicationCommand
|
||||
|
||||
return [ApplicationCommand(cmd_data) for cmd_data in data]
|
||||
|
||||
async def create_guild_application_command(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
guild_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
) -> "ApplicationCommand":
|
||||
"""Creates a new guild command."""
|
||||
data = await self.request(
|
||||
"POST",
|
||||
f"/applications/{application_id}/guilds/{guild_id}/commands",
|
||||
payload=payload,
|
||||
)
|
||||
from .interactions import ApplicationCommand
|
||||
|
||||
return ApplicationCommand(data)
|
||||
|
||||
async def get_guild_application_command(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
guild_id: "Snowflake",
|
||||
command_id: "Snowflake",
|
||||
) -> "ApplicationCommand":
|
||||
"""Fetches a specific guild command."""
|
||||
data = await self.request(
|
||||
"GET",
|
||||
f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
|
||||
)
|
||||
from .interactions import ApplicationCommand
|
||||
|
||||
return ApplicationCommand(data)
|
||||
|
||||
async def edit_guild_application_command(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
guild_id: "Snowflake",
|
||||
command_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
) -> "ApplicationCommand":
|
||||
"""Edits a specific guild command."""
|
||||
data = await self.request(
|
||||
"PATCH",
|
||||
f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
|
||||
payload=payload,
|
||||
)
|
||||
from .interactions import ApplicationCommand
|
||||
|
||||
return ApplicationCommand(data)
|
||||
|
||||
async def delete_guild_application_command(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
guild_id: "Snowflake",
|
||||
command_id: "Snowflake",
|
||||
) -> None:
|
||||
"""Deletes a specific guild command."""
|
||||
await self.request(
|
||||
"DELETE",
|
||||
f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
|
||||
)
|
||||
|
||||
async def bulk_overwrite_guild_application_commands(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
guild_id: "Snowflake",
|
||||
payload: List[Dict[str, Any]],
|
||||
) -> List["ApplicationCommand"]:
|
||||
"""Bulk overwrites all commands for your application for a specific guild."""
|
||||
data = await self.request(
|
||||
"PUT",
|
||||
f"/applications/{application_id}/guilds/{guild_id}/commands",
|
||||
payload=payload,
|
||||
)
|
||||
from .interactions import ApplicationCommand
|
||||
|
||||
return [ApplicationCommand(cmd_data) for cmd_data in data]
|
||||
|
||||
# --- Interaction Response Endpoints ---
|
||||
# Note: These methods return Dict[str, Any] representing the Message data.
|
||||
# The caller (e.g., AppCommandHandler) will be responsible for constructing Message models
|
||||
# if needed, as Message model instantiation requires a `client_instance`.
|
||||
|
||||
async def create_interaction_response(
|
||||
self,
|
||||
interaction_id: "Snowflake",
|
||||
interaction_token: str,
|
||||
payload: "InteractionResponsePayload",
|
||||
*,
|
||||
ephemeral: bool = False,
|
||||
) -> None:
|
||||
"""Creates a response to an Interaction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ephemeral: bool
|
||||
Ignored parameter for test compatibility.
|
||||
"""
|
||||
# Interaction responses do not use the bot token in the Authorization header.
|
||||
# They are authenticated by the interaction_token in the URL.
|
||||
await self.request(
|
||||
"POST",
|
||||
f"/interactions/{interaction_id}/{interaction_token}/callback",
|
||||
payload=payload.to_dict(),
|
||||
use_auth_header=False,
|
||||
)
|
||||
|
||||
async def get_original_interaction_response(
|
||||
self, application_id: "Snowflake", interaction_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Gets the initial Interaction response."""
|
||||
# This endpoint uses the bot token for auth.
|
||||
return await self.request(
|
||||
"GET", f"/webhooks/{application_id}/{interaction_token}/messages/@original"
|
||||
)
|
||||
|
||||
async def edit_original_interaction_response(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
interaction_token: str,
|
||||
payload: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Edits the initial Interaction response."""
|
||||
return await self.request(
|
||||
"PATCH",
|
||||
f"/webhooks/{application_id}/{interaction_token}/messages/@original",
|
||||
payload=payload,
|
||||
use_auth_header=False,
|
||||
) # Docs imply webhook-style auth
|
||||
|
||||
async def delete_original_interaction_response(
|
||||
self, application_id: "Snowflake", interaction_token: str
|
||||
) -> None:
|
||||
"""Deletes the initial Interaction response."""
|
||||
await self.request(
|
||||
"DELETE",
|
||||
f"/webhooks/{application_id}/{interaction_token}/messages/@original",
|
||||
use_auth_header=False,
|
||||
) # Docs imply webhook-style auth
|
||||
|
||||
async def create_followup_message(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
interaction_token: str,
|
||||
payload: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Creates a followup message for an Interaction."""
|
||||
# Followup messages are sent to a webhook endpoint.
|
||||
return await self.request(
|
||||
"POST",
|
||||
f"/webhooks/{application_id}/{interaction_token}",
|
||||
payload=payload,
|
||||
use_auth_header=False,
|
||||
) # Docs imply webhook-style auth
|
||||
|
||||
async def edit_followup_message(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
interaction_token: str,
|
||||
message_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Edits a followup message for an Interaction."""
|
||||
return await self.request(
|
||||
"PATCH",
|
||||
f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}",
|
||||
payload=payload,
|
||||
use_auth_header=False,
|
||||
) # Docs imply webhook-style auth
|
||||
|
||||
async def delete_followup_message(
|
||||
self,
|
||||
application_id: "Snowflake",
|
||||
interaction_token: str,
|
||||
message_id: "Snowflake",
|
||||
) -> None:
|
||||
"""Deletes a followup message for an Interaction."""
|
||||
await self.request(
|
||||
"DELETE",
|
||||
f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}",
|
||||
use_auth_header=False,
|
||||
)
|
||||
|
||||
async def trigger_typing(self, channel_id: str) -> None:
|
||||
"""Sends a typing indicator to the specified channel."""
|
||||
await self.request("POST", f"/channels/{channel_id}/typing")
|
32
disagreement/hybrid_context.py
Normal file
32
disagreement/hybrid_context.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""Utility class for working with either command or app contexts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from .ext.commands.core import CommandContext
|
||||
from .ext.app_commands.context import AppCommandContext
|
||||
|
||||
|
||||
class HybridContext:
|
||||
"""Wraps :class:`CommandContext` and :class:`AppCommandContext`.
|
||||
|
||||
Provides a single :meth:`send` method that proxies to ``reply`` for
|
||||
prefix commands and to ``send`` for slash commands.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: Union[CommandContext, AppCommandContext]):
|
||||
self._ctx = ctx
|
||||
|
||||
async def send(self, *args: Any, **kwargs: Any):
|
||||
if isinstance(self._ctx, AppCommandContext):
|
||||
return await self._ctx.send(*args, **kwargs)
|
||||
return await self._ctx.reply(*args, **kwargs)
|
||||
|
||||
async def edit(self, *args: Any, **kwargs: Any):
|
||||
if hasattr(self._ctx, "edit"):
|
||||
return await self._ctx.edit(*args, **kwargs)
|
||||
raise AttributeError("Underlying context does not support editing.")
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._ctx, name)
|
22
disagreement/i18n.py
Normal file
22
disagreement/i18n.py
Normal file
@ -0,0 +1,22 @@
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
|
||||
_translations: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
|
||||
def set_translations(locale: str, mapping: Dict[str, str]) -> None:
|
||||
"""Set translations for a locale."""
|
||||
_translations[locale] = mapping
|
||||
|
||||
|
||||
def load_translations(locale: str, file_path: str) -> None:
|
||||
"""Load translations for *locale* from a JSON file."""
|
||||
with open(file_path, "r", encoding="utf-8") as handle:
|
||||
_translations[locale] = json.load(handle)
|
||||
|
||||
|
||||
def translate(key: str, locale: str, *, default: Optional[str] = None) -> str:
|
||||
"""Return the translated string for *key* in *locale*."""
|
||||
return _translations.get(locale, {}).get(
|
||||
key, default if default is not None else key
|
||||
)
|
572
disagreement/interactions.py
Normal file
572
disagreement/interactions.py
Normal file
@ -0,0 +1,572 @@
|
||||
# disagreement/interactions.py
|
||||
|
||||
"""
|
||||
Data models for Discord Interaction objects.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Union, Any, TYPE_CHECKING
|
||||
|
||||
from .enums import (
|
||||
ApplicationCommandType,
|
||||
ApplicationCommandOptionType,
|
||||
InteractionType,
|
||||
InteractionCallbackType,
|
||||
IntegrationType,
|
||||
InteractionContextType,
|
||||
ChannelType,
|
||||
)
|
||||
|
||||
# Runtime imports for models used in this module
|
||||
from .models import (
|
||||
User,
|
||||
Message,
|
||||
Member,
|
||||
Role,
|
||||
Embed,
|
||||
PartialChannel,
|
||||
Attachment,
|
||||
ActionRow,
|
||||
Component,
|
||||
AllowedMentions,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Import Client type only for type checking to avoid circular imports
|
||||
from .client import Client
|
||||
from .ui.modal import Modal
|
||||
|
||||
# MessageFlags, PartialAttachment can be added if/when defined
|
||||
|
||||
Snowflake = str
|
||||
|
||||
|
||||
# Based on Application Command Option Choice Structure
|
||||
class ApplicationCommandOptionChoice:
|
||||
"""Represents a choice for an application command option."""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.name: str = data["name"]
|
||||
self.value: Union[str, int, float] = data["value"]
|
||||
self.name_localizations: Optional[Dict[str, str]] = data.get(
|
||||
"name_localizations"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<ApplicationCommandOptionChoice name='{self.name}' value={self.value!r}>"
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {"name": self.name, "value": self.value}
|
||||
if self.name_localizations:
|
||||
payload["name_localizations"] = self.name_localizations
|
||||
return payload
|
||||
|
||||
|
||||
# Based on Application Command Option Structure
|
||||
class ApplicationCommandOption:
|
||||
"""Represents an option for an application command."""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.type: ApplicationCommandOptionType = ApplicationCommandOptionType(
|
||||
data["type"]
|
||||
)
|
||||
self.name: str = data["name"]
|
||||
self.description: str = data["description"]
|
||||
self.required: bool = data.get("required", False)
|
||||
|
||||
self.choices: Optional[List[ApplicationCommandOptionChoice]] = (
|
||||
[ApplicationCommandOptionChoice(c) for c in data["choices"]]
|
||||
if data.get("choices")
|
||||
else None
|
||||
)
|
||||
|
||||
self.options: Optional[List["ApplicationCommandOption"]] = (
|
||||
[ApplicationCommandOption(o) for o in data["options"]]
|
||||
if data.get("options")
|
||||
else None
|
||||
) # For subcommands/groups
|
||||
|
||||
self.channel_types: Optional[List[ChannelType]] = (
|
||||
[ChannelType(ct) for ct in data.get("channel_types", [])]
|
||||
if data.get("channel_types")
|
||||
else None
|
||||
)
|
||||
self.min_value: Optional[Union[int, float]] = data.get("min_value")
|
||||
self.max_value: Optional[Union[int, float]] = data.get("max_value")
|
||||
self.min_length: Optional[int] = data.get("min_length")
|
||||
self.max_length: Optional[int] = data.get("max_length")
|
||||
self.autocomplete: bool = data.get("autocomplete", False)
|
||||
self.name_localizations: Optional[Dict[str, str]] = data.get(
|
||||
"name_localizations"
|
||||
)
|
||||
self.description_localizations: Optional[Dict[str, str]] = data.get(
|
||||
"description_localizations"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<ApplicationCommandOption name='{self.name}' type={self.type!r}>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"type": self.type.value,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
}
|
||||
if self.required: # Defaults to False, only include if True
|
||||
payload["required"] = self.required
|
||||
if self.choices:
|
||||
payload["choices"] = [c.to_dict() for c in self.choices]
|
||||
if self.options: # For subcommands/groups
|
||||
payload["options"] = [o.to_dict() for o in self.options]
|
||||
if self.channel_types:
|
||||
payload["channel_types"] = [ct.value for ct in self.channel_types]
|
||||
if self.min_value is not None:
|
||||
payload["min_value"] = self.min_value
|
||||
if self.max_value is not None:
|
||||
payload["max_value"] = self.max_value
|
||||
if self.min_length is not None:
|
||||
payload["min_length"] = self.min_length
|
||||
if self.max_length is not None:
|
||||
payload["max_length"] = self.max_length
|
||||
if self.autocomplete: # Defaults to False, only include if True
|
||||
payload["autocomplete"] = self.autocomplete
|
||||
if self.name_localizations:
|
||||
payload["name_localizations"] = self.name_localizations
|
||||
if self.description_localizations:
|
||||
payload["description_localizations"] = self.description_localizations
|
||||
return payload
|
||||
|
||||
|
||||
# Based on Application Command Structure
|
||||
class ApplicationCommand:
|
||||
"""Represents an application command."""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.id: Optional[Snowflake] = data.get("id")
|
||||
self.type: ApplicationCommandType = ApplicationCommandType(
|
||||
data.get("type", 1)
|
||||
) # Default to CHAT_INPUT
|
||||
self.application_id: Optional[Snowflake] = data.get("application_id")
|
||||
self.guild_id: Optional[Snowflake] = data.get("guild_id")
|
||||
self.name: str = data["name"]
|
||||
self.description: str = data.get(
|
||||
"description", ""
|
||||
) # Empty for USER/MESSAGE commands
|
||||
|
||||
self.options: Optional[List[ApplicationCommandOption]] = (
|
||||
[ApplicationCommandOption(o) for o in data["options"]]
|
||||
if data.get("options")
|
||||
else None
|
||||
)
|
||||
|
||||
self.default_member_permissions: Optional[str] = data.get(
|
||||
"default_member_permissions"
|
||||
)
|
||||
self.dm_permission: Optional[bool] = data.get("dm_permission") # Deprecated
|
||||
self.nsfw: bool = data.get("nsfw", False)
|
||||
self.version: Optional[Snowflake] = data.get("version")
|
||||
self.name_localizations: Optional[Dict[str, str]] = data.get(
|
||||
"name_localizations"
|
||||
)
|
||||
self.description_localizations: Optional[Dict[str, str]] = data.get(
|
||||
"description_localizations"
|
||||
)
|
||||
|
||||
self.integration_types: Optional[List[IntegrationType]] = (
|
||||
[IntegrationType(it) for it in data["integration_types"]]
|
||||
if data.get("integration_types")
|
||||
else None
|
||||
)
|
||||
|
||||
self.contexts: Optional[List[InteractionContextType]] = (
|
||||
[InteractionContextType(c) for c in data["contexts"]]
|
||||
if data.get("contexts")
|
||||
else None
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<ApplicationCommand id='{self.id}' name='{self.name}' type={self.type!r}>"
|
||||
)
|
||||
|
||||
|
||||
# Based on Interaction Object's Resolved Data Structure
|
||||
class ResolvedData:
|
||||
"""Represents resolved data for an interaction."""
|
||||
|
||||
def __init__(
|
||||
self, data: dict, client_instance: Optional["Client"] = None
|
||||
): # client_instance for model hydration
|
||||
# Models are now imported in TYPE_CHECKING block
|
||||
|
||||
users_data = data.get("users", {})
|
||||
self.users: Dict[Snowflake, "User"] = {
|
||||
uid: User(udata) for uid, udata in users_data.items()
|
||||
}
|
||||
|
||||
self.members: Dict[Snowflake, "Member"] = {}
|
||||
for mid, mdata in data.get("members", {}).items():
|
||||
member_payload = dict(mdata)
|
||||
member_payload.setdefault("id", mid)
|
||||
if "user" not in member_payload and mid in users_data:
|
||||
member_payload["user"] = users_data[mid]
|
||||
self.members[mid] = Member(member_payload, client_instance=client_instance)
|
||||
|
||||
self.roles: Dict[Snowflake, "Role"] = {
|
||||
rid: Role(rdata) for rid, rdata in data.get("roles", {}).items()
|
||||
}
|
||||
|
||||
self.channels: Dict[Snowflake, "PartialChannel"] = {
|
||||
cid: PartialChannel(cdata, client_instance=client_instance)
|
||||
for cid, cdata in data.get("channels", {}).items()
|
||||
}
|
||||
|
||||
self.messages: Dict[Snowflake, "Message"] = (
|
||||
{
|
||||
mid: Message(mdata, client_instance=client_instance) for mid, mdata in data.get("messages", {}).items() # type: ignore[misc]
|
||||
}
|
||||
if client_instance
|
||||
else {}
|
||||
) # Only hydrate if client is available
|
||||
|
||||
self.attachments: Dict[Snowflake, "Attachment"] = {
|
||||
aid: Attachment(adata) for aid, adata in data.get("attachments", {}).items()
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<ResolvedData users={len(self.users)} members={len(self.members)} roles={len(self.roles)} channels={len(self.channels)} messages={len(self.messages)} attachments={len(self.attachments)}>"
|
||||
|
||||
|
||||
# Based on Interaction Object's Data Structure (for Application Commands)
|
||||
class InteractionData:
|
||||
"""Represents the data payload for an interaction."""
|
||||
|
||||
def __init__(self, data: dict, client_instance: Optional["Client"] = None):
|
||||
self.id: Optional[Snowflake] = data.get("id") # Command ID
|
||||
self.name: Optional[str] = data.get("name") # Command name
|
||||
self.type: Optional[ApplicationCommandType] = (
|
||||
ApplicationCommandType(data["type"]) if data.get("type") else None
|
||||
)
|
||||
|
||||
self.resolved: Optional[ResolvedData] = (
|
||||
ResolvedData(data["resolved"], client_instance=client_instance)
|
||||
if data.get("resolved")
|
||||
else None
|
||||
)
|
||||
|
||||
# For CHAT_INPUT, this is List[ApplicationCommandInteractionDataOption]
|
||||
# For USER/MESSAGE, this is not present or different.
|
||||
# For now, storing as raw list of dicts. Parsing can happen in handler.
|
||||
self.options: Optional[List[Dict[str, Any]]] = data.get("options")
|
||||
|
||||
# For message components
|
||||
self.custom_id: Optional[str] = data.get("custom_id")
|
||||
self.component_type: Optional[int] = data.get("component_type")
|
||||
self.values: Optional[List[str]] = data.get("values")
|
||||
|
||||
self.guild_id: Optional[Snowflake] = data.get("guild_id")
|
||||
self.target_id: Optional[Snowflake] = data.get(
|
||||
"target_id"
|
||||
) # For USER/MESSAGE commands
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<InteractionData id='{self.id}' name='{self.name}' type={self.type!r}>"
|
||||
|
||||
|
||||
# Based on Interaction Object Structure
|
||||
class Interaction:
|
||||
"""Represents an interaction from Discord."""
|
||||
|
||||
def __init__(self, data: dict, client_instance: "Client"):
|
||||
self._client: "Client" = client_instance
|
||||
|
||||
self.id: Snowflake = data["id"]
|
||||
self.application_id: Snowflake = data["application_id"]
|
||||
self.type: InteractionType = InteractionType(data["type"])
|
||||
|
||||
self.data: Optional[InteractionData] = (
|
||||
InteractionData(data["data"], client_instance=client_instance)
|
||||
if data.get("data")
|
||||
else None
|
||||
)
|
||||
|
||||
self.guild_id: Optional[Snowflake] = data.get("guild_id")
|
||||
self.channel_id: Optional[Snowflake] = data.get(
|
||||
"channel_id"
|
||||
) # Will be present on command invocations
|
||||
|
||||
member_data = data.get("member")
|
||||
user_data_from_member = (
|
||||
member_data.get("user") if isinstance(member_data, dict) else None
|
||||
)
|
||||
|
||||
self.member: Optional["Member"] = (
|
||||
Member(member_data, client_instance=self._client) if member_data else None
|
||||
)
|
||||
|
||||
# User object is included within member if in guild, otherwise it's top-level
|
||||
# If self.member was successfully hydrated, its .user attribute should be preferred if it exists.
|
||||
# However, Member.__init__ handles setting User attributes.
|
||||
# The primary source for User is data.get("user") or member_data.get("user").
|
||||
|
||||
if data.get("user"):
|
||||
self.user: Optional["User"] = User(data["user"])
|
||||
elif user_data_from_member:
|
||||
self.user: Optional["User"] = User(user_data_from_member)
|
||||
elif (
|
||||
self.member
|
||||
): # If member was hydrated and has user attributes (e.g. from Member(User) inheritance)
|
||||
# This assumes Member correctly populates its User parts.
|
||||
self.user: Optional["User"] = self.member # Member is a User subclass
|
||||
else:
|
||||
self.user: Optional["User"] = None
|
||||
|
||||
self.token: str = data["token"] # For responding to the interaction
|
||||
self.version: int = data["version"]
|
||||
|
||||
self.message: Optional["Message"] = (
|
||||
Message(data["message"], client_instance=client_instance)
|
||||
if data.get("message")
|
||||
else None
|
||||
) # For component interactions
|
||||
|
||||
self.app_permissions: Optional[str] = data.get(
|
||||
"app_permissions"
|
||||
) # Bitwise set of permissions the app has in the source channel
|
||||
self.locale: Optional[str] = data.get(
|
||||
"locale"
|
||||
) # Selected language of the invoking user
|
||||
self.guild_locale: Optional[str] = data.get(
|
||||
"guild_locale"
|
||||
) # Guild's preferred language
|
||||
|
||||
self.response = InteractionResponse(self)
|
||||
|
||||
async def respond(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
*,
|
||||
embed: Optional[Embed] = None,
|
||||
embeds: Optional[List[Embed]] = None,
|
||||
components: Optional[List[ActionRow]] = None,
|
||||
ephemeral: bool = False,
|
||||
tts: bool = False,
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Responds to this interaction.
|
||||
|
||||
Parameters:
|
||||
content (Optional[str]): The content of the message.
|
||||
embed (Optional[Embed]): A single embed to send.
|
||||
embeds (Optional[List[Embed]]): A list of embeds to send.
|
||||
components (Optional[List[ActionRow]]): A list of ActionRow components.
|
||||
ephemeral (bool): Whether the response should be ephemeral (only visible to the user).
|
||||
tts (bool): Whether the message should be sent with text-to-speech.
|
||||
"""
|
||||
if embed and embeds:
|
||||
raise ValueError("Cannot provide both embed and embeds.")
|
||||
|
||||
data: Dict[str, Any] = {}
|
||||
if tts:
|
||||
data["tts"] = True
|
||||
if content:
|
||||
data["content"] = content
|
||||
if embed:
|
||||
data["embeds"] = [embed.to_dict()]
|
||||
elif embeds:
|
||||
data["embeds"] = [e.to_dict() for e in embeds]
|
||||
if components:
|
||||
data["components"] = [c.to_dict() for c in components]
|
||||
if ephemeral:
|
||||
data["flags"] = 1 << 6 # EPHEMERAL flag
|
||||
|
||||
payload = InteractionResponsePayload(
|
||||
type=InteractionCallbackType.CHANNEL_MESSAGE_WITH_SOURCE,
|
||||
data=InteractionCallbackData(data),
|
||||
)
|
||||
|
||||
await self._client._http.create_interaction_response(
|
||||
interaction_id=self.id,
|
||||
interaction_token=self.token,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
async def respond_modal(self, modal: "Modal") -> None:
|
||||
"""|coro| Send a modal in response to this interaction."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
payload = {
|
||||
"type": InteractionCallbackType.MODAL.value,
|
||||
"data": modal.to_dict(),
|
||||
}
|
||||
await self._client._http.create_interaction_response(
|
||||
interaction_id=self.id,
|
||||
interaction_token=self.token,
|
||||
payload=cast(Any, payload),
|
||||
)
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
*,
|
||||
embed: Optional[Embed] = None,
|
||||
embeds: Optional[List[Embed]] = None,
|
||||
components: Optional[List[ActionRow]] = None,
|
||||
attachments: Optional[List[Any]] = None,
|
||||
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Edits the original response to this interaction.
|
||||
|
||||
If the interaction is from a component, this will acknowledge the
|
||||
interaction and update the message in one operation.
|
||||
|
||||
Parameters:
|
||||
content (Optional[str]): The new message content.
|
||||
embed (Optional[Embed]): A single embed to send. Ignored if
|
||||
``embeds`` is provided.
|
||||
embeds (Optional[List[Embed]]): A list of embeds to send.
|
||||
components (Optional[List[ActionRow]]): Updated components for the
|
||||
message.
|
||||
attachments (Optional[List[Any]]): Attachments to include with the
|
||||
message.
|
||||
allowed_mentions (Optional[Dict[str, Any]]): Controls mentions in the
|
||||
message.
|
||||
"""
|
||||
if embed and embeds:
|
||||
raise ValueError("Cannot provide both embed and embeds.")
|
||||
|
||||
payload_data: Dict[str, Any] = {}
|
||||
if content is not None:
|
||||
payload_data["content"] = content
|
||||
if embed:
|
||||
payload_data["embeds"] = [embed.to_dict()]
|
||||
elif embeds is not None:
|
||||
payload_data["embeds"] = [e.to_dict() for e in embeds]
|
||||
if components is not None:
|
||||
payload_data["components"] = [c.to_dict() for c in components]
|
||||
if attachments is not None:
|
||||
payload_data["attachments"] = [
|
||||
a.to_dict() if hasattr(a, "to_dict") else a for a in attachments
|
||||
]
|
||||
if allowed_mentions is not None:
|
||||
payload_data["allowed_mentions"] = allowed_mentions
|
||||
|
||||
if self.type == InteractionType.MESSAGE_COMPONENT:
|
||||
# For component interactions, we send an UPDATE_MESSAGE response
|
||||
# to acknowledge the interaction and edit the message simultaneously.
|
||||
payload = InteractionResponsePayload(
|
||||
type=InteractionCallbackType.UPDATE_MESSAGE,
|
||||
data=InteractionCallbackData(payload_data),
|
||||
)
|
||||
await self._client._http.create_interaction_response(
|
||||
self.id, self.token, payload
|
||||
)
|
||||
else:
|
||||
# For other interaction types (like an initial slash command response),
|
||||
# we edit the original response via the webhook endpoint.
|
||||
await self._client._http.edit_original_interaction_response(
|
||||
application_id=self.application_id,
|
||||
interaction_token=self.token,
|
||||
payload=payload_data,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Interaction id='{self.id}' type={self.type!r}>"
|
||||
|
||||
|
||||
class InteractionResponse:
|
||||
"""Helper for sending responses for an :class:`Interaction`."""
|
||||
|
||||
def __init__(self, interaction: "Interaction") -> None:
|
||||
self._interaction = interaction
|
||||
|
||||
async def send_modal(self, modal: "Modal") -> None:
|
||||
"""Sends a modal response."""
|
||||
payload = InteractionResponsePayload(
|
||||
type=InteractionCallbackType.MODAL,
|
||||
data=InteractionCallbackData(modal.to_dict()),
|
||||
)
|
||||
await self._interaction._client._http.create_interaction_response(
|
||||
self._interaction.id,
|
||||
self._interaction.token,
|
||||
payload,
|
||||
)
|
||||
|
||||
|
||||
# Based on Interaction Response Object's Data Structure
|
||||
class InteractionCallbackData:
|
||||
"""Data for an interaction response."""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.tts: Optional[bool] = data.get("tts")
|
||||
self.content: Optional[str] = data.get("content")
|
||||
self.embeds: Optional[List[Embed]] = (
|
||||
[Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None
|
||||
)
|
||||
self.allowed_mentions: Optional[AllowedMentions] = (
|
||||
AllowedMentions(data["allowed_mentions"])
|
||||
if data.get("allowed_mentions")
|
||||
else None
|
||||
)
|
||||
self.flags: Optional[int] = data.get("flags") # MessageFlags enum could be used
|
||||
from .components import component_factory
|
||||
|
||||
self.components: Optional[List[Component]] = (
|
||||
[component_factory(c) for c in data.get("components", [])]
|
||||
if data.get("components")
|
||||
else None
|
||||
)
|
||||
self.attachments: Optional[List[Attachment]] = (
|
||||
[Attachment(a) for a in data.get("attachments", [])]
|
||||
if data.get("attachments")
|
||||
else None
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# Helper to convert to dict for sending to Discord API
|
||||
payload = {}
|
||||
if self.tts is not None:
|
||||
payload["tts"] = self.tts
|
||||
if self.content is not None:
|
||||
payload["content"] = self.content
|
||||
if self.embeds is not None:
|
||||
payload["embeds"] = [e.to_dict() for e in self.embeds]
|
||||
if self.allowed_mentions is not None:
|
||||
payload["allowed_mentions"] = self.allowed_mentions.to_dict()
|
||||
if self.flags is not None:
|
||||
payload["flags"] = self.flags
|
||||
if self.components is not None:
|
||||
payload["components"] = [c.to_dict() for c in self.components]
|
||||
if self.attachments is not None:
|
||||
payload["attachments"] = [a.to_dict() for a in self.attachments]
|
||||
return payload
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<InteractionCallbackData content='{self.content[:20] if self.content else None}'>"
|
||||
|
||||
|
||||
# Based on Interaction Response Object Structure
|
||||
class InteractionResponsePayload:
|
||||
"""Payload for responding to an interaction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: InteractionCallbackType,
|
||||
data: Optional[InteractionCallbackData] = None,
|
||||
):
|
||||
self.type: InteractionCallbackType = type
|
||||
self.data: Optional[InteractionCallbackData] = data
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {"type": self.type.value}
|
||||
if self.data:
|
||||
payload["data"] = self.data.to_dict()
|
||||
return payload
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<InteractionResponsePayload type={self.type!r}>"
|
26
disagreement/logging_config.py
Normal file
26
disagreement/logging_config.py
Normal file
@ -0,0 +1,26 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_logging(level: int, file: Optional[str] = None) -> None:
|
||||
"""Configure logging for the library.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
level:
|
||||
Logging level from the :mod:`logging` module.
|
||||
file:
|
||||
Optional file path to write logs to. If ``None``, logs are sent to
|
||||
standard output.
|
||||
"""
|
||||
handlers: list[logging.Handler] = []
|
||||
if file is None:
|
||||
handlers.append(logging.StreamHandler())
|
||||
else:
|
||||
handlers.append(logging.FileHandler(file))
|
||||
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
handlers=handlers,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
1642
disagreement/models.py
Normal file
1642
disagreement/models.py
Normal file
File diff suppressed because it is too large
Load Diff
109
disagreement/oauth.py
Normal file
109
disagreement/oauth.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""OAuth2 utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from .errors import HTTPException
|
||||
|
||||
|
||||
def build_authorization_url(
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scope: Union[str, List[str]],
|
||||
*,
|
||||
state: Optional[str] = None,
|
||||
response_type: str = "code",
|
||||
prompt: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Return the Discord OAuth2 authorization URL."""
|
||||
if isinstance(scope, list):
|
||||
scope = " ".join(scope)
|
||||
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": response_type,
|
||||
"scope": scope,
|
||||
}
|
||||
if state is not None:
|
||||
params["state"] = state
|
||||
if prompt is not None:
|
||||
params["prompt"] = prompt
|
||||
|
||||
return "https://discord.com/oauth2/authorize?" + urlencode(params)
|
||||
|
||||
|
||||
async def exchange_code_for_token(
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Exchange an authorization code for an access token."""
|
||||
close = False
|
||||
if session is None:
|
||||
session = aiohttp.ClientSession()
|
||||
close = True
|
||||
|
||||
data = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
resp = await session.post(
|
||||
"https://discord.com/api/v10/oauth2/token",
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
try:
|
||||
json_data = await resp.json()
|
||||
if resp.status != 200:
|
||||
raise HTTPException(resp, message="OAuth token exchange failed")
|
||||
finally:
|
||||
if close:
|
||||
await session.close()
|
||||
return json_data
|
||||
|
||||
|
||||
async def refresh_access_token(
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
*,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Refresh an access token using a refresh token."""
|
||||
|
||||
close = False
|
||||
if session is None:
|
||||
session = aiohttp.ClientSession()
|
||||
close = True
|
||||
|
||||
data = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
resp = await session.post(
|
||||
"https://discord.com/api/v10/oauth2/token",
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
try:
|
||||
json_data = await resp.json()
|
||||
if resp.status != 200:
|
||||
raise HTTPException(resp, message="OAuth token refresh failed")
|
||||
finally:
|
||||
if close:
|
||||
await session.close()
|
||||
return json_data
|
99
disagreement/permissions.py
Normal file
99
disagreement/permissions.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""Utility helpers for working with Discord permission bitmasks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntFlag
|
||||
from typing import Iterable, List
|
||||
|
||||
|
||||
class Permissions(IntFlag):
|
||||
"""Discord guild and channel permissions."""
|
||||
|
||||
CREATE_INSTANT_INVITE = 1 << 0
|
||||
KICK_MEMBERS = 1 << 1
|
||||
BAN_MEMBERS = 1 << 2
|
||||
ADMINISTRATOR = 1 << 3
|
||||
MANAGE_CHANNELS = 1 << 4
|
||||
MANAGE_GUILD = 1 << 5
|
||||
ADD_REACTIONS = 1 << 6
|
||||
VIEW_AUDIT_LOG = 1 << 7
|
||||
PRIORITY_SPEAKER = 1 << 8
|
||||
STREAM = 1 << 9
|
||||
VIEW_CHANNEL = 1 << 10
|
||||
SEND_MESSAGES = 1 << 11
|
||||
SEND_TTS_MESSAGES = 1 << 12
|
||||
MANAGE_MESSAGES = 1 << 13
|
||||
EMBED_LINKS = 1 << 14
|
||||
ATTACH_FILES = 1 << 15
|
||||
READ_MESSAGE_HISTORY = 1 << 16
|
||||
MENTION_EVERYONE = 1 << 17
|
||||
USE_EXTERNAL_EMOJIS = 1 << 18
|
||||
VIEW_GUILD_INSIGHTS = 1 << 19
|
||||
CONNECT = 1 << 20
|
||||
SPEAK = 1 << 21
|
||||
MUTE_MEMBERS = 1 << 22
|
||||
DEAFEN_MEMBERS = 1 << 23
|
||||
MOVE_MEMBERS = 1 << 24
|
||||
USE_VAD = 1 << 25
|
||||
CHANGE_NICKNAME = 1 << 26
|
||||
MANAGE_NICKNAMES = 1 << 27
|
||||
MANAGE_ROLES = 1 << 28
|
||||
MANAGE_WEBHOOKS = 1 << 29
|
||||
MANAGE_GUILD_EXPRESSIONS = 1 << 30
|
||||
USE_APPLICATION_COMMANDS = 1 << 31
|
||||
REQUEST_TO_SPEAK = 1 << 32
|
||||
MANAGE_EVENTS = 1 << 33
|
||||
MANAGE_THREADS = 1 << 34
|
||||
CREATE_PUBLIC_THREADS = 1 << 35
|
||||
CREATE_PRIVATE_THREADS = 1 << 36
|
||||
USE_EXTERNAL_STICKERS = 1 << 37
|
||||
SEND_MESSAGES_IN_THREADS = 1 << 38
|
||||
USE_EMBEDDED_ACTIVITIES = 1 << 39
|
||||
MODERATE_MEMBERS = 1 << 40
|
||||
VIEW_CREATOR_MONETIZATION_ANALYTICS = 1 << 41
|
||||
USE_SOUNDBOARD = 1 << 42
|
||||
CREATE_GUILD_EXPRESSIONS = 1 << 43
|
||||
CREATE_EVENTS = 1 << 44
|
||||
USE_EXTERNAL_SOUNDS = 1 << 45
|
||||
SEND_VOICE_MESSAGES = 1 << 46
|
||||
|
||||
|
||||
def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int:
|
||||
"""Return a combined integer value for multiple permissions."""
|
||||
|
||||
value = 0
|
||||
for perm in perms:
|
||||
if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)):
|
||||
value |= permissions_value(*perm)
|
||||
else:
|
||||
value |= int(perm)
|
||||
return value
|
||||
|
||||
|
||||
def has_permissions(
|
||||
current: int | str | Permissions,
|
||||
*perms: Permissions | int | Iterable[Permissions | int],
|
||||
) -> bool:
|
||||
"""Return ``True`` if ``current`` includes all ``perms``."""
|
||||
|
||||
current_val = int(current)
|
||||
needed = permissions_value(*perms)
|
||||
return (current_val & needed) == needed
|
||||
|
||||
|
||||
def missing_permissions(
|
||||
current: int | str | Permissions,
|
||||
*perms: Permissions | int | Iterable[Permissions | int],
|
||||
) -> List[Permissions]:
|
||||
"""Return the subset of ``perms`` not present in ``current``."""
|
||||
|
||||
current_val = int(current)
|
||||
missing: List[Permissions] = []
|
||||
for perm in perms:
|
||||
if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)):
|
||||
missing.extend(missing_permissions(current_val, *perm))
|
||||
else:
|
||||
perm_val = int(perm)
|
||||
if not current_val & perm_val:
|
||||
missing.append(Permissions(perm_val))
|
||||
return missing
|
75
disagreement/rate_limiter.py
Normal file
75
disagreement/rate_limiter.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Asynchronous rate limiter for Discord HTTP requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Mapping
|
||||
|
||||
|
||||
class _Bucket:
|
||||
def __init__(self) -> None:
|
||||
self.remaining: int = 1
|
||||
self.reset_at: float = 0.0
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter implementing per-route buckets and a global queue."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buckets: Dict[str, _Bucket] = {}
|
||||
self._global_event = asyncio.Event()
|
||||
self._global_event.set()
|
||||
|
||||
def _get_bucket(self, route: str) -> _Bucket:
|
||||
bucket = self._buckets.get(route)
|
||||
if bucket is None:
|
||||
bucket = _Bucket()
|
||||
self._buckets[route] = bucket
|
||||
return bucket
|
||||
|
||||
async def acquire(self, route: str) -> _Bucket:
|
||||
bucket = self._get_bucket(route)
|
||||
while True:
|
||||
await self._global_event.wait()
|
||||
async with bucket.lock:
|
||||
now = time.monotonic()
|
||||
if bucket.remaining <= 0 and now < bucket.reset_at:
|
||||
await asyncio.sleep(bucket.reset_at - now)
|
||||
continue
|
||||
if bucket.remaining > 0:
|
||||
bucket.remaining -= 1
|
||||
return bucket
|
||||
|
||||
def release(self, route: str, headers: Mapping[str, str]) -> None:
|
||||
bucket = self._get_bucket(route)
|
||||
try:
|
||||
remaining = int(headers.get("X-RateLimit-Remaining", bucket.remaining))
|
||||
reset_after = float(headers.get("X-RateLimit-Reset-After", "0"))
|
||||
bucket.remaining = remaining
|
||||
bucket.reset_at = time.monotonic() + reset_after
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if headers.get("X-RateLimit-Global", "false").lower() == "true":
|
||||
retry_after = float(headers.get("Retry-After", "0"))
|
||||
self._global_event.clear()
|
||||
asyncio.create_task(self._lift_global(retry_after))
|
||||
|
||||
async def handle_rate_limit(
|
||||
self, route: str, retry_after: float, is_global: bool
|
||||
) -> None:
|
||||
bucket = self._get_bucket(route)
|
||||
bucket.remaining = 0
|
||||
bucket.reset_at = time.monotonic() + retry_after
|
||||
if is_global:
|
||||
self._global_event.clear()
|
||||
await asyncio.sleep(retry_after)
|
||||
self._global_event.set()
|
||||
else:
|
||||
await asyncio.sleep(retry_after)
|
||||
|
||||
async def _lift_global(self, delay: float) -> None:
|
||||
await asyncio.sleep(delay)
|
||||
self._global_event.set()
|
65
disagreement/shard_manager.py
Normal file
65
disagreement/shard_manager.py
Normal file
@ -0,0 +1,65 @@
|
||||
# disagreement/shard_manager.py
|
||||
|
||||
"""Sharding utilities for managing multiple gateway connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
from .gateway import GatewayClient
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - for type checking only
|
||||
from .client import Client
|
||||
|
||||
|
||||
class Shard:
|
||||
"""Represents a single gateway shard."""
|
||||
|
||||
def __init__(self, shard_id: int, shard_count: int, gateway: GatewayClient) -> None:
|
||||
self.id: int = shard_id
|
||||
self.count: int = shard_count
|
||||
self.gateway: GatewayClient = gateway
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connects this shard's gateway."""
|
||||
await self.gateway.connect()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Closes this shard's gateway."""
|
||||
await self.gateway.close()
|
||||
|
||||
|
||||
class ShardManager:
|
||||
"""Manages multiple :class:`Shard` instances."""
|
||||
|
||||
def __init__(self, client: "Client", shard_count: int) -> None:
|
||||
self.client: "Client" = client
|
||||
self.shard_count: int = shard_count
|
||||
self.shards: List[Shard] = []
|
||||
|
||||
def _create_shards(self) -> None:
|
||||
if self.shards:
|
||||
return
|
||||
for shard_id in range(self.shard_count):
|
||||
gateway = GatewayClient(
|
||||
http_client=self.client._http,
|
||||
event_dispatcher=self.client._event_dispatcher,
|
||||
token=self.client.token,
|
||||
intents=self.client.intents,
|
||||
client_instance=self.client,
|
||||
verbose=self.client.verbose,
|
||||
shard_id=shard_id,
|
||||
shard_count=self.shard_count,
|
||||
)
|
||||
self.shards.append(Shard(shard_id, self.shard_count, gateway))
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Starts all shards."""
|
||||
self._create_shards()
|
||||
await asyncio.gather(*(s.connect() for s in self.shards))
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Closes all shards."""
|
||||
await asyncio.gather(*(s.close() for s in self.shards))
|
||||
self.shards.clear()
|
42
disagreement/typing.py
Normal file
42
disagreement/typing.py
Normal file
@ -0,0 +1,42 @@
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from .errors import DisagreementException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
|
||||
if __name__ == "typing":
|
||||
# For direct module execution testing
|
||||
pass
|
||||
|
||||
|
||||
class Typing:
|
||||
"""Async context manager for Discord typing indicator."""
|
||||
|
||||
def __init__(self, client: "Client", channel_id: str) -> None:
|
||||
self._client = client
|
||||
self._channel_id = channel_id
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def _run(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
await self._client._http.trigger_typing(self._channel_id)
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> "Typing":
|
||||
if self._client._closed:
|
||||
raise DisagreementException("Client is closed.")
|
||||
await self._client._http.trigger_typing(self._channel_id)
|
||||
self._task = asyncio.create_task(self._run())
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await self._task
|
17
disagreement/ui/__init__.py
Normal file
17
disagreement/ui/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from .view import View
|
||||
from .item import Item
|
||||
from .button import Button, button
|
||||
from .select import Select, select
|
||||
from .modal import Modal, TextInput, text_input
|
||||
|
||||
__all__ = [
|
||||
"View",
|
||||
"Item",
|
||||
"Button",
|
||||
"button",
|
||||
"Select",
|
||||
"select",
|
||||
"Modal",
|
||||
"TextInput",
|
||||
"text_input",
|
||||
]
|
99
disagreement/ui/button.py
Normal file
99
disagreement/ui/button.py
Normal file
@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING
|
||||
|
||||
from .item import Item
|
||||
from ..enums import ComponentType, ButtonStyle
|
||||
from ..models import PartialEmoji, to_partial_emoji
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..interactions import Interaction
|
||||
|
||||
|
||||
class Button(Item):
|
||||
"""Represents a button component in a View.
|
||||
|
||||
Args:
|
||||
style (ButtonStyle): The style of the button.
|
||||
label (Optional[str]): The text that appears on the button.
|
||||
emoji (Optional[str | PartialEmoji]): The emoji that appears on the button.
|
||||
custom_id (Optional[str]): The developer-defined identifier for the button.
|
||||
url (Optional[str]): The URL for the button.
|
||||
disabled (bool): Whether the button is disabled.
|
||||
row (Optional[int]): The row the button should be placed in, from 0 to 4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
style: ButtonStyle = ButtonStyle.SECONDARY,
|
||||
label: Optional[str] = None,
|
||||
emoji: Optional[str | PartialEmoji] = None,
|
||||
custom_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
disabled: bool = False,
|
||||
row: Optional[int] = None,
|
||||
):
|
||||
super().__init__(type=ComponentType.BUTTON)
|
||||
if not label and not emoji:
|
||||
raise ValueError("A button must have a label and/or an emoji.")
|
||||
|
||||
if url and custom_id:
|
||||
raise ValueError("A button cannot have both a URL and a custom_id.")
|
||||
|
||||
self.style = style
|
||||
self.label = label
|
||||
self.emoji = to_partial_emoji(emoji)
|
||||
self.custom_id = custom_id
|
||||
self.url = url
|
||||
self.disabled = disabled
|
||||
self._row = row
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Converts the button to a dictionary that can be sent to Discord."""
|
||||
payload = {
|
||||
"type": ComponentType.BUTTON.value,
|
||||
"style": self.style.value,
|
||||
"disabled": self.disabled,
|
||||
}
|
||||
if self.label:
|
||||
payload["label"] = self.label
|
||||
if self.emoji:
|
||||
payload["emoji"] = self.emoji.to_dict()
|
||||
if self.url:
|
||||
payload["url"] = self.url
|
||||
if self.custom_id:
|
||||
payload["custom_id"] = self.custom_id
|
||||
return payload
|
||||
|
||||
|
||||
def button(
|
||||
*,
|
||||
label: Optional[str] = None,
|
||||
custom_id: Optional[str] = None,
|
||||
style: ButtonStyle = ButtonStyle.SECONDARY,
|
||||
emoji: Optional[str | PartialEmoji] = None,
|
||||
url: Optional[str] = None,
|
||||
disabled: bool = False,
|
||||
row: Optional[int] = None,
|
||||
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Button]:
|
||||
"""A decorator to create a button in a View."""
|
||||
|
||||
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Button:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Button callback must be a coroutine function.")
|
||||
|
||||
item = Button(
|
||||
label=label,
|
||||
custom_id=custom_id,
|
||||
style=style,
|
||||
emoji=emoji,
|
||||
url=url,
|
||||
disabled=disabled,
|
||||
row=row,
|
||||
)
|
||||
item.callback = func
|
||||
return item
|
||||
|
||||
return decorator
|
38
disagreement/ui/item.py
Normal file
38
disagreement/ui/item.py
Normal file
@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING
|
||||
|
||||
from ..models import Component
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .view import View
|
||||
from ..interactions import Interaction
|
||||
|
||||
|
||||
class Item(Component):
|
||||
"""Represents a UI item that can be placed in a View.
|
||||
|
||||
This is a base class and is not meant to be used directly.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._view: Optional[View] = None
|
||||
self._row: Optional[int] = None
|
||||
# This is the callback associated with this item.
|
||||
self.callback: Optional[
|
||||
Callable[["View", Interaction], Coroutine[Any, Any, Any]]
|
||||
] = None
|
||||
|
||||
@property
|
||||
def view(self) -> Optional[View]:
|
||||
return self._view
|
||||
|
||||
@property
|
||||
def row(self) -> Optional[int]:
|
||||
return self._row
|
||||
|
||||
def _refresh_from_data(self, data: dict[str, Any]) -> None:
|
||||
# This is used to update the item's state from incoming interaction data.
|
||||
# For example, a button's disabled state could be updated here.
|
||||
pass
|
132
disagreement/ui/modal.py
Normal file
132
disagreement/ui/modal.py
Normal file
@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Coroutine, Optional, List, TYPE_CHECKING
|
||||
import asyncio
|
||||
|
||||
from .item import Item
|
||||
from .view import View
|
||||
from ..enums import ComponentType, TextInputStyle
|
||||
from ..models import ActionRow
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - for type hints only
|
||||
from ..interactions import Interaction
|
||||
|
||||
|
||||
class TextInput(Item):
|
||||
"""Represents a text input component inside a modal."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
label: str,
|
||||
custom_id: Optional[str] = None,
|
||||
style: TextInputStyle = TextInputStyle.SHORT,
|
||||
placeholder: Optional[str] = None,
|
||||
required: bool = True,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
row: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(type=ComponentType.TEXT_INPUT)
|
||||
self.label = label
|
||||
self.custom_id = custom_id
|
||||
self.style = style
|
||||
self.placeholder = placeholder
|
||||
self.required = required
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self._row = row
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload = {
|
||||
"type": ComponentType.TEXT_INPUT.value,
|
||||
"style": self.style.value,
|
||||
"label": self.label,
|
||||
"required": self.required,
|
||||
}
|
||||
if self.custom_id:
|
||||
payload["custom_id"] = self.custom_id
|
||||
if self.placeholder:
|
||||
payload["placeholder"] = self.placeholder
|
||||
if self.min_length is not None:
|
||||
payload["min_length"] = self.min_length
|
||||
if self.max_length is not None:
|
||||
payload["max_length"] = self.max_length
|
||||
return payload
|
||||
|
||||
|
||||
def text_input(
|
||||
*,
|
||||
label: str,
|
||||
custom_id: Optional[str] = None,
|
||||
style: TextInputStyle = TextInputStyle.SHORT,
|
||||
placeholder: Optional[str] = None,
|
||||
required: bool = True,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
row: Optional[int] = None,
|
||||
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], TextInput]:
|
||||
"""Decorator to define a text input callback inside a :class:`Modal`."""
|
||||
|
||||
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> TextInput:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("TextInput callback must be a coroutine function.")
|
||||
|
||||
item = TextInput(
|
||||
label=label,
|
||||
custom_id=custom_id,
|
||||
style=style,
|
||||
placeholder=placeholder,
|
||||
required=required,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
row=row,
|
||||
)
|
||||
item.callback = func
|
||||
return item
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Modal:
|
||||
"""Represents a modal dialog."""
|
||||
|
||||
def __init__(self, *, title: str, custom_id: str) -> None:
|
||||
self.title = title
|
||||
self.custom_id = custom_id
|
||||
self._children: List[TextInput] = []
|
||||
|
||||
for item in self.__class__.__dict__.values():
|
||||
if isinstance(item, TextInput):
|
||||
self.add_item(item)
|
||||
|
||||
@property
|
||||
def children(self) -> List[TextInput]:
|
||||
return self._children
|
||||
|
||||
def add_item(self, item: TextInput) -> None:
|
||||
if not isinstance(item, TextInput):
|
||||
raise TypeError("Only TextInput items can be added to a Modal.")
|
||||
if len(self._children) >= 5:
|
||||
raise ValueError("A modal can only have up to 5 text inputs.")
|
||||
item._view = None # Not part of a view but reuse item base
|
||||
self._children.append(item)
|
||||
|
||||
def to_components(self) -> List[ActionRow]:
|
||||
rows: List[ActionRow] = []
|
||||
for child in self.children:
|
||||
row = ActionRow(components=[child])
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"title": self.title,
|
||||
"custom_id": self.custom_id,
|
||||
"components": [r.to_dict() for r in self.to_components()],
|
||||
}
|
||||
|
||||
async def callback(
|
||||
self, interaction: Interaction
|
||||
) -> None: # pragma: no cover - default
|
||||
pass
|
92
disagreement/ui/select.py
Normal file
92
disagreement/ui/select.py
Normal file
@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Callable, Coroutine, List, Optional, TYPE_CHECKING
|
||||
|
||||
from .item import Item
|
||||
from ..enums import ComponentType
|
||||
from ..models import SelectOption
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..interactions import Interaction
|
||||
|
||||
|
||||
class Select(Item):
|
||||
"""Represents a select menu component in a View.
|
||||
|
||||
Args:
|
||||
custom_id (str): The developer-defined identifier for the select menu.
|
||||
options (List[SelectOption]): The choices in the select menu.
|
||||
placeholder (Optional[str]): The placeholder text that is shown if nothing is selected.
|
||||
min_values (int): The minimum number of items that must be chosen.
|
||||
max_values (int): The maximum number of items that can be chosen.
|
||||
disabled (bool): Whether the select menu is disabled.
|
||||
row (Optional[int]): The row the select menu should be placed in, from 0 to 4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
custom_id: str,
|
||||
options: List[SelectOption],
|
||||
placeholder: Optional[str] = None,
|
||||
min_values: int = 1,
|
||||
max_values: int = 1,
|
||||
disabled: bool = False,
|
||||
row: Optional[int] = None,
|
||||
):
|
||||
super().__init__(type=ComponentType.STRING_SELECT)
|
||||
self.custom_id = custom_id
|
||||
self.options = options
|
||||
self.placeholder = placeholder
|
||||
self.min_values = min_values
|
||||
self.max_values = max_values
|
||||
self.disabled = disabled
|
||||
self._row = row
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Converts the select menu to a dictionary that can be sent to Discord."""
|
||||
payload = {
|
||||
"type": ComponentType.STRING_SELECT.value,
|
||||
"custom_id": self.custom_id,
|
||||
"options": [option.to_dict() for option in self.options],
|
||||
"disabled": self.disabled,
|
||||
}
|
||||
if self.placeholder:
|
||||
payload["placeholder"] = self.placeholder
|
||||
if self.min_values is not None:
|
||||
payload["min_values"] = self.min_values
|
||||
if self.max_values is not None:
|
||||
payload["max_values"] = self.max_values
|
||||
return payload
|
||||
|
||||
|
||||
def select(
|
||||
*,
|
||||
custom_id: str,
|
||||
options: List[SelectOption],
|
||||
placeholder: Optional[str] = None,
|
||||
min_values: int = 1,
|
||||
max_values: int = 1,
|
||||
disabled: bool = False,
|
||||
row: Optional[int] = None,
|
||||
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Select]:
|
||||
"""A decorator to create a select menu in a View."""
|
||||
|
||||
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Select:
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError("Select callback must be a coroutine function.")
|
||||
|
||||
item = Select(
|
||||
custom_id=custom_id,
|
||||
options=options,
|
||||
placeholder=placeholder,
|
||||
min_values=min_values,
|
||||
max_values=max_values,
|
||||
disabled=disabled,
|
||||
row=row,
|
||||
)
|
||||
item.callback = func
|
||||
return item
|
||||
|
||||
return decorator
|
165
disagreement/ui/view.py
Normal file
165
disagreement/ui/view.py
Normal file
@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from ..models import ActionRow
|
||||
from .item import Item
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import Client
|
||||
from ..interactions import Interaction
|
||||
|
||||
|
||||
class View:
|
||||
"""Represents a container for UI components that can be sent with a message.
|
||||
|
||||
Args:
|
||||
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
|
||||
Defaults to 180.
|
||||
"""
|
||||
|
||||
def __init__(self, *, timeout: Optional[float] = 180.0):
|
||||
self.timeout = timeout
|
||||
self.id = str(uuid.uuid4())
|
||||
self.__children: List[Item] = []
|
||||
self.__stopped = asyncio.Event()
|
||||
self._client: Optional[Client] = None
|
||||
self._message_id: Optional[str] = None
|
||||
|
||||
for item in self.__class__.__dict__.values():
|
||||
if isinstance(item, Item):
|
||||
self.add_item(item)
|
||||
|
||||
@property
|
||||
def children(self) -> List[Item]:
|
||||
return self.__children
|
||||
|
||||
def add_item(self, item: Item):
|
||||
"""Adds an item to the view."""
|
||||
if not isinstance(item, Item):
|
||||
raise TypeError("Only instances of 'Item' can be added to a View.")
|
||||
|
||||
if len(self.__children) >= 25:
|
||||
raise ValueError("A view can only have a maximum of 25 components.")
|
||||
|
||||
item._view = self
|
||||
self.__children.append(item)
|
||||
|
||||
@property
|
||||
def message_id(self) -> Optional[str]:
|
||||
return self._message_id
|
||||
|
||||
@message_id.setter
|
||||
def message_id(self, value: str):
|
||||
self._message_id = value
|
||||
|
||||
def to_components(self) -> List[ActionRow]:
|
||||
"""Converts the view's children into a list of ActionRow components.
|
||||
|
||||
This retains the original, simple layout behaviour where each item is
|
||||
placed in its own :class:`ActionRow` to ensure backward compatibility.
|
||||
"""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
if item.custom_id is None:
|
||||
item.custom_id = (
|
||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||
)
|
||||
|
||||
rows.append(ActionRow(components=[item]))
|
||||
|
||||
return rows
|
||||
|
||||
def layout_components_advanced(self) -> List[ActionRow]:
|
||||
"""Group compatible components into rows following Discord rules."""
|
||||
|
||||
rows: List[ActionRow] = []
|
||||
|
||||
for item in self.children:
|
||||
if item.custom_id is None:
|
||||
item.custom_id = (
|
||||
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||
)
|
||||
|
||||
target_row = item.row
|
||||
if target_row is not None:
|
||||
if not 0 <= target_row <= 4:
|
||||
raise ValueError("Row index must be between 0 and 4.")
|
||||
|
||||
while len(rows) <= target_row:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
rows.append(ActionRow())
|
||||
|
||||
rows[target_row].add_component(item)
|
||||
continue
|
||||
|
||||
placed = False
|
||||
for row in rows:
|
||||
try:
|
||||
row.add_component(item)
|
||||
placed = True
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not placed:
|
||||
if len(rows) >= 5:
|
||||
raise ValueError("A view can have at most 5 action rows.")
|
||||
new_row = ActionRow([item])
|
||||
rows.append(new_row)
|
||||
|
||||
return rows
|
||||
|
||||
def to_components_payload(self) -> List[Dict[str, Any]]:
|
||||
"""Converts the view's children into a list of component dictionaries
|
||||
that can be sent to the Discord API."""
|
||||
return [row.to_dict() for row in self.to_components()]
|
||||
|
||||
async def _dispatch(self, interaction: Interaction):
|
||||
"""Called by the client to dispatch an interaction to the correct item."""
|
||||
if self.timeout is not None:
|
||||
self.__stopped.set() # Reset the timeout on each interaction
|
||||
self.__stopped.clear()
|
||||
|
||||
if interaction.data:
|
||||
custom_id = interaction.data.custom_id
|
||||
for child in self.children:
|
||||
if child.custom_id == custom_id:
|
||||
if child.callback:
|
||||
await child.callback(self, interaction)
|
||||
break
|
||||
|
||||
async def wait(self) -> bool:
|
||||
"""Waits until the view has stopped interacting."""
|
||||
return await self.__stopped.wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stops the view from listening to interactions."""
|
||||
if not self.__stopped.is_set():
|
||||
self.__stopped.set()
|
||||
|
||||
async def on_timeout(self):
|
||||
"""Called when the view times out."""
|
||||
pass # User can override this
|
||||
|
||||
async def _start(self, client: Client):
|
||||
"""Starts the view's internal listener."""
|
||||
self._client = client
|
||||
if self.timeout is not None:
|
||||
asyncio.create_task(self._timeout_task())
|
||||
|
||||
async def _timeout_task(self):
|
||||
"""The task that waits for the timeout and then stops the view."""
|
||||
try:
|
||||
await asyncio.wait_for(self.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self.stop()
|
||||
await self.on_timeout()
|
||||
if self._client and self._message_id:
|
||||
# Remove the view from the client's listeners
|
||||
self._client._views.pop(self._message_id, None)
|
120
disagreement/voice_client.py
Normal file
120
disagreement/voice_client.py
Normal file
@ -0,0 +1,120 @@
|
||||
# disagreement/voice_client.py
|
||||
"""Voice gateway and UDP audio client."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class VoiceClient:
|
||||
"""Handles the Discord voice WebSocket connection and UDP streaming."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
session_id: str,
|
||||
token: str,
|
||||
guild_id: int,
|
||||
user_id: int,
|
||||
*,
|
||||
ws=None,
|
||||
udp: Optional[socket.socket] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.session_id = session_id
|
||||
self.token = token
|
||||
self.guild_id = str(guild_id)
|
||||
self.user_id = str(user_id)
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
|
||||
self._udp = udp
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_interval: Optional[float] = None
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self.verbose = verbose
|
||||
self.ssrc: Optional[int] = None
|
||||
self.secret_key: Optional[Sequence[int]] = None
|
||||
self._server_ip: Optional[str] = None
|
||||
self._server_port: Optional[int] = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._ws is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._ws = await self._session.ws_connect(self.endpoint)
|
||||
|
||||
hello = await self._ws.receive_json()
|
||||
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
|
||||
self._heartbeat_task = self._loop.create_task(self._heartbeat())
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 0,
|
||||
"d": {
|
||||
"server_id": self.guild_id,
|
||||
"user_id": self.user_id,
|
||||
"session_id": self.session_id,
|
||||
"token": self.token,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ready = await self._ws.receive_json()
|
||||
data = ready["d"]
|
||||
self.ssrc = data["ssrc"]
|
||||
self._server_ip = data["ip"]
|
||||
self._server_port = data["port"]
|
||||
|
||||
if self._udp is None:
|
||||
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._udp.connect((self._server_ip, self._server_port))
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"op": 1,
|
||||
"d": {
|
||||
"protocol": "udp",
|
||||
"data": {
|
||||
"address": self._udp.getsockname()[0],
|
||||
"port": self._udp.getsockname()[1],
|
||||
"mode": "xsalsa20_poly1305",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
session_desc = await self._ws.receive_json()
|
||||
self.secret_key = session_desc["d"].get("secret_key")
|
||||
|
||||
async def _heartbeat(self) -> None:
|
||||
assert self._ws is not None
|
||||
assert self._heartbeat_interval is not None
|
||||
try:
|
||||
while True:
|
||||
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
|
||||
await asyncio.sleep(self._heartbeat_interval)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def send_audio_frame(self, frame: bytes) -> None:
|
||||
if not self._udp:
|
||||
raise RuntimeError("UDP socket not initialised")
|
||||
self._udp.send(frame)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._heartbeat_task
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
if self._udp:
|
||||
self._udp.close()
|
18
docs/caching.md
Normal file
18
docs/caching.md
Normal file
@ -0,0 +1,18 @@
|
||||
# Caching
|
||||
|
||||
Disagreement ships with a simple in-memory cache used by the HTTP and Gateway clients. Cached objects reduce API requests and improve performance.
|
||||
|
||||
The client automatically caches guilds, channels and users as they are received from events or HTTP calls. You can access cached data through lookup helpers such as `Client.get_guild`.
|
||||
|
||||
The cache can be cleared manually if needed:
|
||||
|
||||
```python
|
||||
client.cache.clear()
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Components](using_components.md)
|
||||
- [Slash Commands](slash_commands.md)
|
||||
- [Voice Features](voice_features.md)
|
||||
|
51
docs/commands.md
Normal file
51
docs/commands.md
Normal file
@ -0,0 +1,51 @@
|
||||
# Commands Extension
|
||||
|
||||
This guide covers the built-in prefix command system.
|
||||
|
||||
## Help Command
|
||||
|
||||
The command handler registers a `help` command automatically. Use it to list all available commands or get information about a single command.
|
||||
|
||||
```
|
||||
!help # lists commands
|
||||
!help ping # shows help for the "ping" command
|
||||
```
|
||||
|
||||
The help command will show each command's brief description if provided.
|
||||
|
||||
## Checks
|
||||
|
||||
Use `commands.check` to prevent a command from running unless a predicate
|
||||
returns ``True``. Checks may be regular or async callables that accept a
|
||||
`CommandContext`.
|
||||
|
||||
```python
|
||||
from disagreement.ext.commands import command, check, CheckFailure
|
||||
|
||||
def is_owner(ctx):
|
||||
return ctx.author.id == "1"
|
||||
|
||||
@command()
|
||||
@check(is_owner)
|
||||
async def secret(ctx):
|
||||
await ctx.send("Only for the owner!")
|
||||
```
|
||||
|
||||
When a check fails a :class:`CheckFailure` is raised and dispatched through the
|
||||
command error handler.
|
||||
|
||||
## Cooldowns
|
||||
|
||||
Commands can be rate limited using the ``cooldown`` decorator. The example
|
||||
below restricts usage to once every three seconds per user:
|
||||
|
||||
```python
|
||||
from disagreement.ext.commands import command, cooldown
|
||||
|
||||
@command()
|
||||
@cooldown(1, 3.0)
|
||||
async def ping(ctx):
|
||||
await ctx.send("Pong!")
|
||||
```
|
||||
|
||||
Invoking a command while it is on cooldown raises :class:`CommandOnCooldown`.
|
21
docs/context_menus.md
Normal file
21
docs/context_menus.md
Normal file
@ -0,0 +1,21 @@
|
||||
# Context Menu Commands
|
||||
|
||||
`disagreement` supports Discord's user and message context menu commands. Use the
|
||||
`user_command` and `message_command` decorators from `ext.app_commands` to
|
||||
define them.
|
||||
|
||||
```python
|
||||
from disagreement.ext.app_commands import user_command, message_command, AppCommandContext
|
||||
from disagreement.models import User, Message
|
||||
|
||||
@user_command(name="User Info")
|
||||
async def user_info(ctx: AppCommandContext, user: User) -> None:
|
||||
await ctx.send(f"User: {user.username}#{user.discriminator}")
|
||||
|
||||
@message_command(name="Quote")
|
||||
async def quote(ctx: AppCommandContext, message: Message) -> None:
|
||||
await ctx.send(message.content)
|
||||
```
|
||||
|
||||
Add the commands to your client's handler and run `sync_commands()` to register
|
||||
them with Discord.
|
25
docs/converters.md
Normal file
25
docs/converters.md
Normal file
@ -0,0 +1,25 @@
|
||||
# Command Argument Converters
|
||||
|
||||
`disagreement.ext.commands` provides a number of built in converters that will parse string arguments into richer objects. These converters are automatically used when a command callback annotates its parameters with one of the supported types.
|
||||
|
||||
## Supported Types
|
||||
|
||||
- `int`, `float`, `bool`, and `str`
|
||||
- `Member` – resolves a user mention or ID to a `Member` object for the current guild
|
||||
- `Role` – resolves a role mention or ID to a `Role` object
|
||||
- `Guild` – resolves a guild ID to a `Guild` object
|
||||
|
||||
## Example
|
||||
|
||||
```python
|
||||
from disagreement.ext.commands import command
|
||||
from disagreement.ext.commands.core import CommandContext
|
||||
from disagreement.models import Member
|
||||
|
||||
@command()
|
||||
async def kick(ctx: CommandContext, target: Member):
|
||||
await target.kick()
|
||||
await ctx.send(f"Kicked {target.display_name}")
|
||||
```
|
||||
|
||||
The framework will automatically convert the first argument to a `Member` using the mention or ID provided by the user.
|
25
docs/events.md
Normal file
25
docs/events.md
Normal file
@ -0,0 +1,25 @@
|
||||
# Events
|
||||
|
||||
Disagreement dispatches Gateway events to asynchronous callbacks. Handlers can be registered with `@client.event` or `client.on_event`.
|
||||
Listeners may be removed later using `EventDispatcher.unregister(event_name, coro)`.
|
||||
|
||||
|
||||
## PRESENCE_UPDATE
|
||||
|
||||
Triggered when a user's presence changes. The callback receives a `PresenceUpdate` model.
|
||||
|
||||
```python
|
||||
@client.event
|
||||
async def on_presence_update(presence: disagreement.PresenceUpdate):
|
||||
...
|
||||
```
|
||||
|
||||
## TYPING_START
|
||||
|
||||
Dispatched when a user begins typing in a channel. The callback receives a `TypingStart` model.
|
||||
|
||||
```python
|
||||
@client.event
|
||||
async def on_typing_start(typing: disagreement.TypingStart):
|
||||
...
|
||||
```
|
17
docs/gateway.md
Normal file
17
docs/gateway.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Gateway Connection and Reconnection
|
||||
|
||||
`GatewayClient` manages the library's WebSocket connection. When the connection drops unexpectedly, it will now automatically attempt to reconnect using an exponential backoff strategy with jitter.
|
||||
|
||||
The default behaviour tries up to five reconnect attempts, doubling the delay each time up to a configurable maximum. A small random jitter is added to spread out reconnect attempts when multiple clients restart at once.
|
||||
|
||||
You can control the maximum number of retries and the backoff cap when constructing `Client`:
|
||||
|
||||
```python
|
||||
bot = Client(
|
||||
token="your-token",
|
||||
gateway_max_retries=10,
|
||||
gateway_max_backoff=120.0,
|
||||
)
|
||||
```
|
||||
|
||||
These values are passed to `GatewayClient` and applied whenever the connection needs to be re-established.
|
36
docs/i18n.md
Normal file
36
docs/i18n.md
Normal file
@ -0,0 +1,36 @@
|
||||
# Internationalization
|
||||
|
||||
Disagreement can translate command names, descriptions and other text using JSON translation files.
|
||||
|
||||
## Providing Translations
|
||||
|
||||
Use `disagreement.i18n.load_translations` to load a JSON file for a locale.
|
||||
|
||||
```python
|
||||
from disagreement import i18n
|
||||
|
||||
i18n.load_translations("es", "path/to/es.json")
|
||||
```
|
||||
|
||||
The JSON file should map translation keys to translated strings:
|
||||
|
||||
```json
|
||||
{
|
||||
"greet": "Hola",
|
||||
"description": "Comando de saludo"
|
||||
}
|
||||
```
|
||||
|
||||
You can also set translations programmatically with `i18n.set_translations`.
|
||||
|
||||
## Using with Commands
|
||||
|
||||
Pass a `locale` argument when defining an `AppCommand` or using the decorators. The command name and description will be looked up using the loaded translations.
|
||||
|
||||
```python
|
||||
@slash_command(name="greet", description="description", locale="es")
|
||||
async def greet(ctx):
|
||||
await ctx.send(i18n.translate("greet", ctx.locale or "es"))
|
||||
```
|
||||
|
||||
If a translation is missing the key itself is returned.
|
48
docs/oauth2.md
Normal file
48
docs/oauth2.md
Normal file
@ -0,0 +1,48 @@
|
||||
# OAuth2 Setup
|
||||
|
||||
This guide explains how to perform a basic OAuth2 flow with `disagreement`.
|
||||
|
||||
1. Generate the authorization URL:
|
||||
|
||||
```python
|
||||
from disagreement.oauth import build_authorization_url
|
||||
|
||||
url = build_authorization_url(
|
||||
client_id="YOUR_CLIENT_ID",
|
||||
redirect_uri="https://your.app/callback",
|
||||
scope=["identify"],
|
||||
)
|
||||
print(url)
|
||||
```
|
||||
|
||||
2. After the user authorizes your application and you receive a code, exchange it for a token:
|
||||
|
||||
```python
|
||||
import aiohttp
|
||||
from disagreement.oauth import exchange_code_for_token
|
||||
|
||||
async def get_token(code: str):
|
||||
return await exchange_code_for_token(
|
||||
client_id="YOUR_CLIENT_ID",
|
||||
client_secret="YOUR_CLIENT_SECRET",
|
||||
code=code,
|
||||
redirect_uri="https://your.app/callback",
|
||||
)
|
||||
```
|
||||
|
||||
`exchange_code_for_token` returns the JSON payload from Discord which includes
|
||||
`access_token`, `refresh_token` and expiry information.
|
||||
|
||||
3. When the access token expires, you can refresh it using the provided refresh
|
||||
token:
|
||||
|
||||
```python
|
||||
from disagreement.oauth import refresh_access_token
|
||||
|
||||
async def refresh(token: str):
|
||||
return await refresh_access_token(
|
||||
refresh_token=token,
|
||||
client_id="YOUR_CLIENT_ID",
|
||||
client_secret="YOUR_CLIENT_SECRET",
|
||||
)
|
||||
```
|
62
docs/permissions.md
Normal file
62
docs/permissions.md
Normal file
@ -0,0 +1,62 @@
|
||||
# Permission Helpers
|
||||
|
||||
The `disagreement.permissions` module defines an :class:`~enum.IntFlag`
|
||||
`Permissions` enumeration along with helper functions for working with
|
||||
the Discord permission bitmask.
|
||||
|
||||
## Permissions Enum
|
||||
|
||||
Each attribute of ``Permissions`` represents a single permission bit. The value
|
||||
is a power of two so multiple permissions can be combined using bitwise OR.
|
||||
|
||||
```python
|
||||
from disagreement.permissions import Permissions
|
||||
|
||||
value = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
|
||||
```
|
||||
|
||||
## Helper Functions
|
||||
|
||||
### ``permissions_value``
|
||||
|
||||
```python
|
||||
permissions_value(*perms) -> int
|
||||
```
|
||||
|
||||
Return an integer bitmask from one or more ``Permissions`` values. Nested
|
||||
iterables are flattened automatically.
|
||||
|
||||
### ``has_permissions``
|
||||
|
||||
```python
|
||||
has_permissions(current, *perms) -> bool
|
||||
```
|
||||
|
||||
Return ``True`` if ``current`` (an ``int`` or ``Permissions``) contains all of
|
||||
the provided permissions.
|
||||
|
||||
### ``missing_permissions``
|
||||
|
||||
```python
|
||||
missing_permissions(current, *perms) -> List[Permissions]
|
||||
```
|
||||
|
||||
Return a list of permissions that ``current`` does not contain.
|
||||
|
||||
## Example
|
||||
|
||||
```python
|
||||
from disagreement.permissions import (
|
||||
Permissions,
|
||||
has_permissions,
|
||||
missing_permissions,
|
||||
)
|
||||
|
||||
current = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
|
||||
|
||||
if has_permissions(current, Permissions.SEND_MESSAGES):
|
||||
print("Can send messages")
|
||||
|
||||
print(missing_permissions(current, Permissions.ADMINISTRATOR))
|
||||
```
|
||||
|
29
docs/presence.md
Normal file
29
docs/presence.md
Normal file
@ -0,0 +1,29 @@
|
||||
# Updating Presence
|
||||
|
||||
The `Client.change_presence` method allows you to update the bot's status and displayed activity.
|
||||
|
||||
## Status Strings
|
||||
|
||||
- `online` – show the bot as online
|
||||
- `idle` – mark the bot as away
|
||||
- `dnd` – do not disturb
|
||||
- `invisible` – appear offline
|
||||
|
||||
## Activity Types
|
||||
|
||||
An activity dictionary must include a `name` and a `type` field. The type value corresponds to Discord's activity types:
|
||||
|
||||
| Type | Meaning |
|
||||
|-----:|--------------|
|
||||
| `0` | Playing |
|
||||
| `1` | Streaming |
|
||||
| `2` | Listening |
|
||||
| `3` | Watching |
|
||||
| `4` | Custom |
|
||||
| `5` | Competing |
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
await client.change_presence(status="idle", activity={"name": "with Discord", "type": 0})
|
||||
```
|
32
docs/reactions.md
Normal file
32
docs/reactions.md
Normal file
@ -0,0 +1,32 @@
|
||||
# Handling Reactions
|
||||
|
||||
`disagreement` provides simple helpers for adding and removing message reactions.
|
||||
|
||||
## HTTP Methods
|
||||
|
||||
Use the `HTTPClient` methods directly if you need lower level control:
|
||||
|
||||
```python
|
||||
await client._http.create_reaction(channel_id, message_id, "👍")
|
||||
await client._http.delete_reaction(channel_id, message_id, "👍")
|
||||
users = await client._http.get_reactions(channel_id, message_id, "👍")
|
||||
```
|
||||
|
||||
You can also use the higher level helpers on :class:`Client`:
|
||||
|
||||
```python
|
||||
await client.create_reaction(channel_id, message_id, "👍")
|
||||
await client.delete_reaction(channel_id, message_id, "👍")
|
||||
users = await client.get_reactions(channel_id, message_id, "👍")
|
||||
```
|
||||
|
||||
## Reaction Events
|
||||
|
||||
Register listeners for `MESSAGE_REACTION_ADD` and `MESSAGE_REACTION_REMOVE`.
|
||||
Each listener receives a `Reaction` model instance.
|
||||
|
||||
```python
|
||||
@client.on_event("MESSAGE_REACTION_ADD")
|
||||
async def on_reaction(reaction: disagreement.Reaction):
|
||||
print(f"{reaction.user_id} reacted with {reaction.emoji}")
|
||||
```
|
22
docs/slash_commands.md
Normal file
22
docs/slash_commands.md
Normal file
@ -0,0 +1,22 @@
|
||||
# Using Slash Commands
|
||||
|
||||
The library provides a slash command framework via the `ext.app_commands` package. Define commands with decorators and register them with Discord.
|
||||
|
||||
```python
|
||||
from disagreement.ext.app_commands import AppCommandGroup
|
||||
|
||||
bot_commands = AppCommandGroup("bot", "Bot commands")
|
||||
|
||||
@bot_commands.command(name="ping")
|
||||
async def ping(ctx):
|
||||
await ctx.respond("Pong!")
|
||||
```
|
||||
|
||||
Use `AppCommandGroup` to group related commands. See the [components guide](using_components.md) for building interactive responses.
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Components](using_components.md)
|
||||
- [Caching](caching.md)
|
||||
- [Voice Features](voice_features.md)
|
||||
|
15
docs/task_loop.md
Normal file
15
docs/task_loop.md
Normal file
@ -0,0 +1,15 @@
|
||||
# Task Loops
|
||||
|
||||
The tasks extension allows you to run functions periodically. Decorate an async function with `@tasks.loop(seconds=...)` and start it using `.start()`.
|
||||
|
||||
```python
|
||||
from disagreement.ext import tasks
|
||||
|
||||
@tasks.loop(seconds=5.0)
|
||||
async def announce():
|
||||
print("Hello from a loop")
|
||||
|
||||
announce.start()
|
||||
```
|
||||
|
||||
Stop the loop with `.stop()` when you no longer need it.
|
16
docs/typing_indicator.md
Normal file
16
docs/typing_indicator.md
Normal file
@ -0,0 +1,16 @@
|
||||
# Typing Indicator
|
||||
|
||||
The library exposes an async context manager to send the typing indicator for a channel.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import disagreement
|
||||
|
||||
client = disagreement.Client(token="YOUR_TOKEN")
|
||||
|
||||
async def indicate(channel_id: str):
|
||||
async with client.typing(channel_id):
|
||||
await long_running_task()
|
||||
```
|
||||
|
||||
This uses the underlying HTTP endpoint `/channels/{channel_id}/typing`.
|
168
docs/using_components.md
Normal file
168
docs/using_components.md
Normal file
@ -0,0 +1,168 @@
|
||||
# Using Message Components
|
||||
|
||||
This guide explains how to work with the `disagreement` message component models. These examples are up to date with the current code base.
|
||||
|
||||
## Enabling the New Component System
|
||||
|
||||
Messages that use the component system must include the flag `IS_COMPONENTS_V2` (value `1 << 15`). Once this flag is set on a message it cannot be removed.
|
||||
|
||||
## Component Categories
|
||||
|
||||
The library exposes three broad categories of components:
|
||||
|
||||
- **Layout Components** – organize the placement of other components.
|
||||
- **Content Components** – display static text or media.
|
||||
- **Interactive Components** – allow the user to interact with your message.
|
||||
|
||||
## Action Row
|
||||
|
||||
`ActionRow` is a layout container. It may hold up to five buttons or a single select menu.
|
||||
|
||||
```python
|
||||
from disagreement.models import ActionRow, Button
|
||||
from disagreement.enums import ButtonStyle
|
||||
|
||||
row = ActionRow(components=[
|
||||
Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="btn")
|
||||
])
|
||||
```
|
||||
|
||||
## Button
|
||||
|
||||
Buttons provide a clickable UI element.
|
||||
|
||||
```python
|
||||
from disagreement.models import Button
|
||||
from disagreement.enums import ButtonStyle
|
||||
|
||||
button = Button(
|
||||
style=ButtonStyle.SUCCESS,
|
||||
label="Confirm",
|
||||
custom_id="confirm_button",
|
||||
)
|
||||
```
|
||||
|
||||
## Select Menus
|
||||
|
||||
`SelectMenu` lets the user choose one or more options. The `type` parameter controls the menu variety (`STRING_SELECT`, `USER_SELECT`, `ROLE_SELECT`, `MENTIONABLE_SELECT`, `CHANNEL_SELECT`).
|
||||
|
||||
```python
|
||||
from disagreement.models import SelectMenu, SelectOption
|
||||
from disagreement.enums import ComponentType, ChannelType
|
||||
|
||||
menu = SelectMenu(
|
||||
custom_id="example",
|
||||
options=[
|
||||
SelectOption(label="Option 1", value="1"),
|
||||
SelectOption(label="Option 2", value="2"),
|
||||
],
|
||||
placeholder="Choose an option",
|
||||
min_values=1,
|
||||
max_values=1,
|
||||
type=ComponentType.STRING_SELECT,
|
||||
)
|
||||
```
|
||||
|
||||
For channel selects you may pass `channel_types` with a list of allowed `ChannelType` values.
|
||||
|
||||
## Section
|
||||
|
||||
`Section` groups one or more `TextDisplay` components and can include an accessory `Button` or `Thumbnail`.
|
||||
|
||||
```python
|
||||
from disagreement.models import Section, TextDisplay, Thumbnail, UnfurledMediaItem
|
||||
|
||||
section = Section(
|
||||
components=[
|
||||
TextDisplay(content="## Section Title"),
|
||||
TextDisplay(content="Sections can hold multiple text displays."),
|
||||
],
|
||||
accessory=Thumbnail(media=UnfurledMediaItem(url="https://example.com/img.png")),
|
||||
)
|
||||
```
|
||||
|
||||
## Text Display
|
||||
|
||||
`TextDisplay` simply renders markdown text.
|
||||
|
||||
```python
|
||||
from disagreement.models import TextDisplay
|
||||
|
||||
text_display = TextDisplay(content="**Bold text**")
|
||||
```
|
||||
|
||||
## Thumbnail
|
||||
|
||||
`Thumbnail` shows a small image. Set `spoiler=True` to hide the image until clicked.
|
||||
|
||||
```python
|
||||
from disagreement.models import Thumbnail, UnfurledMediaItem
|
||||
|
||||
thumb = Thumbnail(
|
||||
media=UnfurledMediaItem(url="https://example.com/image.png"),
|
||||
description="A picture",
|
||||
spoiler=False,
|
||||
)
|
||||
```
|
||||
|
||||
## Media Gallery
|
||||
|
||||
`MediaGallery` holds multiple `MediaGalleryItem` objects.
|
||||
|
||||
```python
|
||||
from disagreement.models import MediaGallery, MediaGalleryItem, UnfurledMediaItem
|
||||
|
||||
gallery = MediaGallery(
|
||||
items=[
|
||||
MediaGalleryItem(media=UnfurledMediaItem(url="https://example.com/1.png")),
|
||||
MediaGalleryItem(media=UnfurledMediaItem(url="https://example.com/2.png")),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
## File
|
||||
|
||||
`File` displays an uploaded file. Use `spoiler=True` to mark it as a spoiler.
|
||||
|
||||
```python
|
||||
from disagreement.models import File, UnfurledMediaItem
|
||||
|
||||
file_component = File(
|
||||
file=UnfurledMediaItem(url="attachment://file.zip"),
|
||||
spoiler=False,
|
||||
)
|
||||
```
|
||||
|
||||
## Separator
|
||||
|
||||
`Separator` adds vertical spacing or an optional divider line between components.
|
||||
|
||||
```python
|
||||
from disagreement.models import Separator
|
||||
|
||||
separator = Separator(divider=True, spacing=2)
|
||||
```
|
||||
|
||||
## Container
|
||||
|
||||
`Container` visually groups a set of components and can apply an accent colour or spoiler.
|
||||
|
||||
```python
|
||||
from disagreement.models import Container, TextDisplay
|
||||
|
||||
container = Container(
|
||||
components=[TextDisplay(content="Inside a container")],
|
||||
accent_color=0xFF0000,
|
||||
spoiler=False,
|
||||
)
|
||||
```
|
||||
|
||||
A container can itself contain layout and content components, letting you build complex messages.
|
||||
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Slash Commands](slash_commands.md)
|
||||
- [Caching](caching.md)
|
||||
- [Voice Features](voice_features.md)
|
||||
|
29
docs/voice_client.md
Normal file
29
docs/voice_client.md
Normal file
@ -0,0 +1,29 @@
|
||||
# VoiceClient
|
||||
|
||||
`VoiceClient` provides a minimal interface to Discord's voice gateway. It handles the WebSocket handshake and lets you stream audio over UDP.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import os
|
||||
import disagreement
|
||||
|
||||
vc = disagreement.VoiceClient(
|
||||
os.environ["DISCORD_VOICE_ENDPOINT"],
|
||||
os.environ["DISCORD_SESSION_ID"],
|
||||
os.environ["DISCORD_VOICE_TOKEN"],
|
||||
int(os.environ["DISCORD_GUILD_ID"]),
|
||||
int(os.environ["DISCORD_USER_ID"]),
|
||||
)
|
||||
|
||||
asyncio.run(vc.connect())
|
||||
```
|
||||
|
||||
After connecting you can send raw Opus frames:
|
||||
|
||||
```python
|
||||
await vc.send_audio_frame(opus_bytes)
|
||||
```
|
||||
|
||||
Call `await vc.close()` when finished.
|
17
docs/voice_features.md
Normal file
17
docs/voice_features.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Voice Features
|
||||
|
||||
Disagreement includes experimental support for connecting to voice channels. You can join a voice channel and play audio using an FFmpeg subprocess.
|
||||
|
||||
```python
|
||||
voice = await client.join_voice(guild_id, channel_id)
|
||||
voice.play_file("welcome.mp3")
|
||||
```
|
||||
|
||||
Voice support is optional and may require additional system dependencies such as FFmpeg.
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Components](using_components.md)
|
||||
- [Slash Commands](slash_commands.md)
|
||||
- [Caching](caching.md)
|
||||
|
34
docs/webhooks.md
Normal file
34
docs/webhooks.md
Normal file
@ -0,0 +1,34 @@
|
||||
# Working with Webhooks
|
||||
|
||||
The `HTTPClient` includes helper methods for creating, editing and deleting Discord webhooks.
|
||||
|
||||
## Create a webhook
|
||||
|
||||
```python
|
||||
from disagreement.http import HTTPClient
|
||||
|
||||
http = HTTPClient(token="TOKEN")
|
||||
payload = {"name": "My Webhook"}
|
||||
webhook_data = await http.create_webhook("123", payload)
|
||||
```
|
||||
|
||||
## Edit a webhook
|
||||
|
||||
```python
|
||||
await http.edit_webhook("456", {"name": "Renamed"})
|
||||
```
|
||||
|
||||
## Delete a webhook
|
||||
|
||||
```python
|
||||
await http.delete_webhook("456")
|
||||
```
|
||||
|
||||
The methods return the raw webhook JSON. You can construct a `Webhook` model if needed:
|
||||
|
||||
```python
|
||||
from disagreement.models import Webhook
|
||||
|
||||
webhook = Webhook(webhook_data)
|
||||
print(webhook.id, webhook.name)
|
||||
```
|
218
examples/basic_bot.py
Normal file
218
examples/basic_bot.py
Normal file
@ -0,0 +1,218 @@
|
||||
# examples/basic_bot.py
|
||||
|
||||
"""
|
||||
A basic example bot using the Disagreement library.
|
||||
|
||||
To run this bot:
|
||||
1. Make sure you have the 'disagreement' library installed or accessible in your PYTHONPATH.
|
||||
If running from the project root, it should be discoverable.
|
||||
2. Set the DISCORD_BOT_TOKEN environment variable to your bot's token.
|
||||
e.g., export DISCORD_BOT_TOKEN="your_actual_token_here" (Linux/macOS)
|
||||
set DISCORD_BOT_TOKEN="your_actual_token_here" (Windows CMD)
|
||||
$env:DISCORD_BOT_TOKEN="your_actual_token_here" (Windows PowerShell)
|
||||
3. Run this script: python examples/basic_bot.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import logging # Optional: for more detailed logging
|
||||
|
||||
# Assuming the 'disagreement' package is in the parent directory or installed
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Add project root to path if running script directly from examples folder
|
||||
# and disagreement is not installed
|
||||
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
try:
|
||||
import disagreement
|
||||
from disagreement.ext import commands # Import the new commands extension
|
||||
except ImportError:
|
||||
print(
|
||||
"Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly."
|
||||
)
|
||||
print(
|
||||
"If running from the 'examples' directory, try running from the project root: python -m examples.basic_bot"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Optional: Configure logging for more insight, especially for gateway events
|
||||
# logging.basicConfig(level=logging.DEBUG) # For very verbose output
|
||||
# logging.getLogger('disagreement.gateway').setLevel(logging.INFO) # Or DEBUG
|
||||
# logging.getLogger('disagreement.http').setLevel(logging.INFO)
|
||||
|
||||
# --- Bot Configuration ---
|
||||
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
|
||||
# --- Intents Configuration ---
|
||||
# Define the intents your bot needs. For basic message reading and responding:
|
||||
intents = (
|
||||
disagreement.GatewayIntent.GUILDS
|
||||
| disagreement.GatewayIntent.GUILD_MESSAGES
|
||||
| disagreement.GatewayIntent.MESSAGE_CONTENT
|
||||
) # MESSAGE_CONTENT is privileged!
|
||||
|
||||
# If you don't need message content and only react to commands/mentions,
|
||||
# you might not need MESSAGE_CONTENT intent.
|
||||
# intents = disagreement.GatewayIntent.default() # A good starting point without privileged intents
|
||||
# intents |= disagreement.GatewayIntent.MESSAGE_CONTENT # Add if needed
|
||||
|
||||
# --- Initialize the Client ---
|
||||
if not BOT_TOKEN:
|
||||
print("Error: The DISCORD_BOT_TOKEN environment variable is not set.")
|
||||
print("Please set it before running the bot.")
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize Client with a command prefix
|
||||
client = disagreement.Client(token=BOT_TOKEN, intents=intents, command_prefix="!")
|
||||
|
||||
|
||||
# --- Define a Cog for example commands ---
|
||||
class ExampleCog(commands.Cog): # Ensuring this uses commands.Cog
|
||||
def __init__(
|
||||
self, bot_client
|
||||
): # Renamed client to bot_client to avoid conflict with self.client
|
||||
super().__init__(bot_client) # Pass the client instance to the base Cog
|
||||
|
||||
@commands.command(name="hello", aliases=["hi"])
|
||||
async def hello_command(self, ctx: commands.CommandContext, *, who: str = "world"):
|
||||
"""Greets someone."""
|
||||
await ctx.reply(f"Hello {ctx.author.mention} and {who}!")
|
||||
print(f"Executed 'hello' command for {ctx.author.username}, greeting {who}")
|
||||
|
||||
@commands.command()
|
||||
async def ping(self, ctx: commands.CommandContext):
|
||||
"""Responds with Pong!"""
|
||||
await ctx.reply("Pong!")
|
||||
print(f"Executed 'ping' command for {ctx.author.username}")
|
||||
|
||||
@commands.command()
|
||||
async def me(self, ctx: commands.CommandContext):
|
||||
"""Shows information about the invoking user."""
|
||||
reply_content = (
|
||||
f"Hello {ctx.author.mention}!\n"
|
||||
f"Your User ID is: {ctx.author.id}\n"
|
||||
f"Your Username: {ctx.author.username}#{ctx.author.discriminator}\n"
|
||||
f"Are you a bot? {'Yes' if ctx.author.bot else 'No'}"
|
||||
)
|
||||
await ctx.reply(reply_content)
|
||||
print(f"Executed 'me' command for {ctx.author.username}")
|
||||
|
||||
@commands.command(name="add")
|
||||
async def add_numbers(self, ctx: commands.CommandContext, num1: int, num2: int):
|
||||
"""Adds two numbers."""
|
||||
result = num1 + num2
|
||||
await ctx.reply(f"The sum of {num1} and {num2} is {result}.")
|
||||
print(
|
||||
f"Executed 'add' command for {ctx.author.username}: {num1} + {num2} = {result}"
|
||||
)
|
||||
|
||||
@commands.command(name="say")
|
||||
async def say_something(self, ctx: commands.CommandContext, *, text_to_say: str):
|
||||
"""Repeats the text you provide."""
|
||||
await ctx.reply(f"You said: {text_to_say}")
|
||||
print(
|
||||
f"Executed 'say' command for {ctx.author.username}, saying: {text_to_say}"
|
||||
)
|
||||
|
||||
@commands.command(name="quit")
|
||||
async def quit_command(self, ctx: commands.CommandContext):
|
||||
"""Shuts down the bot (requires YOUR_USER_ID to be set)."""
|
||||
# Replace YOUR_USER_ID with your actual Discord User ID for a safe shutdown command
|
||||
your_user_id = "YOUR_USER_ID_REPLACE_ME" # IMPORTANT: Replace this
|
||||
if str(ctx.author.id) == your_user_id:
|
||||
print("Quit command received. Shutting down...")
|
||||
await ctx.reply("Shutting down...")
|
||||
await self.client.close() # Access client via self.client from Cog
|
||||
else:
|
||||
await ctx.reply("You are not authorized to use this command.")
|
||||
print(
|
||||
f"Unauthorized quit attempt by {ctx.author.username} ({ctx.author.id})"
|
||||
)
|
||||
|
||||
|
||||
# --- Event Handlers ---
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
"""Called when the bot is ready and connected to Discord."""
|
||||
if client.user:
|
||||
print(
|
||||
f"Bot is ready! Logged in as {client.user.username}#{client.user.discriminator}"
|
||||
)
|
||||
print(f"User ID: {client.user.id}")
|
||||
else:
|
||||
print("Bot is ready, but client.user is missing!")
|
||||
print("------")
|
||||
print("Disagreement Bot is operational.")
|
||||
print("Listening for commands...")
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message: disagreement.Message):
|
||||
"""Called when a message is created and received."""
|
||||
# Command processing is now handled by the CommandHandler via client._process_message_for_commands
|
||||
# This on_message can be used for other message-related logic if needed,
|
||||
# or removed if all message handling is command-based.
|
||||
|
||||
# Example: Log all messages (excluding bot's own, if client.user was available)
|
||||
# if client.user and message.author.id == client.user.id:
|
||||
# return
|
||||
|
||||
print(
|
||||
f"General on_message: #{message.channel_id} from {message.author.username}: {message.content}"
|
||||
)
|
||||
# The old if/elif command structure is no longer needed here.
|
||||
|
||||
|
||||
@client.on_event(
|
||||
"GUILD_CREATE"
|
||||
) # Example of listening to a specific event by its Discord name
|
||||
async def on_guild_available(guild_data: dict): # Receives raw data for now
|
||||
# In a real scenario, guild_data would be parsed into a Guild model
|
||||
print(f"Guild available: {guild_data.get('name')} (ID: {guild_data.get('id')})")
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
async def main():
|
||||
print("Starting Disagreement Bot...")
|
||||
try:
|
||||
# Add the Cog to the client
|
||||
client.add_cog(ExampleCog(client)) # Pass client instance to Cog constructor
|
||||
# client.add_cog is synchronous, but it schedules cog.cog_load() if it's async.
|
||||
|
||||
await client.run()
|
||||
except disagreement.AuthenticationError:
|
||||
print(
|
||||
"Authentication failed. Please check your bot token and ensure it's correct."
|
||||
)
|
||||
except disagreement.DisagreementException as e:
|
||||
print(f"A Disagreement library error occurred: {e}")
|
||||
except KeyboardInterrupt:
|
||||
print("Bot shutting down due to KeyboardInterrupt...")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
if not client.is_closed():
|
||||
print("Ensuring client is closed...")
|
||||
await client.close()
|
||||
print("Bot has been shut down.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Note: On Windows, the default asyncio event loop policy might not support add_signal_handler.
|
||||
# If you encounter issues with Ctrl+C not working as expected,
|
||||
# you might need to adjust the event loop policy or handle shutdown differently.
|
||||
# For example, for Windows:
|
||||
# if os.name == 'nt':
|
||||
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
asyncio.run(main())
|
292
examples/component_bot.py
Normal file
292
examples/component_bot.py
Normal file
@ -0,0 +1,292 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Union, Optional
|
||||
from disagreement import Client, ui, HybridContext
|
||||
from disagreement.models import (
|
||||
Message,
|
||||
SelectOption,
|
||||
User,
|
||||
Member,
|
||||
Role,
|
||||
Attachment,
|
||||
Channel,
|
||||
ActionRow,
|
||||
Button,
|
||||
Section,
|
||||
TextDisplay,
|
||||
Thumbnail,
|
||||
UnfurledMediaItem,
|
||||
MediaGallery,
|
||||
MediaGalleryItem,
|
||||
Container,
|
||||
)
|
||||
from disagreement.enums import (
|
||||
ButtonStyle,
|
||||
GatewayIntent,
|
||||
ChannelType,
|
||||
MessageFlags,
|
||||
InteractionCallbackType,
|
||||
MessageFlags,
|
||||
)
|
||||
from disagreement.ext.commands.cog import Cog
|
||||
from disagreement.ext.commands.core import CommandContext
|
||||
from disagreement.ext.app_commands.decorators import hybrid_command, slash_command
|
||||
from disagreement.ext.app_commands.context import AppCommandContext
|
||||
from disagreement.interactions import (
|
||||
Interaction,
|
||||
InteractionResponsePayload,
|
||||
InteractionCallbackData,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Get the bot token and application ID from the environment variables
|
||||
token = os.getenv("DISCORD_BOT_TOKEN")
|
||||
application_id = os.getenv("DISCORD_APPLICATION_ID")
|
||||
|
||||
if not token:
|
||||
raise ValueError("Bot token not found in environment variables")
|
||||
if not application_id:
|
||||
raise ValueError("Application ID not found in environment variables")
|
||||
|
||||
# Define the intents
|
||||
intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
|
||||
|
||||
# Create a new client
|
||||
client = Client(
|
||||
|
||||
token=token,
|
||||
|
||||
intents=intents,
|
||||
|
||||
command_prefix="!",
|
||||
|
||||
mention_replies=True,
|
||||
)
|
||||
|
||||
# Simple stock data used for the stock cycling example
|
||||
STOCKS = [
|
||||
{"symbol": "AAPL", "name": "Apple Inc.", "price": "$175"},
|
||||
{"symbol": "MSFT", "name": "Microsoft Corp.", "price": "$315"},
|
||||
{"symbol": "GOOGL", "name": "Alphabet Inc.", "price": "$128"},
|
||||
]
|
||||
|
||||
|
||||
# Define a View class that contains our components
|
||||
class MyView(ui.View):
|
||||
def __init__(self):
|
||||
super().__init__(timeout=180) # 180-second timeout
|
||||
self.click_count = 0
|
||||
|
||||
@ui.button(label="Click Me!", style=ButtonStyle.SUCCESS, emoji="🖱️")
|
||||
async def click_me(self, interaction: Interaction):
|
||||
self.click_count += 1
|
||||
await interaction.respond(
|
||||
content=f"You've clicked the button {self.click_count} times!",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
@ui.select(
|
||||
custom_id="string_select",
|
||||
placeholder="Choose an option",
|
||||
options=[
|
||||
SelectOption(
|
||||
label="Option 1", value="opt1", description="This is the first option"
|
||||
),
|
||||
SelectOption(
|
||||
label="Option 2", value="opt2", description="This is the second option"
|
||||
),
|
||||
],
|
||||
)
|
||||
async def select_menu(self, interaction: Interaction):
|
||||
if interaction.data and interaction.data.values:
|
||||
await interaction.respond(
|
||||
content=f"You selected: {interaction.data.values[0]}",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
async def on_timeout(self):
|
||||
# This method is called when the view times out.
|
||||
# You can use this to edit the original message, for example.
|
||||
print("View has timed out!")
|
||||
|
||||
|
||||
# View for cycling through available stocks
|
||||
class StockView(ui.View):
|
||||
def __init__(self):
|
||||
super().__init__(timeout=180)
|
||||
self.index = 0
|
||||
|
||||
@ui.button(label="Next Stock", style=ButtonStyle.PRIMARY)
|
||||
async def next_stock(self, interaction: Interaction):
|
||||
self.index = (self.index + 1) % len(STOCKS)
|
||||
stock = STOCKS[self.index]
|
||||
# Edit the message by responding to the interaction with an update
|
||||
await interaction.edit(
|
||||
content=f"**{stock['symbol']}** - {stock['name']}\nPrice: {stock['price']}",
|
||||
components=self.to_components(),
|
||||
# Preserve the original reply mention
|
||||
allowed_mentions={"replied_user": True},
|
||||
)
|
||||
|
||||
|
||||
class ComponentCommandsCog(Cog):
|
||||
def __init__(self, client: Client):
|
||||
super().__init__(client)
|
||||
|
||||
@hybrid_command(name="components", description="Sends interactive components.")
|
||||
async def components_command(self, ctx: Union[CommandContext, AppCommandContext]):
|
||||
# Send a message with the view using a hybrid context helper
|
||||
hybrid = HybridContext(ctx)
|
||||
await hybrid.send("Here are your components:", view=MyView())
|
||||
|
||||
@hybrid_command(
|
||||
name="stocks", description="Shows stock information with navigation."
|
||||
)
|
||||
async def stocks_command(self, ctx: Union[CommandContext, AppCommandContext]):
|
||||
# Show the first stock and attach a button to cycle through them
|
||||
first = STOCKS[0]
|
||||
hybrid = HybridContext(ctx)
|
||||
await hybrid.send(
|
||||
f"**{first['symbol']}** - {first['name']}\nPrice: {first['price']}",
|
||||
view=StockView(),
|
||||
)
|
||||
|
||||
@hybrid_command(name="sectiondemo", description="Shows a section layout.")
|
||||
async def section_demo(self, ctx: Union[CommandContext, AppCommandContext]) -> None:
|
||||
section = Section(
|
||||
components=[
|
||||
TextDisplay(content="## Advanced Components"),
|
||||
TextDisplay(content="Sections group text with accessories."),
|
||||
],
|
||||
accessory=Thumbnail(
|
||||
media=UnfurledMediaItem(url="https://placehold.co/100x100.png")
|
||||
),
|
||||
)
|
||||
container = Container(components=[section], accent_color=0x5865F2)
|
||||
hybrid = HybridContext(ctx)
|
||||
await hybrid.send(
|
||||
components=[container],
|
||||
flags=MessageFlags.IS_COMPONENTS_V2.value,
|
||||
)
|
||||
|
||||
@hybrid_command(name="gallerydemo", description="Shows a media gallery.")
|
||||
async def gallery_demo(self, ctx: Union[CommandContext, AppCommandContext]) -> None:
|
||||
gallery = MediaGallery(
|
||||
items=[
|
||||
MediaGalleryItem(
|
||||
media=UnfurledMediaItem(url="https://placehold.co/600x400.png")
|
||||
),
|
||||
MediaGalleryItem(
|
||||
media=UnfurledMediaItem(url="https://placehold.co/600x400.jpg")
|
||||
),
|
||||
]
|
||||
)
|
||||
hybrid = HybridContext(ctx)
|
||||
await hybrid.send(
|
||||
components=[gallery],
|
||||
flags=MessageFlags.IS_COMPONENTS_V2.value,
|
||||
)
|
||||
|
||||
@hybrid_command(
|
||||
name="complex_components",
|
||||
description="Shows a complex layout with multiple containers.",
|
||||
)
|
||||
async def complex_components(
|
||||
self, ctx: Union[CommandContext, AppCommandContext]
|
||||
) -> None:
|
||||
container1 = Container(
|
||||
components=[
|
||||
Section(
|
||||
components=[
|
||||
TextDisplay(content="## Complex Layout Example"),
|
||||
TextDisplay(
|
||||
content="This container has an accent color and includes a section with an action row of buttons. There is a thumbnail accessory to the right."
|
||||
),
|
||||
],
|
||||
accessory=Thumbnail(
|
||||
media=UnfurledMediaItem(url="https://placehold.co/100x100.png")
|
||||
),
|
||||
),
|
||||
ActionRow(
|
||||
components=[
|
||||
Button(
|
||||
style=ButtonStyle.PRIMARY,
|
||||
label="Primary",
|
||||
custom_id="complex_primary",
|
||||
),
|
||||
Button(
|
||||
style=ButtonStyle.SUCCESS,
|
||||
label="Success",
|
||||
custom_id="complex_success",
|
||||
),
|
||||
Button(
|
||||
style=ButtonStyle.DANGER,
|
||||
label="Destructive",
|
||||
custom_id="complex_destructive",
|
||||
),
|
||||
]
|
||||
),
|
||||
],
|
||||
accent_color=0x5865F2,
|
||||
)
|
||||
container2 = Container(
|
||||
components=[
|
||||
TextDisplay(
|
||||
content="## Another Container\nThis container has no accent color and includes a media gallery."
|
||||
),
|
||||
MediaGallery(
|
||||
items=[
|
||||
MediaGalleryItem(
|
||||
media=UnfurledMediaItem(
|
||||
url="https://placehold.co/300x200.png"
|
||||
)
|
||||
),
|
||||
MediaGalleryItem(
|
||||
media=UnfurledMediaItem(
|
||||
url="https://placehold.co/300x200.jpg"
|
||||
)
|
||||
),
|
||||
MediaGalleryItem(
|
||||
media=UnfurledMediaItem(
|
||||
url="https://placehold.co/300x200.gif"
|
||||
)
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
hybrid = HybridContext(ctx)
|
||||
await hybrid.send(
|
||||
components=[container1, container2],
|
||||
flags=MessageFlags.IS_COMPONENTS_V2.value,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
@client.event
|
||||
async def on_ready():
|
||||
if client.user:
|
||||
print(f"Logged in as {client.user.username}")
|
||||
if client.application_id:
|
||||
try:
|
||||
print("Attempting to sync application commands...")
|
||||
await client.app_command_handler.sync_commands(
|
||||
application_id=client.application_id
|
||||
)
|
||||
print("Application commands synced successfully.")
|
||||
except Exception as e:
|
||||
print(f"Error syncing application commands: {e}")
|
||||
else:
|
||||
print(
|
||||
"Client's application ID is not set. Skipping application command sync."
|
||||
)
|
||||
|
||||
client.add_cog(ComponentCommandsCog(client))
|
||||
await client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
71
examples/context_menus.py
Normal file
71
examples/context_menus.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Examples showing how to use context menu commands."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Allow running example from repository root
|
||||
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.ext.app_commands import (
|
||||
user_command,
|
||||
message_command,
|
||||
AppCommandContext,
|
||||
)
|
||||
from disagreement.models import User, Message
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
|
||||
APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "")
|
||||
client = Client(token=BOT_TOKEN, application_id=APP_ID)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
"""Called when the bot is ready and connected to Discord."""
|
||||
if client.user:
|
||||
print(f"Bot is ready! Logged in as {client.user.username}")
|
||||
print("Attempting to sync application commands...")
|
||||
try:
|
||||
if client.application_id:
|
||||
await client.app_command_handler.sync_commands(
|
||||
application_id=client.application_id
|
||||
)
|
||||
print("Application commands synced successfully.")
|
||||
else:
|
||||
print("Skipping command sync: application ID is not set.")
|
||||
except Exception as e:
|
||||
print(f"Error syncing application commands: {e}")
|
||||
else:
|
||||
print("Bot is ready, but client.user is missing!")
|
||||
print("------")
|
||||
|
||||
|
||||
@user_command(name="User Info")
|
||||
async def user_info(ctx: AppCommandContext, user: User) -> None:
|
||||
await ctx.send(
|
||||
f"Selected user: {user.username}#{user.discriminator}", ephemeral=True
|
||||
)
|
||||
|
||||
|
||||
@message_command(name="Quote")
|
||||
async def quote(ctx: AppCommandContext, message: Message) -> None:
|
||||
await ctx.send(f"Quoted message: {message.content}", ephemeral=True)
|
||||
|
||||
|
||||
client.app_command_handler.add_command(user_info)
|
||||
client.app_command_handler.add_command(quote)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
await client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
315
examples/hybrid_bot.py
Normal file
315
examples/hybrid_bot.py
Normal file
@ -0,0 +1,315 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from typing import Any, Optional, Literal, Union
|
||||
|
||||
from disagreement import HybridContext
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.ext.commands.cog import Cog
|
||||
from disagreement.ext.commands.core import CommandContext
|
||||
from disagreement.ext.app_commands.decorators import (
|
||||
slash_command,
|
||||
user_command,
|
||||
message_command,
|
||||
hybrid_command,
|
||||
)
|
||||
from disagreement.ext.app_commands.commands import (
|
||||
AppCommandGroup,
|
||||
) # For defining groups
|
||||
from disagreement.ext.app_commands.context import AppCommandContext # Added
|
||||
from disagreement.models import (
|
||||
User,
|
||||
Member,
|
||||
Role,
|
||||
Attachment,
|
||||
Message,
|
||||
Channel,
|
||||
) # For type hints
|
||||
from disagreement.enums import (
|
||||
ChannelType,
|
||||
) # For channel option type hints, assuming it exists
|
||||
|
||||
# from disagreement.interactions import Interaction # Replaced by AppCommandContext
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# --- Define a Test Cog ---
|
||||
class TestCog(Cog):
|
||||
def __init__(self, client: Client):
|
||||
super().__init__(client)
|
||||
|
||||
@slash_command(name="greet", description="Sends a greeting.")
|
||||
async def greet_slash(self, ctx: AppCommandContext, name: str):
|
||||
await ctx.send(f"Hello, {name}! (Slash)")
|
||||
|
||||
@user_command(name="Show User Info")
|
||||
async def show_user_info_user(
|
||||
self, ctx: AppCommandContext, user: User
|
||||
): # Target user is in ctx.interaction.data.target_id and resolved
|
||||
target_user = (
|
||||
ctx.interaction.data.resolved.users.get(ctx.interaction.data.target_id)
|
||||
if ctx.interaction.data
|
||||
and ctx.interaction.data.resolved
|
||||
and ctx.interaction.data.target_id
|
||||
else user
|
||||
)
|
||||
if target_user:
|
||||
await ctx.send(
|
||||
f"User: {target_user.username}#{target_user.discriminator} (ID: {target_user.id}) (User Cmd)",
|
||||
ephemeral=True,
|
||||
)
|
||||
else:
|
||||
await ctx.send("Could not find user information.", ephemeral=True)
|
||||
|
||||
@message_command(name="Quote Message")
|
||||
async def quote_message_msg(
|
||||
self, ctx: AppCommandContext, message: Message
|
||||
): # Target message is in ctx.interaction.data.target_id and resolved
|
||||
target_message = (
|
||||
ctx.interaction.data.resolved.messages.get(ctx.interaction.data.target_id)
|
||||
if ctx.interaction.data
|
||||
and ctx.interaction.data.resolved
|
||||
and ctx.interaction.data.target_id
|
||||
else message
|
||||
)
|
||||
if target_message:
|
||||
await ctx.send(
|
||||
f'Quoting {target_message.author.username}: "{target_message.content}" (Message Cmd)',
|
||||
ephemeral=True,
|
||||
)
|
||||
else:
|
||||
await ctx.send("Could not find message to quote.", ephemeral=True)
|
||||
|
||||
@hybrid_command(name="ping", description="Checks bot latency.", aliases=["pong"])
|
||||
async def ping_hybrid(
|
||||
self, ctx: Union[CommandContext, AppCommandContext], arg: Optional[str] = None
|
||||
):
|
||||
# latency = self.client.latency # Assuming client has latency attribute from gateway - Commented out for now
|
||||
latency_ms = "N/A" # Placeholder
|
||||
hybrid = HybridContext(ctx)
|
||||
await hybrid.send(f"Pong! Arg: {arg} (Hybrid)")
|
||||
|
||||
@slash_command(name="options_test", description="Tests various option types.")
|
||||
async def options_test_slash(
|
||||
self,
|
||||
ctx: AppCommandContext,
|
||||
text: str,
|
||||
integer: int,
|
||||
boolean: bool,
|
||||
number: float,
|
||||
user_option: User,
|
||||
role_option: Role,
|
||||
attachment_option: Attachment,
|
||||
choice_option_str: Literal["apple", "banana", "cherry"],
|
||||
# Channel and member options as well as numeric Literal choices are
|
||||
# not yet exercised in tests pending full library support.
|
||||
):
|
||||
response_parts = [
|
||||
f"Text: {text}",
|
||||
f"Integer: {integer}",
|
||||
f"Boolean: {boolean}",
|
||||
f"Number: {number}",
|
||||
f"User: {user_option.username}#{user_option.discriminator}",
|
||||
f"Role: {role_option.name}",
|
||||
f"Attachment: {attachment_option.filename} (URL: {attachment_option.url})",
|
||||
f"Choice Str: {choice_option_str}",
|
||||
]
|
||||
await ctx.send("\n".join(response_parts), ephemeral=True)
|
||||
|
||||
# --- Subcommand Group Test ---
|
||||
# Define the group as a class attribute.
|
||||
# The AppCommandHandler's discovery mechanism (via Cog) should pick up AppCommandGroup instances.
|
||||
settings_group = AppCommandGroup(
|
||||
name="settings",
|
||||
description="Manage bot settings.",
|
||||
# guild_ids can be added here if the group is guild-specific
|
||||
)
|
||||
|
||||
@slash_command(
|
||||
name="show", description="Shows current setting values.", parent=settings_group
|
||||
)
|
||||
async def settings_show(
|
||||
self, ctx: AppCommandContext, setting_name: Optional[str] = None
|
||||
):
|
||||
if setting_name:
|
||||
await ctx.send(
|
||||
f"Showing value for setting: {setting_name} (Value: Placeholder)",
|
||||
ephemeral=True,
|
||||
)
|
||||
else:
|
||||
await ctx.send(
|
||||
"Showing all settings: (Placeholder for all settings)", ephemeral=True
|
||||
)
|
||||
|
||||
@slash_command(
|
||||
name="update", description="Updates a setting.", parent=settings_group
|
||||
)
|
||||
async def settings_update(
|
||||
self, ctx: AppCommandContext, setting_name: str, value: str
|
||||
):
|
||||
await ctx.send(
|
||||
f"Updated setting: {setting_name} to value: {value}", ephemeral=True
|
||||
)
|
||||
|
||||
# The Cog's metaclass or command registration logic should handle adding `settings_group`
|
||||
# (and its subcommands) to the client's AppCommandHandler.
|
||||
# The decorators now handle associating subcommands with their parent group.
|
||||
|
||||
@slash_command(
|
||||
name="numeric_choices_test", description="Tests integer and float choices."
|
||||
)
|
||||
async def numeric_choices_test_slash(
|
||||
self,
|
||||
ctx: AppCommandContext,
|
||||
int_choice: Literal[10, 20, 30, 42],
|
||||
float_choice: float,
|
||||
):
|
||||
response = (
|
||||
f"Integer Choice: {int_choice} (Type: {type(int_choice).__name__})\n"
|
||||
f"Float Choice: {float_choice} (Type: {type(float_choice).__name__})"
|
||||
)
|
||||
await ctx.send(response, ephemeral=True)
|
||||
|
||||
@slash_command(
|
||||
name="numeric_choices_extended",
|
||||
description="Tests additional integer and float choice handling.",
|
||||
)
|
||||
async def numeric_choices_extended_slash(
|
||||
self,
|
||||
ctx: AppCommandContext,
|
||||
int_choice: Literal[-5, 0, 5],
|
||||
float_choice: float,
|
||||
):
|
||||
response = (
|
||||
f"Int Choice: {int_choice} (Type: {type(int_choice).__name__})\n"
|
||||
f"Float Choice: {float_choice} (Type: {type(float_choice).__name__})"
|
||||
)
|
||||
await ctx.send(response, ephemeral=True)
|
||||
|
||||
@slash_command(
|
||||
name="channel_member_test",
|
||||
description="Tests channel and member options.",
|
||||
)
|
||||
async def channel_member_test_slash(
|
||||
self,
|
||||
ctx: AppCommandContext,
|
||||
channel: Channel,
|
||||
member: Member,
|
||||
):
|
||||
response = (
|
||||
f"Channel: {channel.name} (Type: {channel.type.name})\n"
|
||||
f"Member: {member.username}#{member.discriminator}"
|
||||
)
|
||||
await ctx.send(response, ephemeral=True)
|
||||
|
||||
@slash_command(
|
||||
name="channel_types_test",
|
||||
description="Demonstrates multiple channel type options.",
|
||||
)
|
||||
async def channel_types_test_slash(
|
||||
self,
|
||||
ctx: AppCommandContext,
|
||||
text_channel: Channel,
|
||||
voice_channel: Channel,
|
||||
category_channel: Channel,
|
||||
):
|
||||
response = (
|
||||
f"Text: {text_channel.type.name}\n"
|
||||
f"Voice: {voice_channel.type.name}\n"
|
||||
f"Category: {category_channel.type.name}"
|
||||
)
|
||||
await ctx.send(response, ephemeral=True)
|
||||
|
||||
|
||||
# --- Main Bot Script ---
|
||||
async def main():
|
||||
bot_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||
application_id = os.getenv("DISCORD_APPLICATION_ID")
|
||||
|
||||
if not bot_token:
|
||||
logger.error("Error: DISCORD_BOT_TOKEN environment variable not set.")
|
||||
return
|
||||
if not application_id:
|
||||
logger.error("Error: DISCORD_APPLICATION_ID environment variable not set.")
|
||||
return
|
||||
|
||||
client = Client(token=bot_token, command_prefix="!", application_id=application_id)
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
if client.user:
|
||||
logger.info(
|
||||
f"Bot logged in as {client.user.username}#{client.user.discriminator}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Client ready, but client.user is not populated! This should not happen."
|
||||
)
|
||||
return # Avoid proceeding if basic client info isn't there
|
||||
|
||||
if client.application_id:
|
||||
logger.info(f"Application ID is: {client.application_id}")
|
||||
# Sync application commands (global in this case)
|
||||
try:
|
||||
logger.info("Attempting to sync application commands...")
|
||||
# Ensure application_id is not None before passing
|
||||
app_id_to_sync = client.application_id
|
||||
if (
|
||||
app_id_to_sync is not None
|
||||
): # Redundant due to outer if, but good for clarity
|
||||
await client.app_command_handler.sync_commands(
|
||||
application_id=app_id_to_sync
|
||||
)
|
||||
logger.info("Application commands synced successfully.")
|
||||
else: # Should not be reached if outer if client.application_id is true
|
||||
logger.error(
|
||||
"Application ID was None despite initial check. Skipping sync."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing application commands: {e}", exc_info=True)
|
||||
else:
|
||||
# This case should be less likely now that Client gets it from READY.
|
||||
# If DISCORD_APPLICATION_ID was critical as a fallback, that logic would be here.
|
||||
# For now, we rely on the READY event.
|
||||
logger.warning(
|
||||
"Client's application ID is not set after READY. Skipping application command sync."
|
||||
)
|
||||
# Check if the environment variable was provided, as a diagnostic.
|
||||
if not application_id:
|
||||
logger.warning(
|
||||
"DISCORD_APPLICATION_ID environment variable was also not provided."
|
||||
)
|
||||
|
||||
client.add_cog(TestCog(client))
|
||||
|
||||
try:
|
||||
await client.run()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Bot shutting down...")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred in the bot's main run loop: {e}", exc_info=True
|
||||
)
|
||||
finally:
|
||||
if not client.is_closed():
|
||||
await client.close()
|
||||
logger.info("Bot has been closed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# For Windows, to allow graceful shutdown with Ctrl+C
|
||||
if os.name == "nt":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main loop interrupted. Exiting.")
|
65
examples/modal_command.py
Normal file
65
examples/modal_command.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Example showing how to present a modal using a slash command."""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from disagreement import Client, ui
|
||||
from disagreement.enums import GatewayIntent, TextInputStyle
|
||||
from disagreement.ext.app_commands.decorators import slash_command
|
||||
from disagreement.ext.app_commands.context import AppCommandContext
|
||||
|
||||
load_dotenv()
|
||||
|
||||
token = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||
application_id = os.getenv("DISCORD_APPLICATION_ID", "")
|
||||
|
||||
client = Client(
|
||||
token=token, application_id=application_id, intents=GatewayIntent.default()
|
||||
)
|
||||
|
||||
|
||||
class FeedbackModal(ui.Modal):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(title="Feedback", custom_id="feedback")
|
||||
|
||||
@ui.text_input(label="Your feedback", style=TextInputStyle.PARAGRAPH)
|
||||
async def feedback(self, interaction):
|
||||
await interaction.respond(content="Thanks for your feedback!", ephemeral=True)
|
||||
|
||||
|
||||
@slash_command(name="feedback", description="Send feedback via a modal")
|
||||
async def feedback_command(ctx: AppCommandContext):
|
||||
await ctx.interaction.respond_modal(FeedbackModal())
|
||||
|
||||
|
||||
client.app_command_handler.add_command(feedback_command)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
"""Called when the bot is ready and connected to Discord."""
|
||||
if client.user:
|
||||
print(f"Bot is ready! Logged in as {client.user.username}")
|
||||
print("Attempting to sync application commands...")
|
||||
try:
|
||||
if client.application_id:
|
||||
await client.app_command_handler.sync_commands(
|
||||
application_id=client.application_id
|
||||
)
|
||||
print("Application commands synced successfully.")
|
||||
else:
|
||||
print("Skipping command sync: application ID is not set.")
|
||||
except Exception as e:
|
||||
print(f"Error syncing application commands: {e}")
|
||||
else:
|
||||
print("Bot is ready, but client.user is missing!")
|
||||
print("------")
|
||||
|
||||
|
||||
async def main():
|
||||
await client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
63
examples/modal_send.py
Normal file
63
examples/modal_send.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Example showing how to send a modal."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from disagreement import Client, GatewayIntent, ui # type: ignore
|
||||
from disagreement.ext.app_commands.decorators import slash_command
|
||||
from disagreement.ext.app_commands.context import AppCommandContext
|
||||
|
||||
load_dotenv()
|
||||
TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||
APP_ID = os.getenv("DISCORD_APPLICATION_ID", "")
|
||||
|
||||
if not TOKEN:
|
||||
print("DISCORD_BOT_TOKEN not set")
|
||||
sys.exit(1)
|
||||
|
||||
client = Client(token=TOKEN, intents=GatewayIntent.default(), application_id=APP_ID)
|
||||
|
||||
|
||||
class NameModal(ui.Modal):
|
||||
def __init__(self):
|
||||
super().__init__(title="Your Name", custom_id="name_modal")
|
||||
self.name = ui.TextInput(label="Name", custom_id="name")
|
||||
|
||||
|
||||
@slash_command(name="namemodal", description="Shows a modal")
|
||||
async def _namemodal(ctx: AppCommandContext):
|
||||
await ctx.interaction.response.send_modal(NameModal())
|
||||
|
||||
|
||||
client.app_command_handler.add_command(_namemodal)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
"""Called when the bot is ready and connected to Discord."""
|
||||
if client.user:
|
||||
print(f"Bot is ready! Logged in as {client.user.username}")
|
||||
print("Attempting to sync application commands...")
|
||||
try:
|
||||
if client.application_id:
|
||||
await client.app_command_handler.sync_commands(
|
||||
application_id=client.application_id
|
||||
)
|
||||
print("Application commands synced successfully.")
|
||||
else:
|
||||
print("Skipping command sync: application ID is not set.")
|
||||
except Exception as e:
|
||||
print(f"Error syncing application commands: {e}")
|
||||
else:
|
||||
print("Bot is ready, but client.user is missing!")
|
||||
print("------")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(client.run())
|
36
examples/sharded_bot.py
Normal file
36
examples/sharded_bot.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Example bot demonstrating gateway sharding."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Ensure local package is importable when running from the examples directory
|
||||
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
import disagreement
|
||||
|
||||
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
if not TOKEN:
|
||||
raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set")
|
||||
|
||||
client = disagreement.Client(token=TOKEN, shard_count=2)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
if client.user:
|
||||
print(f"Shard bot ready as {client.user.username}#{client.user.discriminator}")
|
||||
else:
|
||||
print("Shard bot ready")
|
||||
|
||||
|
||||
async def main():
|
||||
if not TOKEN:
|
||||
print("DISCORD_BOT_TOKEN environment variable not set")
|
||||
return
|
||||
await client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
30
examples/task_loop.py
Normal file
30
examples/task_loop.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Example showing the tasks extension."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Allow running from the examples folder without installing
|
||||
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from disagreement.ext import tasks
|
||||
|
||||
counter = 0
|
||||
|
||||
|
||||
@tasks.loop(seconds=1.0)
|
||||
async def ticker() -> None:
|
||||
global counter
|
||||
counter += 1
|
||||
print(f"Tick {counter}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
ticker.start()
|
||||
await asyncio.sleep(5)
|
||||
ticker.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
61
examples/voice_bot.py
Normal file
61
examples/voice_bot.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""Example bot demonstrating VoiceClient usage."""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
# If running from the examples directory
|
||||
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from typing import cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import disagreement
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_VOICE_ENDPOINT = os.getenv("DISCORD_VOICE_ENDPOINT")
|
||||
_VOICE_TOKEN = os.getenv("DISCORD_VOICE_TOKEN")
|
||||
_VOICE_SESSION_ID = os.getenv("DISCORD_SESSION_ID")
|
||||
_GUILD_ID = os.getenv("DISCORD_GUILD_ID")
|
||||
_USER_ID = os.getenv("DISCORD_USER_ID")
|
||||
|
||||
if not all([_VOICE_ENDPOINT, _VOICE_TOKEN, _VOICE_SESSION_ID, _GUILD_ID, _USER_ID]):
|
||||
print("Missing one or more required environment variables for voice connection")
|
||||
sys.exit(1)
|
||||
|
||||
assert _VOICE_ENDPOINT
|
||||
assert _VOICE_TOKEN
|
||||
assert _VOICE_SESSION_ID
|
||||
assert _GUILD_ID
|
||||
assert _USER_ID
|
||||
|
||||
VOICE_ENDPOINT = cast(str, _VOICE_ENDPOINT)
|
||||
VOICE_TOKEN = cast(str, _VOICE_TOKEN)
|
||||
VOICE_SESSION_ID = cast(str, _VOICE_SESSION_ID)
|
||||
GUILD_ID = int(cast(str, _GUILD_ID))
|
||||
USER_ID = int(cast(str, _USER_ID))
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
vc = disagreement.VoiceClient(
|
||||
VOICE_ENDPOINT,
|
||||
VOICE_SESSION_ID,
|
||||
VOICE_TOKEN,
|
||||
GUILD_ID,
|
||||
USER_ID,
|
||||
)
|
||||
await vc.connect()
|
||||
|
||||
try:
|
||||
# Send silence frame as an example
|
||||
await vc.send_audio_frame(b"\xf8\xff\xfe")
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
await vc.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
55
pyproject.toml
Normal file
55
pyproject.toml
Normal file
@ -0,0 +1,55 @@
|
||||
[project]
|
||||
name = "disagreement"
|
||||
version = "0.0.1"
|
||||
description = "A Python library for the Discord API."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = {text = "BSD 3-Clause"}
|
||||
authors = [
|
||||
{name = "Slipstream", email = "me@slipstreamm.dev"}
|
||||
]
|
||||
keywords = ["discord", "api", "bot", "async", "aiohttp"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Internet",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"hypothesis>=6.89.0",
|
||||
]
|
||||
dev = [
|
||||
"dotenv>=0.0.5",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Slipstreamm/disagreement"
|
||||
Issues = "https://github.com/Slipstreamm/disagreement/issues"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
# Optional: for linting/formatting, e.g., Ruff
|
||||
# [tool.ruff]
|
||||
# line-length = 88
|
||||
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
|
||||
# ignore = []
|
||||
|
||||
# [tool.ruff.format]
|
||||
# quote-style = "double"
|
9
pyrightconfig.json
Normal file
9
pyrightconfig.json
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"include": ["."],
|
||||
"exclude": ["**/node_modules", "**/__pycache__", "**/.venv", "**/.git", "**/dist", "**/build", "**/tests/**", "tavilytool.py"],
|
||||
"ignore": [],
|
||||
"reportMissingImports": true,
|
||||
"reportMissingTypeStubs": false,
|
||||
"pythonVersion": "3.13",
|
||||
"typeCheckingMode": "standard"
|
||||
}
|
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
-e .[test,dev]
|
171
tavilytool.py
Normal file
171
tavilytool.py
Normal file
@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tavily API Script for AI Agents
|
||||
Execute with: python tavily.py "your search query"
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import requests # type: ignore
|
||||
import argparse
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class TavilyAPI:
|
||||
def __init__(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
self.base_url = "https://api.tavily.com"
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
search_depth: str = "basic",
|
||||
include_answer: bool = True,
|
||||
include_images: bool = False,
|
||||
include_raw_content: bool = False,
|
||||
max_results: int = 5,
|
||||
include_domains: Optional[List[str]] = None,
|
||||
exclude_domains: Optional[List[str]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Perform a search using Tavily API
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
search_depth: "basic" or "advanced"
|
||||
include_answer: Include AI-generated answer
|
||||
include_images: Include images in results
|
||||
include_raw_content: Include raw HTML content
|
||||
max_results: Maximum number of results (1-20)
|
||||
include_domains: List of domains to include
|
||||
exclude_domains: List of domains to exclude
|
||||
|
||||
Returns:
|
||||
Dictionary containing search results
|
||||
"""
|
||||
url = f"{self.base_url}/search"
|
||||
|
||||
payload = {
|
||||
"api_key": self.api_key,
|
||||
"query": query,
|
||||
"search_depth": search_depth,
|
||||
"include_answer": include_answer,
|
||||
"include_images": include_images,
|
||||
"include_raw_content": include_raw_content,
|
||||
"max_results": max_results,
|
||||
}
|
||||
|
||||
if include_domains:
|
||||
payload["include_domains"] = include_domains
|
||||
if exclude_domains:
|
||||
payload["exclude_domains"] = exclude_domains
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
return {"error": f"API request failed: {str(e)}"}
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "Invalid JSON response from API"}
|
||||
|
||||
|
||||
def format_results(results: Dict) -> str:
|
||||
"""Format search results for display"""
|
||||
if "error" in results:
|
||||
return f"❌ Error: {results['error']}"
|
||||
|
||||
output = []
|
||||
|
||||
# Add answer if available
|
||||
if results.get("answer"):
|
||||
output.append("🤖 AI Answer:")
|
||||
output.append(f" {results['answer']}")
|
||||
output.append("")
|
||||
|
||||
# Add search results
|
||||
if results.get("results"):
|
||||
output.append("🔍 Search Results:")
|
||||
for i, result in enumerate(results["results"], 1):
|
||||
output.append(f" {i}. {result.get('title', 'No title')}")
|
||||
output.append(f" URL: {result.get('url', 'No URL')}")
|
||||
if result.get("content"):
|
||||
# Truncate content to first 200 chars
|
||||
content = (
|
||||
result["content"][:200] + "..."
|
||||
if len(result["content"]) > 200
|
||||
else result["content"]
|
||||
)
|
||||
output.append(f" Content: {content}")
|
||||
output.append("")
|
||||
|
||||
# Add images if available
|
||||
if results.get("images"):
|
||||
output.append("🖼️ Images:")
|
||||
for img in results["images"][:3]: # Show first 3 images
|
||||
output.append(f" {img}")
|
||||
output.append("")
|
||||
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Search using Tavily API")
|
||||
parser.add_argument("query", help="Search query")
|
||||
parser.add_argument(
|
||||
"--depth",
|
||||
choices=["basic", "advanced"],
|
||||
default="basic",
|
||||
help="Search depth (default: basic)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-results",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Maximum number of results (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-images", action="store_true", help="Include images in results"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-answer", action="store_true", help="Don't include AI-generated answer"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-domains", nargs="+", help="Include only these domains"
|
||||
)
|
||||
parser.add_argument("--exclude-domains", nargs="+", help="Exclude these domains")
|
||||
parser.add_argument("--raw", action="store_true", help="Output raw JSON response")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("TAVILY_API_KEY")
|
||||
if not api_key:
|
||||
print("❌ Error: TAVILY_API_KEY environment variable not set")
|
||||
print("Set it with: export TAVILY_API_KEY='your-api-key-here'")
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize Tavily API
|
||||
tavily = TavilyAPI(api_key)
|
||||
|
||||
# Perform search
|
||||
results = tavily.search(
|
||||
query=args.query,
|
||||
search_depth=args.depth,
|
||||
include_answer=not args.no_answer,
|
||||
include_images=args.include_images,
|
||||
max_results=args.max_results,
|
||||
include_domains=args.include_domains,
|
||||
exclude_domains=args.exclude_domains,
|
||||
)
|
||||
|
||||
# Output results
|
||||
if args.raw:
|
||||
print(json.dumps(results, indent=2))
|
||||
else:
|
||||
print(format_results(results))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
111
tests/conftest.py
Normal file
111
tests/conftest.py
Normal file
@ -0,0 +1,111 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.interactions import Interaction
|
||||
from disagreement.enums import InteractionType
|
||||
from disagreement.models import Message
|
||||
|
||||
|
||||
class DummyHTTP:
|
||||
def __init__(self):
|
||||
self.create_interaction_response = AsyncMock()
|
||||
self.create_followup_message = AsyncMock(
|
||||
return_value={
|
||||
"id": "123",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
)
|
||||
self.edit_original_interaction_response = AsyncMock(return_value={"id": "123"})
|
||||
self.edit_followup_message = AsyncMock(
|
||||
return_value={
|
||||
"id": "123",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class DummyBot:
|
||||
def __init__(self):
|
||||
self._http = DummyHTTP()
|
||||
self.application_id = "app123"
|
||||
self._guilds = {}
|
||||
self._channels = {}
|
||||
|
||||
def get_guild(self, gid):
|
||||
return self._guilds.get(gid)
|
||||
|
||||
async def fetch_channel(self, cid):
|
||||
return self._channels.get(cid)
|
||||
|
||||
|
||||
class DummyCommandHTTP:
|
||||
def __init__(self):
|
||||
self.edit_message = AsyncMock(return_value={"id": "321", "channel_id": "c"})
|
||||
|
||||
|
||||
class DummyCommandBot:
|
||||
def __init__(self):
|
||||
self._http = DummyCommandHTTP()
|
||||
|
||||
async def edit_message(self, channel_id, message_id, *, content=None, **kwargs):
|
||||
return await self._http.edit_message(
|
||||
channel_id, message_id, {"content": content, **kwargs}
|
||||
)
|
||||
|
||||
|
||||
class DummyClient:
|
||||
pass
|
||||
|
||||
|
||||
class DummyInteraction:
|
||||
data = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_bot():
|
||||
return DummyBot()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def interaction(dummy_bot):
|
||||
data = {
|
||||
"id": "1",
|
||||
"application_id": dummy_bot.application_id,
|
||||
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||
"token": "tok",
|
||||
"version": 1,
|
||||
}
|
||||
return Interaction(data, client_instance=dummy_bot)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def command_bot():
|
||||
return DummyCommandBot()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def message(command_bot):
|
||||
message_data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
return Message(message_data, client_instance=command_bot)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_client():
|
||||
return DummyClient()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def basic_interaction():
|
||||
return DummyInteraction()
|
172
tests/test_additional_converters.py
Normal file
172
tests/test_additional_converters.py
Normal file
@ -0,0 +1,172 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.commands.converters import run_converters
|
||||
from disagreement.ext.commands.core import CommandContext, Command
|
||||
from disagreement.ext.commands.errors import BadArgument
|
||||
from disagreement.models import Message, Member, Role, Guild
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
)
|
||||
|
||||
|
||||
class DummyBot:
|
||||
def __init__(self, guild: Guild):
|
||||
self._guilds = {guild.id: guild}
|
||||
|
||||
def get_guild(self, gid):
|
||||
return self._guilds.get(gid)
|
||||
|
||||
async def fetch_member(self, gid, mid):
|
||||
guild = self._guilds.get(gid)
|
||||
return guild.get_member(mid) if guild else None
|
||||
|
||||
async def fetch_role(self, gid, rid):
|
||||
guild = self._guilds.get(gid)
|
||||
return guild.get_role(rid) if guild else None
|
||||
|
||||
async def fetch_guild(self, gid):
|
||||
return self._guilds.get(gid)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def guild_objects():
|
||||
guild_data = {
|
||||
"id": "1",
|
||||
"name": "g",
|
||||
"owner_id": "2",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
}
|
||||
guild = Guild(guild_data, client_instance=None)
|
||||
|
||||
member = Member(
|
||||
{
|
||||
"user": {"id": "3", "username": "m", "discriminator": "0001"},
|
||||
"joined_at": "t",
|
||||
"roles": [],
|
||||
},
|
||||
None,
|
||||
)
|
||||
member.guild_id = guild.id
|
||||
|
||||
role = Role(
|
||||
{
|
||||
"id": "5",
|
||||
"name": "r",
|
||||
"color": 0,
|
||||
"hoist": False,
|
||||
"position": 0,
|
||||
"permissions": "0",
|
||||
"managed": False,
|
||||
"mentionable": True,
|
||||
}
|
||||
)
|
||||
|
||||
guild._members[member.id] = member
|
||||
guild.roles.append(role)
|
||||
|
||||
return guild, member, role
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def command_context(guild_objects):
|
||||
guild, member, role = guild_objects
|
||||
bot = DummyBot(guild)
|
||||
message_data = {
|
||||
"id": "10",
|
||||
"channel_id": "20",
|
||||
"guild_id": guild.id,
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(message_data, client_instance=bot)
|
||||
|
||||
async def dummy(ctx):
|
||||
pass
|
||||
|
||||
cmd = Command(dummy)
|
||||
return CommandContext(
|
||||
message=msg, bot=bot, prefix="!", command=cmd, invoked_with="dummy"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_converter(command_context, guild_objects):
|
||||
_, member, _ = guild_objects
|
||||
mention = f"<@!{member.id}>"
|
||||
result = await run_converters(command_context, Member, mention)
|
||||
assert result is member
|
||||
result = await run_converters(command_context, Member, member.id)
|
||||
assert result is member
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_converter(command_context, guild_objects):
|
||||
_, _, role = guild_objects
|
||||
mention = f"<@&{role.id}>"
|
||||
result = await run_converters(command_context, Role, mention)
|
||||
assert result is role
|
||||
result = await run_converters(command_context, Role, role.id)
|
||||
assert result is role
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_converter(command_context, guild_objects):
|
||||
guild, _, _ = guild_objects
|
||||
result = await run_converters(command_context, Guild, guild.id)
|
||||
assert result is guild
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_converter_no_guild():
|
||||
guild_data = {
|
||||
"id": "99",
|
||||
"name": "g",
|
||||
"owner_id": "2",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
}
|
||||
guild = Guild(guild_data, client_instance=None)
|
||||
bot = DummyBot(guild)
|
||||
message_data = {
|
||||
"id": "11",
|
||||
"channel_id": "20",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(message_data, client_instance=bot)
|
||||
|
||||
async def dummy(ctx):
|
||||
pass
|
||||
|
||||
ctx = CommandContext(
|
||||
message=msg, bot=bot, prefix="!", command=Command(dummy), invoked_with="dummy"
|
||||
)
|
||||
|
||||
with pytest.raises(BadArgument):
|
||||
await run_converters(ctx, Member, "<@!1>")
|
17
tests/test_cache.py
Normal file
17
tests/test_cache.py
Normal file
@ -0,0 +1,17 @@
|
||||
import time
|
||||
|
||||
from disagreement.cache import Cache
|
||||
|
||||
|
||||
def test_cache_store_and_get():
|
||||
cache = Cache()
|
||||
cache.set("a", 123)
|
||||
assert cache.get("a") == 123
|
||||
|
||||
|
||||
def test_cache_ttl_expiry():
|
||||
cache = Cache(ttl=0.01)
|
||||
cache.set("b", 1)
|
||||
assert cache.get("b") == 1
|
||||
time.sleep(0.02)
|
||||
assert cache.get("b") is None
|
33
tests/test_client_context_manager.py
Normal file
33
tests/test_client_context_manager.py
Normal file
@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.client import Client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_async_context_closes(monkeypatch):
|
||||
client = Client(token="t")
|
||||
monkeypatch.setattr(client, "connect", AsyncMock())
|
||||
monkeypatch.setattr(client._http, "close", AsyncMock())
|
||||
|
||||
async with client:
|
||||
client.connect.assert_awaited_once()
|
||||
|
||||
client._http.close.assert_awaited_once()
|
||||
assert client.is_closed()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_async_context_closes_on_exception(monkeypatch):
|
||||
client = Client(token="t")
|
||||
monkeypatch.setattr(client, "connect", AsyncMock())
|
||||
monkeypatch.setattr(client._http, "close", AsyncMock())
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with client:
|
||||
raise ValueError("boom")
|
||||
|
||||
client.connect.assert_awaited_once()
|
||||
client._http.close.assert_awaited_once()
|
||||
assert client.is_closed()
|
51
tests/test_command_checks.py
Normal file
51
tests/test_command_checks.py
Normal file
@ -0,0 +1,51 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.commands.core import Command, CommandContext
|
||||
from disagreement.ext.commands.decorators import check, cooldown
|
||||
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_decorator_blocks(message):
|
||||
async def cb(ctx):
|
||||
pass
|
||||
|
||||
cmd = Command(check(lambda c: False)(cb))
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=message._client,
|
||||
prefix="!",
|
||||
command=cmd,
|
||||
invoked_with="test",
|
||||
)
|
||||
|
||||
with pytest.raises(CheckFailure):
|
||||
await cmd.invoke(ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_per_user(message):
|
||||
uses = []
|
||||
|
||||
@cooldown(1, 0.05)
|
||||
async def cb(ctx):
|
||||
uses.append(1)
|
||||
|
||||
cmd = Command(cb)
|
||||
ctx = CommandContext(
|
||||
message=message,
|
||||
bot=message._client,
|
||||
prefix="!",
|
||||
command=cmd,
|
||||
invoked_with="test",
|
||||
)
|
||||
|
||||
await cmd.invoke(ctx)
|
||||
|
||||
with pytest.raises(CommandOnCooldown):
|
||||
await cmd.invoke(ctx)
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
await cmd.invoke(ctx)
|
||||
assert len(uses) == 2
|
29
tests/test_components_factory.py
Normal file
29
tests/test_components_factory.py
Normal file
@ -0,0 +1,29 @@
|
||||
from disagreement.components import component_factory
|
||||
from disagreement.enums import ComponentType, ButtonStyle
|
||||
|
||||
|
||||
def test_component_factory_button():
|
||||
data = {
|
||||
"type": ComponentType.BUTTON.value,
|
||||
"style": ButtonStyle.PRIMARY.value,
|
||||
"label": "Click",
|
||||
"custom_id": "x",
|
||||
}
|
||||
comp = component_factory(data)
|
||||
assert comp.to_dict()["label"] == "Click"
|
||||
|
||||
|
||||
def test_component_factory_action_row():
|
||||
data = {
|
||||
"type": ComponentType.ACTION_ROW.value,
|
||||
"components": [
|
||||
{
|
||||
"type": ComponentType.BUTTON.value,
|
||||
"style": ButtonStyle.PRIMARY.value,
|
||||
"label": "A",
|
||||
"custom_id": "b",
|
||||
}
|
||||
],
|
||||
}
|
||||
row = component_factory(data)
|
||||
assert len(row.components) == 1
|
199
tests/test_context.py
Normal file
199
tests/test_context.py
Normal file
@ -0,0 +1,199 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.app_commands.context import AppCommandContext
|
||||
from disagreement.interactions import Interaction
|
||||
from disagreement.enums import InteractionType, MessageFlags
|
||||
from disagreement.models import (
|
||||
Embed,
|
||||
ActionRow,
|
||||
Button,
|
||||
Container,
|
||||
TextDisplay,
|
||||
)
|
||||
from disagreement.enums import ButtonStyle, ComponentType
|
||||
|
||||
|
||||
class DummyHTTP:
|
||||
def __init__(self):
|
||||
self.create_interaction_response = AsyncMock()
|
||||
self.create_followup_message = AsyncMock(return_value={"id": "123"})
|
||||
self.edit_original_interaction_response = AsyncMock(return_value={"id": "123"})
|
||||
self.edit_followup_message = AsyncMock(return_value={"id": "123"})
|
||||
|
||||
|
||||
class DummyBot:
|
||||
def __init__(self):
|
||||
self._http = DummyHTTP()
|
||||
self.application_id = "app123"
|
||||
self._guilds = {}
|
||||
self._channels = {}
|
||||
|
||||
def get_guild(self, gid):
|
||||
return self._guilds.get(gid)
|
||||
|
||||
async def fetch_channel(self, cid):
|
||||
return self._channels.get(cid)
|
||||
from disagreement.ext.commands.core import CommandContext, Command
|
||||
from disagreement.enums import MessageFlags, ButtonStyle, ComponentType
|
||||
from disagreement.models import ActionRow, Button, Container, TextDisplay
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_extra_payload(dummy_bot, interaction):
|
||||
ctx = AppCommandContext(dummy_bot, interaction)
|
||||
button = Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="a")
|
||||
row = ActionRow([button])
|
||||
await ctx.send(
|
||||
content="hi",
|
||||
tts=True,
|
||||
components=[row],
|
||||
allowed_mentions={"parse": []},
|
||||
files=[{"id": 1, "filename": "f.txt"}],
|
||||
ephemeral=True,
|
||||
)
|
||||
dummy_bot._http.create_interaction_response.assert_called_once()
|
||||
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
|
||||
"payload"
|
||||
].data.to_dict()
|
||||
assert payload["tts"] is True
|
||||
assert payload["components"]
|
||||
assert payload["allowed_mentions"] == {"parse": []}
|
||||
assert payload["attachments"] == [{"id": 1, "filename": "f.txt"}]
|
||||
assert payload["flags"] == MessageFlags.EPHEMERAL.value
|
||||
await ctx.send_followup(content="again")
|
||||
dummy_bot._http.create_followup_message.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_send_is_followup(dummy_bot, interaction):
|
||||
ctx = AppCommandContext(dummy_bot, interaction)
|
||||
await ctx.send(content="first")
|
||||
await ctx.send_followup(content="second")
|
||||
assert dummy_bot._http.create_interaction_response.call_count == 1
|
||||
assert dummy_bot._http.create_followup_message.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_with_components_and_attachments(dummy_bot, interaction):
|
||||
ctx = AppCommandContext(dummy_bot, interaction)
|
||||
await ctx.send(content="orig")
|
||||
row = ActionRow([Button(style=ButtonStyle.PRIMARY, label="B", custom_id="b")])
|
||||
await ctx.edit(content="new", components=[row], attachments=[{"id": 1}])
|
||||
dummy_bot._http.edit_original_interaction_response.assert_called_once()
|
||||
payload = dummy_bot._http.edit_original_interaction_response.call_args.kwargs[
|
||||
"payload"
|
||||
]
|
||||
assert payload["components"]
|
||||
assert payload["attachments"] == [{"id": 1}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_flags(dummy_bot, interaction):
|
||||
ctx = AppCommandContext(dummy_bot, interaction)
|
||||
await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value)
|
||||
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
|
||||
"payload"
|
||||
].data.to_dict()
|
||||
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_container_component(dummy_bot, interaction):
|
||||
ctx = AppCommandContext(dummy_bot, interaction)
|
||||
container = Container(components=[TextDisplay(content="hi")])
|
||||
await ctx.send(components=[container], flags=MessageFlags.IS_COMPONENTS_V2.value)
|
||||
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
|
||||
"payload"
|
||||
].data.to_dict()
|
||||
assert payload["components"][0]["type"] == ComponentType.CONTAINER.value
|
||||
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_context_edit(command_bot, message):
|
||||
async def dummy(ctx):
|
||||
pass
|
||||
|
||||
cmd = Command(dummy)
|
||||
ctx = CommandContext(
|
||||
message=message, bot=command_bot, prefix="!", command=cmd, invoked_with="dummy"
|
||||
)
|
||||
await ctx.edit(message.id, content="new")
|
||||
command_bot._http.edit_message.assert_called_once()
|
||||
args = command_bot._http.edit_message.call_args[0]
|
||||
assert args[0] == message.channel_id
|
||||
assert args[1] == message.id
|
||||
assert args[2]["content"] == "new"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_http_error_propagates(dummy_bot, interaction):
|
||||
ctx = AppCommandContext(dummy_bot, interaction)
|
||||
dummy_bot._http.create_interaction_response.side_effect = RuntimeError("boom")
|
||||
with pytest.raises(RuntimeError):
|
||||
await ctx.send(content="hi")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_send_only_initial_once(dummy_bot, interaction):
|
||||
ctx = AppCommandContext(dummy_bot, interaction)
|
||||
|
||||
async def send_msg(i: int):
|
||||
if i == 0:
|
||||
await ctx.send(content=str(i))
|
||||
else:
|
||||
await ctx.send_followup(content=str(i))
|
||||
|
||||
await asyncio.gather(*(send_msg(i) for i in range(50)))
|
||||
assert dummy_bot._http.create_interaction_response.call_count == 1
|
||||
assert dummy_bot._http.create_followup_message.call_count == 49
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_with_flags_2():
|
||||
bot = DummyBot()
|
||||
interaction = Interaction(
|
||||
{
|
||||
"id": "1",
|
||||
"application_id": bot.application_id,
|
||||
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||
"token": "tok",
|
||||
"version": 1,
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
ctx = AppCommandContext(bot, interaction)
|
||||
await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value)
|
||||
payload = bot._http.create_interaction_response.call_args[1][
|
||||
"payload"
|
||||
].data.to_dict()
|
||||
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_container_component_2():
|
||||
bot = DummyBot()
|
||||
interaction = Interaction(
|
||||
{
|
||||
"id": "1",
|
||||
"application_id": bot.application_id,
|
||||
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||
"token": "tok",
|
||||
"version": 1,
|
||||
},
|
||||
client_instance=bot,
|
||||
)
|
||||
ctx = AppCommandContext(bot, interaction)
|
||||
container = Container(components=[TextDisplay(content="hi")])
|
||||
await ctx.send(
|
||||
components=[container],
|
||||
flags=MessageFlags.IS_COMPONENTS_V2.value,
|
||||
)
|
||||
payload = bot._http.create_interaction_response.call_args[1][
|
||||
"payload"
|
||||
].data.to_dict()
|
||||
assert payload["components"][0]["type"] == ComponentType.CONTAINER.value
|
||||
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
|
80
tests/test_context_menus.py
Normal file
80
tests/test_context_menus.py
Normal file
@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.app_commands.handler import AppCommandHandler
|
||||
from disagreement.ext.app_commands.decorators import user_command, message_command
|
||||
from disagreement.enums import ApplicationCommandType, InteractionType
|
||||
from disagreement.interactions import Interaction
|
||||
from disagreement.models import User, Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_context_menu_invokes(dummy_bot):
|
||||
handler = AppCommandHandler(dummy_bot)
|
||||
captured = {}
|
||||
|
||||
@user_command(name="Info")
|
||||
async def info(ctx, user: User):
|
||||
captured["user"] = user
|
||||
|
||||
handler.add_command(info)
|
||||
|
||||
data = {
|
||||
"id": "cmd",
|
||||
"name": "Info",
|
||||
"type": ApplicationCommandType.USER.value,
|
||||
"target_id": "42",
|
||||
"resolved": {
|
||||
"users": {"42": {"id": "42", "username": "Bob", "discriminator": "0001"}}
|
||||
},
|
||||
}
|
||||
payload = {
|
||||
"id": "1",
|
||||
"application_id": dummy_bot.application_id,
|
||||
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||
"token": "tok",
|
||||
"version": 1,
|
||||
"data": data,
|
||||
}
|
||||
interaction = Interaction(payload, client_instance=dummy_bot)
|
||||
await handler.process_interaction(interaction)
|
||||
assert isinstance(captured.get("user"), User)
|
||||
assert captured["user"].id == "42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_context_menu_invokes(dummy_bot):
|
||||
handler = AppCommandHandler(dummy_bot)
|
||||
captured = {}
|
||||
|
||||
@message_command(name="Quote")
|
||||
async def quote(ctx, message: Message):
|
||||
captured["msg"] = message
|
||||
|
||||
handler.add_command(quote)
|
||||
|
||||
msg_data = {
|
||||
"id": "99",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "Ann", "discriminator": "0001"},
|
||||
"content": "Hello",
|
||||
"timestamp": "t",
|
||||
}
|
||||
data = {
|
||||
"id": "cmd",
|
||||
"name": "Quote",
|
||||
"type": ApplicationCommandType.MESSAGE.value,
|
||||
"target_id": "99",
|
||||
"resolved": {"messages": {"99": msg_data}},
|
||||
}
|
||||
payload = {
|
||||
"id": "1",
|
||||
"application_id": dummy_bot.application_id,
|
||||
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||
"token": "tok",
|
||||
"version": 1,
|
||||
"data": data,
|
||||
}
|
||||
interaction = Interaction(payload, client_instance=dummy_bot)
|
||||
await handler.process_interaction(interaction)
|
||||
assert isinstance(captured.get("msg"), Message)
|
||||
assert captured["msg"].id == "99"
|
30
tests/test_converter_registration.py
Normal file
30
tests/test_converter_registration.py
Normal file
@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.app_commands.handler import AppCommandHandler
|
||||
from disagreement.ext.app_commands.converters import Converter
|
||||
|
||||
|
||||
class MyType:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class MyConverter(Converter[MyType]):
|
||||
async def convert(self, interaction, value):
|
||||
return MyType(f"converted-{value}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_converter_registration(dummy_client):
|
||||
handler = AppCommandHandler(client=dummy_client)
|
||||
handler.register_converter(MyType, MyConverter)
|
||||
assert handler.get_converter(MyType) is MyConverter
|
||||
|
||||
result = await handler._resolve_value(
|
||||
value="example",
|
||||
expected_type=MyType,
|
||||
resolved_data=None,
|
||||
guild_id=None,
|
||||
)
|
||||
assert isinstance(result, MyType)
|
||||
assert result.value == "converted-example"
|
113
tests/test_converters.py
Normal file
113
tests/test_converters.py
Normal file
@ -0,0 +1,113 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st
|
||||
|
||||
from disagreement.ext.app_commands.converters import run_converters
|
||||
from disagreement.enums import ApplicationCommandOptionType
|
||||
from disagreement.errors import AppCommandOptionConversionError
|
||||
from conftest import DummyInteraction, DummyClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"py_type, option_type, input_value, expected",
|
||||
[
|
||||
(str, ApplicationCommandOptionType.STRING, "hello", "hello"),
|
||||
(int, ApplicationCommandOptionType.INTEGER, "42", 42),
|
||||
(bool, ApplicationCommandOptionType.BOOLEAN, "true", True),
|
||||
(float, ApplicationCommandOptionType.NUMBER, "3.14", pytest.approx(3.14)),
|
||||
],
|
||||
)
|
||||
async def test_basic_type_converters(
|
||||
basic_interaction, dummy_client, py_type, option_type, input_value, expected
|
||||
):
|
||||
result = await run_converters(
|
||||
basic_interaction, py_type, option_type, input_value, dummy_client
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_converters_error_cases(basic_interaction, dummy_client):
|
||||
with pytest.raises(AppCommandOptionConversionError):
|
||||
await run_converters(
|
||||
basic_interaction,
|
||||
bool,
|
||||
ApplicationCommandOptionType.BOOLEAN,
|
||||
"maybe",
|
||||
dummy_client,
|
||||
)
|
||||
|
||||
with pytest.raises(AppCommandOptionConversionError):
|
||||
await run_converters(
|
||||
basic_interaction,
|
||||
list,
|
||||
ApplicationCommandOptionType.MENTIONABLE,
|
||||
"x",
|
||||
dummy_client,
|
||||
)
|
||||
|
||||
|
||||
@given(st.text())
|
||||
def test_string_roundtrip(value):
|
||||
interaction = DummyInteraction()
|
||||
client = DummyClient()
|
||||
result = asyncio.run(
|
||||
run_converters(
|
||||
interaction,
|
||||
str,
|
||||
ApplicationCommandOptionType.STRING,
|
||||
value,
|
||||
client,
|
||||
)
|
||||
)
|
||||
assert result == value
|
||||
|
||||
|
||||
@given(st.integers())
|
||||
def test_integer_roundtrip(value):
|
||||
interaction = DummyInteraction()
|
||||
client = DummyClient()
|
||||
result = asyncio.run(
|
||||
run_converters(
|
||||
interaction,
|
||||
int,
|
||||
ApplicationCommandOptionType.INTEGER,
|
||||
str(value),
|
||||
client,
|
||||
)
|
||||
)
|
||||
assert result == value
|
||||
|
||||
|
||||
@given(st.booleans())
|
||||
def test_boolean_roundtrip(value):
|
||||
interaction = DummyInteraction()
|
||||
client = DummyClient()
|
||||
raw = "true" if value else "false"
|
||||
result = asyncio.run(
|
||||
run_converters(
|
||||
interaction,
|
||||
bool,
|
||||
ApplicationCommandOptionType.BOOLEAN,
|
||||
raw,
|
||||
client,
|
||||
)
|
||||
)
|
||||
assert result is value
|
||||
|
||||
|
||||
@given(st.floats(allow_nan=False, allow_infinity=False))
|
||||
def test_number_roundtrip(value):
|
||||
interaction = DummyInteraction()
|
||||
client = DummyClient()
|
||||
result = asyncio.run(
|
||||
run_converters(
|
||||
interaction,
|
||||
float,
|
||||
ApplicationCommandOptionType.NUMBER,
|
||||
str(value),
|
||||
client,
|
||||
)
|
||||
)
|
||||
assert result == pytest.approx(value)
|
38
tests/test_error_handler.py
Normal file
38
tests/test_error_handler.py
Normal file
@ -0,0 +1,38 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from disagreement.error_handler import setup_global_error_handler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_exception_logs_error(monkeypatch, capsys):
|
||||
loop = asyncio.new_event_loop()
|
||||
records = []
|
||||
|
||||
def fake_error(msg, *args, **kwargs):
|
||||
records.append(msg % args if args else msg)
|
||||
|
||||
monkeypatch.setattr(logging, "error", fake_error)
|
||||
setup_global_error_handler(loop)
|
||||
exc = RuntimeError("boom")
|
||||
loop.call_exception_handler({"exception": exc})
|
||||
assert any("Unhandled exception" in r for r in records)
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_logs_error(monkeypatch):
|
||||
loop = asyncio.new_event_loop()
|
||||
logged = {}
|
||||
|
||||
def fake_error(msg, *args, **kwargs):
|
||||
logged["msg"] = msg % args if args else msg
|
||||
|
||||
monkeypatch.setattr(logging, "error", fake_error)
|
||||
setup_global_error_handler(loop)
|
||||
loop.call_exception_handler({"message": "oops"})
|
||||
assert "oops" in logged["msg"]
|
||||
loop.close()
|
23
tests/test_errors.py
Normal file
23
tests/test_errors.py
Normal file
@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.errors import (
|
||||
HTTPException,
|
||||
RateLimitError,
|
||||
AppCommandOptionConversionError,
|
||||
)
|
||||
|
||||
|
||||
def test_http_exception_message():
|
||||
exc = HTTPException(message="Bad", status=400)
|
||||
assert str(exc) == "HTTP 400: Bad"
|
||||
|
||||
|
||||
def test_rate_limit_error_inherits_httpexception():
|
||||
exc = RateLimitError(response=None, retry_after=1.0, is_global=True)
|
||||
assert isinstance(exc, HTTPException)
|
||||
assert "Rate limited" in str(exc)
|
||||
|
||||
|
||||
def test_app_command_option_conversion_error():
|
||||
exc = AppCommandOptionConversionError("bad", option_name="opt", original_value="x")
|
||||
assert "opt" in str(exc) and "x" in str(exc)
|
68
tests/test_event_dispatcher.py
Normal file
68
tests/test_event_dispatcher.py
Normal file
@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from disagreement.event_dispatcher import EventDispatcher
|
||||
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self):
|
||||
self.parsed = {}
|
||||
|
||||
def parse_message(self, data):
|
||||
self.parsed["message"] = True
|
||||
return data
|
||||
|
||||
def parse_guild(self, data):
|
||||
self.parsed["guild"] = True
|
||||
return data
|
||||
|
||||
def parse_channel(self, data):
|
||||
self.parsed["channel"] = True
|
||||
return data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_calls_listener():
|
||||
client = DummyClient()
|
||||
dispatcher = EventDispatcher(client)
|
||||
called = {}
|
||||
|
||||
async def listener(payload):
|
||||
called["data"] = payload
|
||||
|
||||
dispatcher.register("MESSAGE_CREATE", listener)
|
||||
await dispatcher.dispatch("MESSAGE_CREATE", {"id": 1})
|
||||
assert called["data"] == {"id": 1}
|
||||
assert client.parsed.get("message")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_listener_no_args():
|
||||
client = DummyClient()
|
||||
dispatcher = EventDispatcher(client)
|
||||
called = False
|
||||
|
||||
async def listener():
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
dispatcher.register("GUILD_CREATE", listener)
|
||||
await dispatcher.dispatch("GUILD_CREATE", {"id": 123})
|
||||
assert called
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_listener():
|
||||
client = DummyClient()
|
||||
dispatcher = EventDispatcher(client)
|
||||
called = False
|
||||
|
||||
async def listener(_):
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
dispatcher.register("MESSAGE_CREATE", listener)
|
||||
dispatcher.unregister("MESSAGE_CREATE", listener)
|
||||
await dispatcher.dispatch("MESSAGE_CREATE", {"id": 1})
|
||||
assert not called
|
42
tests/test_event_error_hook.py
Normal file
42
tests/test_event_error_hook.py
Normal file
@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.event_dispatcher import EventDispatcher
|
||||
|
||||
|
||||
class DummyClient:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_error_hook_called():
|
||||
dispatcher = EventDispatcher(DummyClient())
|
||||
hook = AsyncMock()
|
||||
dispatcher.on_dispatch_error = hook
|
||||
|
||||
async def listener(_):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
dispatcher.register("TEST_EVENT", listener)
|
||||
await dispatcher.dispatch("TEST_EVENT", {})
|
||||
|
||||
hook.assert_awaited_once()
|
||||
args = hook.call_args.args
|
||||
assert args[0] == "TEST_EVENT"
|
||||
assert isinstance(args[1], RuntimeError)
|
||||
assert args[2] is listener
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_error_hook_not_called_when_ok():
|
||||
dispatcher = EventDispatcher(DummyClient())
|
||||
hook = AsyncMock()
|
||||
dispatcher.on_dispatch_error = hook
|
||||
|
||||
async def listener(_):
|
||||
return
|
||||
|
||||
dispatcher.register("TEST_EVENT", listener)
|
||||
await dispatcher.dispatch("TEST_EVENT", {})
|
||||
|
||||
hook.assert_not_awaited()
|
44
tests/test_extension_loader.py
Normal file
44
tests/test_extension_loader.py
Normal file
@ -0,0 +1,44 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from disagreement.ext import loader
|
||||
|
||||
|
||||
def create_dummy_module(name):
|
||||
mod = types.ModuleType(name)
|
||||
called = {"setup": False, "teardown": False}
|
||||
|
||||
def setup():
|
||||
called["setup"] = True
|
||||
|
||||
def teardown():
|
||||
called["teardown"] = True
|
||||
|
||||
mod.setup = setup
|
||||
mod.teardown = teardown
|
||||
sys.modules[name] = mod
|
||||
return called
|
||||
|
||||
|
||||
def test_load_and_unload_extension():
|
||||
called = create_dummy_module("dummy_ext")
|
||||
|
||||
module = loader.load_extension("dummy_ext")
|
||||
assert module is sys.modules["dummy_ext"]
|
||||
assert called["setup"] is True
|
||||
|
||||
loader.unload_extension("dummy_ext")
|
||||
assert called["teardown"] is True
|
||||
assert "dummy_ext" not in loader._loaded_extensions
|
||||
assert "dummy_ext" not in sys.modules
|
||||
|
||||
|
||||
def test_load_extension_twice_raises():
|
||||
called = create_dummy_module("repeat_ext")
|
||||
loader.load_extension("repeat_ext")
|
||||
with pytest.raises(ValueError):
|
||||
loader.load_extension("repeat_ext")
|
||||
loader.unload_extension("repeat_ext")
|
||||
assert called["teardown"] is True
|
67
tests/test_gateway_backoff.py
Normal file
67
tests/test_gateway_backoff.py
Normal file
@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from disagreement.gateway import GatewayClient, GatewayException
|
||||
from disagreement.client import Client
|
||||
|
||||
|
||||
class DummyHTTP:
|
||||
async def get_gateway_bot(self):
|
||||
return {"url": "ws://example"}
|
||||
|
||||
async def _ensure_session(self):
|
||||
self._session = AsyncMock()
|
||||
self._session.ws_connect = AsyncMock()
|
||||
|
||||
|
||||
class DummyDispatcher:
|
||||
async def dispatch(self, *_):
|
||||
pass
|
||||
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.application_id = None # Mock application_id for Client.connect
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_connect_backoff(monkeypatch):
|
||||
http = DummyHTTP()
|
||||
# Mock the GatewayClient's connect method to simulate failures and then success
|
||||
mock_gateway_connect = AsyncMock(
|
||||
side_effect=[GatewayException("boom"), GatewayException("boom"), None]
|
||||
)
|
||||
# Create a dummy client instance
|
||||
client = Client(
|
||||
token="test_token",
|
||||
intents=0,
|
||||
loop=asyncio.get_event_loop(),
|
||||
command_prefix="!",
|
||||
verbose=False,
|
||||
mention_replies=False,
|
||||
shard_count=None,
|
||||
)
|
||||
# Patch the internal _gateway attribute after client initialization
|
||||
# This ensures _initialize_gateway is called and _gateway is set
|
||||
await client._initialize_gateway()
|
||||
monkeypatch.setattr(client._gateway, "connect", mock_gateway_connect)
|
||||
|
||||
# Mock wait_until_ready to prevent it from blocking the test
|
||||
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
|
||||
|
||||
delays = []
|
||||
|
||||
async def fake_sleep(d):
|
||||
delays.append(d)
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
||||
|
||||
# Call the client's connect method, which contains the backoff logic
|
||||
await client.connect()
|
||||
|
||||
# Assert that GatewayClient.connect was called the correct number of times
|
||||
assert mock_gateway_connect.call_count == 3
|
||||
# Assert the delays experienced due to exponential backoff
|
||||
assert delays == [5, 10]
|
57
tests/test_help_command.py
Normal file
57
tests/test_help_command.py
Normal file
@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.ext.commands.core import CommandHandler, Command
|
||||
from disagreement.models import Message
|
||||
|
||||
|
||||
class DummyBot:
|
||||
def __init__(self):
|
||||
self.sent = []
|
||||
|
||||
async def send_message(self, channel_id, content, **kwargs):
|
||||
self.sent.append(content)
|
||||
return {"id": "1", "channel_id": channel_id, "content": content}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_lists_commands():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
|
||||
async def foo(ctx):
|
||||
pass
|
||||
|
||||
handler.add_command(Command(foo, name="foo", brief="Foo cmd"))
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "!help",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("foo" in m for m in bot.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_specific_command():
|
||||
bot = DummyBot()
|
||||
handler = CommandHandler(client=bot, prefix="!")
|
||||
|
||||
async def bar(ctx):
|
||||
pass
|
||||
|
||||
handler.add_command(Command(bar, name="bar", description="Bar desc"))
|
||||
|
||||
msg_data = {
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "!help bar",
|
||||
"timestamp": "t",
|
||||
}
|
||||
msg = Message(msg_data, client_instance=bot)
|
||||
await handler.process_commands(msg)
|
||||
assert any("Bar desc" in m for m in bot.sent)
|
57
tests/test_http_reactions.py
Normal file
57
tests/test_http_reactions.py
Normal file
@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.errors import DisagreementException
|
||||
from disagreement.models import User
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_reaction_calls_http():
|
||||
http = SimpleNamespace(create_reaction=AsyncMock())
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
client._closed = False
|
||||
|
||||
await client.create_reaction("1", "2", "😀")
|
||||
|
||||
http.create_reaction.assert_called_once_with("1", "2", "😀")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_reaction_closed():
|
||||
http = SimpleNamespace(create_reaction=AsyncMock())
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
client._closed = True
|
||||
|
||||
with pytest.raises(DisagreementException):
|
||||
await client.create_reaction("1", "2", "😀")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_reaction_calls_http():
|
||||
http = SimpleNamespace(delete_reaction=AsyncMock())
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
client._closed = False
|
||||
|
||||
await client.delete_reaction("1", "2", "😀")
|
||||
|
||||
http.delete_reaction.assert_called_once_with("1", "2", "😀")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_reactions_parses_users():
|
||||
users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}]
|
||||
http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload))
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
client._closed = False
|
||||
client._users = {}
|
||||
|
||||
users = await client.get_reactions("1", "2", "😀")
|
||||
|
||||
http.get_reactions.assert_called_once_with("1", "2", "😀")
|
||||
assert isinstance(users[0], User)
|
44
tests/test_hybrid_context.py
Normal file
44
tests/test_hybrid_context.py
Normal file
@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from disagreement.hybrid_context import HybridContext
|
||||
from disagreement.ext.app_commands.context import AppCommandContext
|
||||
|
||||
|
||||
class DummyCommandCtx:
|
||||
def __init__(self):
|
||||
self.sent = []
|
||||
|
||||
async def reply(self, *a, **kw):
|
||||
self.sent.append(("reply", a, kw))
|
||||
|
||||
async def edit(self, *a, **kw):
|
||||
self.sent.append(("edit", a, kw))
|
||||
|
||||
|
||||
class DummyAppCtx(AppCommandContext):
|
||||
def __init__(self):
|
||||
self.sent = []
|
||||
|
||||
async def send(self, *a, **kw):
|
||||
self.sent.append(("send", a, kw))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_routes_based_on_context():
|
||||
cctx = DummyCommandCtx()
|
||||
actx = DummyAppCtx()
|
||||
await HybridContext(cctx).send("hi")
|
||||
await HybridContext(actx).send("hi")
|
||||
assert cctx.sent[0][0] == "reply"
|
||||
assert actx.sent[0][0] == "send"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_delegation_and_error():
|
||||
cctx = DummyCommandCtx()
|
||||
hctx = HybridContext(cctx)
|
||||
await hctx.edit("m")
|
||||
assert cctx.sent[0][0] == "edit"
|
||||
with pytest.raises(AttributeError):
|
||||
await HybridContext(DummyAppCtx()).edit("m")
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user