Initial commit
This commit is contained in:
commit
7c7cb4137c
34
.github/workflows/ci.yml
vendored
Normal file
34
.github/workflows/ci.yml
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
name: Python CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install -e .
|
||||||
|
npm install -g pyright
|
||||||
|
|
||||||
|
- name: Run Pyright
|
||||||
|
run: pyright
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
pytest tests/
|
129
.gitignore
vendored
Normal file
129
.gitignore
vendored
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a PyInstaller script; this is deployed to PyInstaller's temporary directory.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# PEP 582; __pypackages__
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# IDE specific files
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
.kilocode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
64
AGENTS.md
Normal file
64
AGENTS.md
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Agents
|
||||||
|
- There are no nested `AGENTS.md` files; this is the only one in the project.
|
||||||
|
- Tools to use for testing: `pyright`, `pylint`, `pytest`, `black`
|
||||||
|
|
||||||
|
- You have a Python script `tavilytool.py` in the project root that you can use to search the web.
|
||||||
|
|
||||||
|
# Tavily API Script Usage Instructions
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
Search for information using simple queries:
|
||||||
|
```bash
|
||||||
|
python tavilytool.py "your search query"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
```bash
|
||||||
|
python tavilytool.py "latest AI development 2024"
|
||||||
|
python tavilytool.py "how to make chocolate chip cookies"
|
||||||
|
python tavilytool.py "current weather in New York"
|
||||||
|
python tavilytool.py "best programming practices Python"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Options
|
||||||
|
|
||||||
|
### Search Depth
|
||||||
|
- **Basic search**: `python tavilytool.py "query"` (default)
|
||||||
|
- **Advanced search**: `python tavilytool.py "query" --depth advanced`
|
||||||
|
|
||||||
|
### Control Results
|
||||||
|
- **Limit results**: `python tavilytool.py "query" --max-results 3`
|
||||||
|
- **Include images**: `python tavilytool.py "query" --include-images`
|
||||||
|
- **Skip AI answer**: `python tavilytool.py "query" --no-answer`
|
||||||
|
|
||||||
|
### Domain Filtering
|
||||||
|
- **Include specific domains**: `python tavilytool.py "query" --include-domains reddit.com stackoverflow.com`
|
||||||
|
- **Exclude domains**: `python tavilytool.py "query" --exclude-domains wikipedia.org`
|
||||||
|
|
||||||
|
### Output Format
|
||||||
|
- **Formatted output**: `python tavilytool.py "query"` (default - human readable)
|
||||||
|
- **Raw JSON**: `python tavilytool.py "query" --raw` (for programmatic processing)
|
||||||
|
|
||||||
|
## Output Structure
|
||||||
|
The default formatted output includes:
|
||||||
|
- 🤖 **AI Answer**: Direct answer to your query
|
||||||
|
- 🔍 **Search Results**: Titles, URLs, and content snippets
|
||||||
|
- 🖼️ **Images**: Relevant images (when `--include-images` is used)
|
||||||
|
|
||||||
|
## Command Combinations
|
||||||
|
```bash
|
||||||
|
# Advanced search with images, limited results
|
||||||
|
python tavilytool.py "machine learning tutorials" --depth advanced --include-images --max-results 3
|
||||||
|
|
||||||
|
# Search specific sites only, raw output
|
||||||
|
python tavilytool.py "Python best practices" --include-domains github.com stackoverflow.com --raw
|
||||||
|
|
||||||
|
# Quick search without AI answer
|
||||||
|
python tavilytool.py "today's news" --no-answer --max-results 5
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tips
|
||||||
|
- Always quote your search queries to handle spaces and special characters
|
||||||
|
- Use `--max-results` to control response length and API usage
|
||||||
|
- Use `--raw` when you need to parse results programmatically
|
||||||
|
- Combine options as needed for specific use cases
|
26
LICENSE
Normal file
26
LICENSE
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
Copyright (c) 2025, Slipstream
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
3
MANIFEST.in
Normal file
3
MANIFEST.in
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
graft docs
|
||||||
|
graft examples
|
||||||
|
include LICENSE
|
131
README.md
Normal file
131
README.md
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
# Disagreement
|
||||||
|
|
||||||
|
A Python library for interacting with the Discord API, with a focus on bot development.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Asynchronous design using `aiohttp`
|
||||||
|
- Gateway and HTTP API clients
|
||||||
|
- Slash command framework
|
||||||
|
- Message component helpers
|
||||||
|
- Built-in caching layer
|
||||||
|
- Experimental voice support
|
||||||
|
- Helpful error handling utilities
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install -U pip
|
||||||
|
pip install disagreement
|
||||||
|
# or install from source for development
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
Requires Python 3.11 or newer.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import disagreement
|
||||||
|
|
||||||
|
# Ensure DISCORD_BOT_TOKEN is set in your environment
|
||||||
|
client = disagreement.Client(token=os.environ.get("DISCORD_BOT_TOKEN"))
|
||||||
|
|
||||||
|
@client.on_event('MESSAGE_CREATE')
|
||||||
|
async def on_message(message: disagreement.Message):
|
||||||
|
print(f"Received: {message.content} from {message.author.username}")
|
||||||
|
if message.content.lower() == '!ping':
|
||||||
|
await message.reply('Pong!')
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
if not client.token:
|
||||||
|
print("Error: DISCORD_BOT_TOKEN environment variable not set.")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
async with client:
|
||||||
|
await asyncio.Future() # run until cancelled
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Bot shutting down...")
|
||||||
|
# Add any other specific exception handling from your library, e.g., disagreement.AuthenticationError
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Global Error Handling
|
||||||
|
|
||||||
|
To ensure unexpected errors don't crash your bot, you can enable the library's
|
||||||
|
global error handler:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import disagreement
|
||||||
|
|
||||||
|
disagreement.setup_global_error_handler()
|
||||||
|
```
|
||||||
|
|
||||||
|
Call this early in your program to log unhandled exceptions instead of letting
|
||||||
|
them terminate the process.
|
||||||
|
|
||||||
|
### Configuring Logging
|
||||||
|
|
||||||
|
Use :func:`disagreement.logging_config.setup_logging` to configure logging for
|
||||||
|
your bot. The helper accepts a logging level and an optional file path.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import logging
|
||||||
|
from disagreement.logging_config import setup_logging
|
||||||
|
|
||||||
|
setup_logging(logging.INFO)
|
||||||
|
# Or log to a file
|
||||||
|
setup_logging(logging.DEBUG, file="bot.log")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Defining Subcommands with `AppCommandGroup`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.ext.app_commands import AppCommandGroup
|
||||||
|
|
||||||
|
settings = AppCommandGroup("settings", "Manage settings")
|
||||||
|
|
||||||
|
@settings.command(name="show")
|
||||||
|
async def show(ctx):
|
||||||
|
"""Displays a setting."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@settings.group("admin", description="Admin settings")
|
||||||
|
def admin_group():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@admin_group.command(name="set")
|
||||||
|
async def set_setting(ctx, key: str, value: str):
|
||||||
|
...
|
||||||
|
## Fetching Guilds
|
||||||
|
|
||||||
|
Use `Client.fetch_guild` to retrieve a guild from the Discord API if it
|
||||||
|
isn't already cached. This is useful when working with guild IDs from
|
||||||
|
outside the gateway events.
|
||||||
|
|
||||||
|
```python
|
||||||
|
guild = await client.fetch_guild("123456789012345678")
|
||||||
|
roles = await client.fetch_roles(guild.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sharding
|
||||||
|
|
||||||
|
To run your bot across multiple gateway shards, pass `shard_count` when creating
|
||||||
|
the client:
|
||||||
|
|
||||||
|
```python
|
||||||
|
client = disagreement.Client(token=BOT_TOKEN, shard_count=2)
|
||||||
|
```
|
||||||
|
|
||||||
|
See `examples/sharded_bot.py` for a full example.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please open an issue or submit a pull request.
|
||||||
|
|
||||||
|
See the [docs](docs/) directory for detailed guides on components, slash commands, caching, and voice features.
|
||||||
|
|
36
disagreement/__init__.py
Normal file
36
disagreement/__init__.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# disagreement/__init__.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Disagreement
|
||||||
|
~~~~~~~~~~~~
|
||||||
|
|
||||||
|
A Python library for interacting with the Discord API.
|
||||||
|
|
||||||
|
:copyright: (c) 2025 Slipstream
|
||||||
|
:license: BSD 3-Clause License, see LICENSE for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__title__ = "disagreement"
|
||||||
|
__author__ = "Slipstream"
|
||||||
|
__license__ = "BSD 3-Clause License"
|
||||||
|
__copyright__ = "Copyright 2025 Slipstream"
|
||||||
|
__version__ = "0.0.1"
|
||||||
|
|
||||||
|
from .client import Client
|
||||||
|
from .models import Message, User
|
||||||
|
from .voice_client import VoiceClient
|
||||||
|
from .typing import Typing
|
||||||
|
from .errors import (
|
||||||
|
DisagreementException,
|
||||||
|
HTTPException,
|
||||||
|
GatewayException,
|
||||||
|
AuthenticationError,
|
||||||
|
)
|
||||||
|
from .enums import GatewayIntent, GatewayOpcode # Export enums
|
||||||
|
from .error_handler import setup_global_error_handler
|
||||||
|
from .hybrid_context import HybridContext
|
||||||
|
from .ext import tasks
|
||||||
|
|
||||||
|
# Set up logging if desired
|
||||||
|
# import logging
|
||||||
|
# logging.getLogger(__name__).addHandler(logging.NullHandler())
|
55
disagreement/cache.py
Normal file
55
disagreement/cache.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .models import Channel, Guild
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Cache(Generic[T]):
|
||||||
|
"""Simple in-memory cache with optional TTL support."""
|
||||||
|
|
||||||
|
def __init__(self, ttl: Optional[float] = None) -> None:
|
||||||
|
self.ttl = ttl
|
||||||
|
self._data: Dict[str, tuple[T, Optional[float]]] = {}
|
||||||
|
|
||||||
|
def set(self, key: str, value: T) -> None:
|
||||||
|
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
|
||||||
|
self._data[key] = (value, expiry)
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[T]:
|
||||||
|
item = self._data.get(key)
|
||||||
|
if not item:
|
||||||
|
return None
|
||||||
|
value, expiry = item
|
||||||
|
if expiry is not None and expiry < time.monotonic():
|
||||||
|
self.invalidate(key)
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
def invalidate(self, key: str) -> None:
|
||||||
|
self._data.pop(key, None)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._data.clear()
|
||||||
|
|
||||||
|
def values(self) -> list[T]:
|
||||||
|
now = time.monotonic()
|
||||||
|
items = []
|
||||||
|
for key, (value, expiry) in list(self._data.items()):
|
||||||
|
if expiry is not None and expiry < now:
|
||||||
|
self.invalidate(key)
|
||||||
|
else:
|
||||||
|
items.append(value)
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
class GuildCache(Cache["Guild"]):
|
||||||
|
"""Cache specifically for :class:`Guild` objects."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelCache(Cache["Channel"]):
|
||||||
|
"""Cache specifically for :class:`Channel` objects."""
|
1144
disagreement/client.py
Normal file
1144
disagreement/client.py
Normal file
File diff suppressed because it is too large
Load Diff
166
disagreement/components.py
Normal file
166
disagreement/components.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
"""Message component utilities."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .enums import ComponentType, ButtonStyle, ChannelType, TextInputStyle
|
||||||
|
from .models import (
|
||||||
|
ActionRow,
|
||||||
|
Button,
|
||||||
|
Component,
|
||||||
|
SelectMenu,
|
||||||
|
SelectOption,
|
||||||
|
PartialEmoji,
|
||||||
|
PartialEmoji,
|
||||||
|
Section,
|
||||||
|
TextDisplay,
|
||||||
|
Thumbnail,
|
||||||
|
MediaGallery,
|
||||||
|
MediaGalleryItem,
|
||||||
|
File,
|
||||||
|
Separator,
|
||||||
|
Container,
|
||||||
|
UnfurledMediaItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover - optional client for future use
|
||||||
|
from .client import Client
|
||||||
|
|
||||||
|
|
||||||
|
def component_factory(
|
||||||
|
data: Dict[str, Any], client: Optional["Client"] = None
|
||||||
|
) -> "Component":
|
||||||
|
"""Create a component object from raw API data."""
|
||||||
|
ctype = ComponentType(data["type"])
|
||||||
|
|
||||||
|
if ctype == ComponentType.ACTION_ROW:
|
||||||
|
row = ActionRow()
|
||||||
|
for comp in data.get("components", []):
|
||||||
|
row.add_component(component_factory(comp, client))
|
||||||
|
return row
|
||||||
|
|
||||||
|
if ctype == ComponentType.BUTTON:
|
||||||
|
return Button(
|
||||||
|
style=ButtonStyle(data["style"]),
|
||||||
|
label=data.get("label"),
|
||||||
|
emoji=PartialEmoji(data["emoji"]) if data.get("emoji") else None,
|
||||||
|
custom_id=data.get("custom_id"),
|
||||||
|
url=data.get("url"),
|
||||||
|
disabled=data.get("disabled", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctype in {
|
||||||
|
ComponentType.STRING_SELECT,
|
||||||
|
ComponentType.USER_SELECT,
|
||||||
|
ComponentType.ROLE_SELECT,
|
||||||
|
ComponentType.MENTIONABLE_SELECT,
|
||||||
|
ComponentType.CHANNEL_SELECT,
|
||||||
|
}:
|
||||||
|
options = [
|
||||||
|
SelectOption(
|
||||||
|
label=o["label"],
|
||||||
|
value=o["value"],
|
||||||
|
description=o.get("description"),
|
||||||
|
emoji=PartialEmoji(o["emoji"]) if o.get("emoji") else None,
|
||||||
|
default=o.get("default", False),
|
||||||
|
)
|
||||||
|
for o in data.get("options", [])
|
||||||
|
]
|
||||||
|
channel_types = None
|
||||||
|
if ctype == ComponentType.CHANNEL_SELECT and data.get("channel_types"):
|
||||||
|
channel_types = [ChannelType(ct) for ct in data.get("channel_types", [])]
|
||||||
|
|
||||||
|
return SelectMenu(
|
||||||
|
custom_id=data["custom_id"],
|
||||||
|
options=options,
|
||||||
|
placeholder=data.get("placeholder"),
|
||||||
|
min_values=data.get("min_values", 1),
|
||||||
|
max_values=data.get("max_values", 1),
|
||||||
|
disabled=data.get("disabled", False),
|
||||||
|
channel_types=channel_types,
|
||||||
|
type=ctype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctype == ComponentType.TEXT_INPUT:
|
||||||
|
from .ui.modal import TextInput
|
||||||
|
|
||||||
|
return TextInput(
|
||||||
|
label=data.get("label", ""),
|
||||||
|
custom_id=data.get("custom_id"),
|
||||||
|
style=TextInputStyle(data.get("style", TextInputStyle.SHORT.value)),
|
||||||
|
placeholder=data.get("placeholder"),
|
||||||
|
required=data.get("required", True),
|
||||||
|
min_length=data.get("min_length"),
|
||||||
|
max_length=data.get("max_length"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctype == ComponentType.SECTION:
|
||||||
|
# The components in a section can only be TextDisplay
|
||||||
|
section_components = []
|
||||||
|
for c in data.get("components", []):
|
||||||
|
comp = component_factory(c, client)
|
||||||
|
if isinstance(comp, TextDisplay):
|
||||||
|
section_components.append(comp)
|
||||||
|
|
||||||
|
accessory = None
|
||||||
|
if data.get("accessory"):
|
||||||
|
acc_comp = component_factory(data["accessory"], client)
|
||||||
|
if isinstance(acc_comp, (Thumbnail, Button)):
|
||||||
|
accessory = acc_comp
|
||||||
|
|
||||||
|
return Section(
|
||||||
|
components=section_components,
|
||||||
|
accessory=accessory,
|
||||||
|
id=data.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctype == ComponentType.TEXT_DISPLAY:
|
||||||
|
return TextDisplay(content=data["content"], id=data.get("id"))
|
||||||
|
|
||||||
|
if ctype == ComponentType.THUMBNAIL:
|
||||||
|
return Thumbnail(
|
||||||
|
media=UnfurledMediaItem(**data["media"]),
|
||||||
|
description=data.get("description"),
|
||||||
|
spoiler=data.get("spoiler", False),
|
||||||
|
id=data.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctype == ComponentType.MEDIA_GALLERY:
|
||||||
|
return MediaGallery(
|
||||||
|
items=[
|
||||||
|
MediaGalleryItem(
|
||||||
|
media=UnfurledMediaItem(**i["media"]),
|
||||||
|
description=i.get("description"),
|
||||||
|
spoiler=i.get("spoiler", False),
|
||||||
|
)
|
||||||
|
for i in data.get("items", [])
|
||||||
|
],
|
||||||
|
id=data.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctype == ComponentType.FILE:
|
||||||
|
return File(
|
||||||
|
file=UnfurledMediaItem(**data["file"]),
|
||||||
|
spoiler=data.get("spoiler", False),
|
||||||
|
id=data.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctype == ComponentType.SEPARATOR:
|
||||||
|
return Separator(
|
||||||
|
divider=data.get("divider", True),
|
||||||
|
spacing=data.get("spacing", 1),
|
||||||
|
id=data.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctype == ComponentType.CONTAINER:
|
||||||
|
return Container(
|
||||||
|
components=[
|
||||||
|
component_factory(c, client) for c in data.get("components", [])
|
||||||
|
],
|
||||||
|
accent_color=data.get("accent_color"),
|
||||||
|
spoiler=data.get("spoiler", False),
|
||||||
|
id=data.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported component type: {ctype}")
|
357
disagreement/enums.py
Normal file
357
disagreement/enums.py
Normal file
@ -0,0 +1,357 @@
|
|||||||
|
# disagreement/enums.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Enums for Discord constants.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import IntEnum, Enum # Import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayOpcode(IntEnum):
|
||||||
|
"""Represents a Discord Gateway Opcode."""
|
||||||
|
|
||||||
|
DISPATCH = 0
|
||||||
|
HEARTBEAT = 1
|
||||||
|
IDENTIFY = 2
|
||||||
|
PRESENCE_UPDATE = 3
|
||||||
|
VOICE_STATE_UPDATE = 4
|
||||||
|
RESUME = 6
|
||||||
|
RECONNECT = 7
|
||||||
|
REQUEST_GUILD_MEMBERS = 8
|
||||||
|
INVALID_SESSION = 9
|
||||||
|
HELLO = 10
|
||||||
|
HEARTBEAT_ACK = 11
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayIntent(IntEnum):
|
||||||
|
"""Represents a Discord Gateway Intent bit.
|
||||||
|
|
||||||
|
Intents are used to subscribe to specific groups of events from the Gateway.
|
||||||
|
"""
|
||||||
|
|
||||||
|
GUILDS = 1 << 0
|
||||||
|
GUILD_MEMBERS = 1 << 1 # Privileged
|
||||||
|
GUILD_MODERATION = 1 << 2 # Formerly GUILD_BANS
|
||||||
|
GUILD_EMOJIS_AND_STICKERS = 1 << 3
|
||||||
|
GUILD_INTEGRATIONS = 1 << 4
|
||||||
|
GUILD_WEBHOOKS = 1 << 5
|
||||||
|
GUILD_INVITES = 1 << 6
|
||||||
|
GUILD_VOICE_STATES = 1 << 7
|
||||||
|
GUILD_PRESENCES = 1 << 8 # Privileged
|
||||||
|
GUILD_MESSAGES = 1 << 9
|
||||||
|
GUILD_MESSAGE_REACTIONS = 1 << 10
|
||||||
|
GUILD_MESSAGE_TYPING = 1 << 11
|
||||||
|
DIRECT_MESSAGES = 1 << 12
|
||||||
|
DIRECT_MESSAGE_REACTIONS = 1 << 13
|
||||||
|
DIRECT_MESSAGE_TYPING = 1 << 14
|
||||||
|
MESSAGE_CONTENT = 1 << 15 # Privileged (as of Aug 31, 2022)
|
||||||
|
GUILD_SCHEDULED_EVENTS = 1 << 16
|
||||||
|
AUTO_MODERATION_CONFIGURATION = 1 << 20
|
||||||
|
AUTO_MODERATION_EXECUTION = 1 << 21
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls) -> int:
|
||||||
|
"""Returns default intents (excluding privileged ones like members, presences, message content)."""
|
||||||
|
return (
|
||||||
|
cls.GUILDS
|
||||||
|
| cls.GUILD_MODERATION
|
||||||
|
| cls.GUILD_EMOJIS_AND_STICKERS
|
||||||
|
| cls.GUILD_INTEGRATIONS
|
||||||
|
| cls.GUILD_WEBHOOKS
|
||||||
|
| cls.GUILD_INVITES
|
||||||
|
| cls.GUILD_VOICE_STATES
|
||||||
|
| cls.GUILD_MESSAGES
|
||||||
|
| cls.GUILD_MESSAGE_REACTIONS
|
||||||
|
| cls.GUILD_MESSAGE_TYPING
|
||||||
|
| cls.DIRECT_MESSAGES
|
||||||
|
| cls.DIRECT_MESSAGE_REACTIONS
|
||||||
|
| cls.DIRECT_MESSAGE_TYPING
|
||||||
|
| cls.GUILD_SCHEDULED_EVENTS
|
||||||
|
| cls.AUTO_MODERATION_CONFIGURATION
|
||||||
|
| cls.AUTO_MODERATION_EXECUTION
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all(cls) -> int:
|
||||||
|
"""Returns all intents, including privileged ones. Use with caution."""
|
||||||
|
val = 0
|
||||||
|
for intent in cls:
|
||||||
|
val |= intent.value
|
||||||
|
return val
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def privileged(cls) -> int:
|
||||||
|
"""Returns a bitmask of all privileged intents."""
|
||||||
|
return cls.GUILD_MEMBERS | cls.GUILD_PRESENCES | cls.MESSAGE_CONTENT
|
||||||
|
|
||||||
|
|
||||||
|
# --- Application Command Enums ---
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationCommandType(IntEnum):
|
||||||
|
"""Type of application command."""
|
||||||
|
|
||||||
|
CHAT_INPUT = 1
|
||||||
|
USER = 2
|
||||||
|
MESSAGE = 3
|
||||||
|
PRIMARY_ENTRY_POINT = 4
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationCommandOptionType(IntEnum):
|
||||||
|
"""Type of application command option."""
|
||||||
|
|
||||||
|
SUB_COMMAND = 1
|
||||||
|
SUB_COMMAND_GROUP = 2
|
||||||
|
STRING = 3
|
||||||
|
INTEGER = 4 # Any integer between -2^53 and 2^53
|
||||||
|
BOOLEAN = 5
|
||||||
|
USER = 6
|
||||||
|
CHANNEL = 7 # Includes all channel types + categories
|
||||||
|
ROLE = 8
|
||||||
|
MENTIONABLE = 9 # Includes users and roles
|
||||||
|
NUMBER = 10 # Any double between -2^53 and 2^53
|
||||||
|
ATTACHMENT = 11
|
||||||
|
|
||||||
|
|
||||||
|
class InteractionType(IntEnum):
|
||||||
|
"""Type of interaction."""
|
||||||
|
|
||||||
|
PING = 1
|
||||||
|
APPLICATION_COMMAND = 2
|
||||||
|
MESSAGE_COMPONENT = 3
|
||||||
|
APPLICATION_COMMAND_AUTOCOMPLETE = 4
|
||||||
|
MODAL_SUBMIT = 5
|
||||||
|
|
||||||
|
|
||||||
|
class InteractionCallbackType(IntEnum):
|
||||||
|
"""Type of interaction callback."""
|
||||||
|
|
||||||
|
PONG = 1
|
||||||
|
CHANNEL_MESSAGE_WITH_SOURCE = 4
|
||||||
|
DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE = 5
|
||||||
|
DEFERRED_UPDATE_MESSAGE = 6
|
||||||
|
UPDATE_MESSAGE = 7
|
||||||
|
APPLICATION_COMMAND_AUTOCOMPLETE_RESULT = 8
|
||||||
|
MODAL = 9 # Response to send a modal
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationType(IntEnum):
|
||||||
|
"""
|
||||||
|
Installation context(s) where the command is available,
|
||||||
|
only for globally-scoped commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
GUILD_INSTALL = (
|
||||||
|
0 # Command is available when the app is installed to a guild (default)
|
||||||
|
)
|
||||||
|
USER_INSTALL = 1 # Command is available when the app is installed to a user
|
||||||
|
|
||||||
|
|
||||||
|
class InteractionContextType(IntEnum):
|
||||||
|
"""
|
||||||
|
Interaction context(s) where the command can be used,
|
||||||
|
only for globally-scoped commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
GUILD = 0 # Command can be used in guilds
|
||||||
|
BOT_DM = 1 # Command can be used in DMs with the app's bot user
|
||||||
|
PRIVATE_CHANNEL = 2 # Command can be used in Group DMs and DMs (requires USER_INSTALL integration_type)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFlags(IntEnum):
|
||||||
|
"""Represents the flags of a message."""
|
||||||
|
|
||||||
|
CROSSPOSTED = 1 << 0
|
||||||
|
IS_CROSSPOST = 1 << 1
|
||||||
|
SUPPRESS_EMBEDS = 1 << 2
|
||||||
|
SOURCE_MESSAGE_DELETED = 1 << 3
|
||||||
|
URGENT = 1 << 4
|
||||||
|
HAS_THREAD = 1 << 5
|
||||||
|
EPHEMERAL = 1 << 6
|
||||||
|
LOADING = 1 << 7
|
||||||
|
FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8
|
||||||
|
SUPPRESS_NOTIFICATIONS = (
|
||||||
|
1 << 12
|
||||||
|
) # Discord specific, was previously 1 << 4 (IS_VOICE_MESSAGE)
|
||||||
|
IS_COMPONENTS_V2 = 1 << 15
|
||||||
|
|
||||||
|
|
||||||
|
# --- Guild Enums ---
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationLevel(IntEnum):
|
||||||
|
"""Guild verification level."""
|
||||||
|
|
||||||
|
NONE = 0
|
||||||
|
LOW = 1
|
||||||
|
MEDIUM = 2
|
||||||
|
HIGH = 3
|
||||||
|
VERY_HIGH = 4
|
||||||
|
|
||||||
|
|
||||||
|
class MessageNotificationLevel(IntEnum):
|
||||||
|
"""Default message notification level for a guild."""
|
||||||
|
|
||||||
|
ALL_MESSAGES = 0
|
||||||
|
ONLY_MENTIONS = 1
|
||||||
|
|
||||||
|
|
||||||
|
class ExplicitContentFilterLevel(IntEnum):
|
||||||
|
"""Explicit content filter level for a guild."""
|
||||||
|
|
||||||
|
DISABLED = 0
|
||||||
|
MEMBERS_WITHOUT_ROLES = 1
|
||||||
|
ALL_MEMBERS = 2
|
||||||
|
|
||||||
|
|
||||||
|
class MFALevel(IntEnum):
|
||||||
|
"""Multi-Factor Authentication level for a guild."""
|
||||||
|
|
||||||
|
NONE = 0
|
||||||
|
ELEVATED = 1
|
||||||
|
|
||||||
|
|
||||||
|
class GuildNSFWLevel(IntEnum):
|
||||||
|
"""NSFW level of a guild."""
|
||||||
|
|
||||||
|
DEFAULT = 0
|
||||||
|
EXPLICIT = 1
|
||||||
|
SAFE = 2
|
||||||
|
AGE_RESTRICTED = 3
|
||||||
|
|
||||||
|
|
||||||
|
class PremiumTier(IntEnum):
|
||||||
|
"""Guild premium tier (boost level)."""
|
||||||
|
|
||||||
|
NONE = 0
|
||||||
|
TIER_1 = 1
|
||||||
|
TIER_2 = 2
|
||||||
|
TIER_3 = 3
|
||||||
|
|
||||||
|
|
||||||
|
class GuildFeature(str, Enum): # Changed from IntEnum to Enum
|
||||||
|
"""Features that a guild can have.
|
||||||
|
|
||||||
|
Note: This is not an exhaustive list and Discord may add more.
|
||||||
|
Using str as a base allows for unknown features to be stored as strings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ANIMATED_BANNER = "ANIMATED_BANNER"
|
||||||
|
ANIMATED_ICON = "ANIMATED_ICON"
|
||||||
|
APPLICATION_COMMAND_PERMISSIONS_V2 = "APPLICATION_COMMAND_PERMISSIONS_V2"
|
||||||
|
AUTO_MODERATION = "AUTO_MODERATION"
|
||||||
|
BANNER = "BANNER"
|
||||||
|
COMMUNITY = "COMMUNITY"
|
||||||
|
CREATOR_MONETIZABLE_PROVISIONAL = "CREATOR_MONETIZABLE_PROVISIONAL"
|
||||||
|
CREATOR_STORE_PAGE = "CREATOR_STORE_PAGE"
|
||||||
|
DEVELOPER_SUPPORT_SERVER = "DEVELOPER_SUPPORT_SERVER"
|
||||||
|
DISCOVERABLE = "DISCOVERABLE"
|
||||||
|
FEATURABLE = "FEATURABLE"
|
||||||
|
INVITES_DISABLED = "INVITES_DISABLED"
|
||||||
|
INVITE_SPLASH = "INVITE_SPLASH"
|
||||||
|
MEMBER_VERIFICATION_GATE_ENABLED = "MEMBER_VERIFICATION_GATE_ENABLED"
|
||||||
|
MORE_STICKERS = "MORE_STICKERS"
|
||||||
|
NEWS = "NEWS"
|
||||||
|
PARTNERED = "PARTNERED"
|
||||||
|
PREVIEW_ENABLED = "PREVIEW_ENABLED"
|
||||||
|
RAID_ALERTS_DISABLED = "RAID_ALERTS_DISABLED"
|
||||||
|
ROLE_ICONS = "ROLE_ICONS"
|
||||||
|
ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE = (
|
||||||
|
"ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE"
|
||||||
|
)
|
||||||
|
ROLE_SUBSCRIPTIONS_ENABLED = "ROLE_SUBSCRIPTIONS_ENABLED"
|
||||||
|
TICKETED_EVENTS_ENABLED = "TICKETED_EVENTS_ENABLED"
|
||||||
|
VANITY_URL = "VANITY_URL"
|
||||||
|
VERIFIED = "VERIFIED"
|
||||||
|
VIP_REGIONS = "VIP_REGIONS"
|
||||||
|
WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED"
|
||||||
|
# Add more as they become known or needed
|
||||||
|
|
||||||
|
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
|
||||||
|
@classmethod
|
||||||
|
def _missing_(cls, value): # type: ignore
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Channel Enums ---
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelType(IntEnum):
|
||||||
|
"""Type of channel."""
|
||||||
|
|
||||||
|
GUILD_TEXT = 0 # a text channel within a server
|
||||||
|
DM = 1 # a direct message between users
|
||||||
|
GUILD_VOICE = 2 # a voice channel within a server
|
||||||
|
GROUP_DM = 3 # a direct message between multiple users
|
||||||
|
GUILD_CATEGORY = 4 # an organizational category that contains up to 50 channels
|
||||||
|
GUILD_ANNOUNCEMENT = 5 # a channel that users can follow and crosspost into their own server (formerly GUILD_NEWS)
|
||||||
|
ANNOUNCEMENT_THREAD = (
|
||||||
|
10 # a temporary sub-channel within a GUILD_ANNOUNCEMENT channel
|
||||||
|
)
|
||||||
|
PUBLIC_THREAD = (
|
||||||
|
11 # a temporary sub-channel within a GUILD_TEXT or GUILD_ANNOUNCEMENT channel
|
||||||
|
)
|
||||||
|
PRIVATE_THREAD = 12 # a temporary sub-channel within a GUILD_TEXT channel that is only viewable by those invited and those with the MANAGE_THREADS permission
|
||||||
|
GUILD_STAGE_VOICE = (
|
||||||
|
13 # a voice channel for hosting events with speakers and audiences
|
||||||
|
)
|
||||||
|
GUILD_DIRECTORY = 14 # a channel in a hub containing the listed servers
|
||||||
|
GUILD_FORUM = 15 # (Still in development) a channel that can only contain threads
|
||||||
|
GUILD_MEDIA = 16 # (Still in development) a channel that can only contain media
|
||||||
|
|
||||||
|
|
||||||
|
class OverwriteType(IntEnum):
|
||||||
|
"""Type of target for a permission overwrite."""
|
||||||
|
|
||||||
|
ROLE = 0
|
||||||
|
MEMBER = 1
|
||||||
|
|
||||||
|
|
||||||
|
# --- Component Enums ---
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentType(IntEnum):
|
||||||
|
"""Type of message component."""
|
||||||
|
|
||||||
|
ACTION_ROW = 1
|
||||||
|
BUTTON = 2
|
||||||
|
STRING_SELECT = 3 # Formerly SELECT_MENU
|
||||||
|
TEXT_INPUT = 4
|
||||||
|
USER_SELECT = 5
|
||||||
|
ROLE_SELECT = 6
|
||||||
|
MENTIONABLE_SELECT = 7
|
||||||
|
CHANNEL_SELECT = 8
|
||||||
|
SECTION = 9
|
||||||
|
TEXT_DISPLAY = 10
|
||||||
|
THUMBNAIL = 11
|
||||||
|
MEDIA_GALLERY = 12
|
||||||
|
FILE = 13
|
||||||
|
SEPARATOR = 14
|
||||||
|
CONTAINER = 17
|
||||||
|
|
||||||
|
|
||||||
|
class ButtonStyle(IntEnum):
|
||||||
|
"""Style of a button component."""
|
||||||
|
|
||||||
|
# Blurple
|
||||||
|
PRIMARY = 1
|
||||||
|
# Grey
|
||||||
|
SECONDARY = 2
|
||||||
|
# Green
|
||||||
|
SUCCESS = 3
|
||||||
|
# Red
|
||||||
|
DANGER = 4
|
||||||
|
# Grey, navigates to a URL
|
||||||
|
LINK = 5
|
||||||
|
|
||||||
|
|
||||||
|
class TextInputStyle(IntEnum):
|
||||||
|
"""Style of a text input component."""
|
||||||
|
|
||||||
|
SHORT = 1
|
||||||
|
PARAGRAPH = 2
|
||||||
|
|
||||||
|
|
||||||
|
# Example of how you might combine intents:
|
||||||
|
# intents = GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
|
||||||
|
# client = Client(token="YOUR_TOKEN", intents=intents)
|
33
disagreement/error_handler.py
Normal file
33
disagreement/error_handler.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .logging_config import setup_logging
|
||||||
|
|
||||||
|
|
||||||
|
def setup_global_error_handler(
|
||||||
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Configure a basic global error handler for the provided loop.
|
||||||
|
|
||||||
|
The handler logs unhandled exceptions so they don't crash the bot.
|
||||||
|
"""
|
||||||
|
if loop is None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
if not logging.getLogger().hasHandlers():
|
||||||
|
setup_logging(logging.ERROR)
|
||||||
|
|
||||||
|
def handle_exception(loop: asyncio.AbstractEventLoop, context: dict) -> None:
|
||||||
|
exception = context.get("exception")
|
||||||
|
if exception:
|
||||||
|
logging.error("Unhandled exception in event loop: %s", exception)
|
||||||
|
traceback.print_exception(
|
||||||
|
type(exception), exception, exception.__traceback__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = context.get("message")
|
||||||
|
logging.error("Event loop error: %s", message)
|
||||||
|
|
||||||
|
loop.set_exception_handler(handle_exception)
|
112
disagreement/errors.py
Normal file
112
disagreement/errors.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
# disagreement/errors.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Custom exceptions for the Disagreement library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Any # Add Optional and Any here
|
||||||
|
|
||||||
|
|
||||||
|
class DisagreementException(Exception):
|
||||||
|
"""Base exception class for all errors raised by this library."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPException(DisagreementException):
|
||||||
|
"""Exception raised for HTTP-related errors.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
response: The aiohttp response object, if available.
|
||||||
|
status: The HTTP status code.
|
||||||
|
text: The response text, if available.
|
||||||
|
error_code: Discord specific error code, if available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, response=None, message=None, *, status=None, text=None, error_code=None
|
||||||
|
):
|
||||||
|
self.response = response
|
||||||
|
self.status = status or (response.status if response else None)
|
||||||
|
self.text = text or (
|
||||||
|
response.text if response else None
|
||||||
|
) # Or await response.text() if in async context
|
||||||
|
self.error_code = error_code
|
||||||
|
|
||||||
|
full_message = f"HTTP {self.status}"
|
||||||
|
if message:
|
||||||
|
full_message += f": {message}"
|
||||||
|
elif self.text:
|
||||||
|
full_message += f": {self.text}"
|
||||||
|
if self.error_code:
|
||||||
|
full_message += f" (Discord Error Code: {self.error_code})"
|
||||||
|
|
||||||
|
super().__init__(full_message)
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayException(DisagreementException):
|
||||||
|
"""Exception raised for errors related to the Discord Gateway connection or protocol."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(DisagreementException):
|
||||||
|
"""Exception raised for authentication failures (e.g., invalid token)."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitError(HTTPException):
|
||||||
|
"""
|
||||||
|
Exception raised when a rate limit is encountered.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
retry_after (float): The number of seconds to wait before retrying.
|
||||||
|
is_global (bool): Whether this is a global rate limit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, response, message=None, *, retry_after: float, is_global: bool = False
|
||||||
|
):
|
||||||
|
self.retry_after = retry_after
|
||||||
|
self.is_global = is_global
|
||||||
|
super().__init__(
|
||||||
|
response,
|
||||||
|
message
|
||||||
|
or f"Rate limited. Retry after: {retry_after}s. Global: {is_global}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# You can add more specific exceptions as needed, e.g.:
|
||||||
|
# class NotFound(HTTPException):
|
||||||
|
# """Raised for 404 Not Found errors."""
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# class Forbidden(HTTPException):
|
||||||
|
# """Raised for 403 Forbidden errors."""
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
|
class AppCommandError(DisagreementException):
|
||||||
|
"""Base exception for application command related errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AppCommandOptionConversionError(AppCommandError):
|
||||||
|
"""Exception raised when an application command option fails to convert."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
option_name: Optional[str] = None,
|
||||||
|
original_value: Any = None,
|
||||||
|
):
|
||||||
|
self.option_name = option_name
|
||||||
|
self.original_value = original_value
|
||||||
|
full_message = message
|
||||||
|
if option_name:
|
||||||
|
full_message = f"Failed to convert option '{option_name}': {message}"
|
||||||
|
if original_value is not None:
|
||||||
|
full_message += f" (Original value: '{original_value}')"
|
||||||
|
super().__init__(full_message)
|
243
disagreement/event_dispatcher.py
Normal file
243
disagreement/event_dispatcher.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
# disagreement/event_dispatcher.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Event dispatcher for handling Discord Gateway events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Set,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Awaitable,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .models import Message, User # Assuming User might be part of other events
|
||||||
|
from .errors import DisagreementException
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import Client # For type hinting to avoid circular imports
|
||||||
|
from .interactions import Interaction
|
||||||
|
|
||||||
|
# Type alias for an event listener
|
||||||
|
EventListener = Callable[..., Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
class EventDispatcher:
|
||||||
|
"""
|
||||||
|
Manages registration and dispatching of event listeners.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client_instance: "Client"):
|
||||||
|
self._client: "Client" = client_instance
|
||||||
|
self._listeners: Dict[str, List[EventListener]] = defaultdict(list)
|
||||||
|
self._waiters: Dict[
|
||||||
|
str, List[tuple[asyncio.Future, Optional[Callable[[Any], bool]]]]
|
||||||
|
] = defaultdict(list)
|
||||||
|
self.on_dispatch_error: Optional[
|
||||||
|
Callable[[str, Exception, EventListener], Awaitable[None]]
|
||||||
|
] = None
|
||||||
|
# Pre-defined parsers for specific event types to convert raw data to models
|
||||||
|
self._event_parsers: Dict[str, Callable[[Dict[str, Any]], Any]] = {
|
||||||
|
"MESSAGE_CREATE": self._parse_message_create,
|
||||||
|
"INTERACTION_CREATE": self._parse_interaction_create,
|
||||||
|
"GUILD_CREATE": self._parse_guild_create,
|
||||||
|
"CHANNEL_CREATE": self._parse_channel_create,
|
||||||
|
"PRESENCE_UPDATE": self._parse_presence_update,
|
||||||
|
"TYPING_START": self._parse_typing_start,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_message_create(self, data: Dict[str, Any]) -> Message:
|
||||||
|
"""Parses raw MESSAGE_CREATE data into a Message object."""
|
||||||
|
return self._client.parse_message(data)
|
||||||
|
|
||||||
|
def _parse_interaction_create(self, data: Dict[str, Any]) -> "Interaction":
|
||||||
|
"""Parses raw INTERACTION_CREATE data into an Interaction object."""
|
||||||
|
from .interactions import Interaction
|
||||||
|
|
||||||
|
return Interaction(data=data, client_instance=self._client)
|
||||||
|
|
||||||
|
def _parse_guild_create(self, data: Dict[str, Any]):
|
||||||
|
"""Parses raw GUILD_CREATE data into a Guild object."""
|
||||||
|
|
||||||
|
return self._client.parse_guild(data)
|
||||||
|
|
||||||
|
def _parse_channel_create(self, data: Dict[str, Any]):
|
||||||
|
"""Parses raw CHANNEL_CREATE data into a Channel object."""
|
||||||
|
|
||||||
|
return self._client.parse_channel(data)
|
||||||
|
|
||||||
|
def _parse_presence_update(self, data: Dict[str, Any]):
|
||||||
|
"""Parses raw PRESENCE_UPDATE data into a PresenceUpdate object."""
|
||||||
|
|
||||||
|
from .models import PresenceUpdate
|
||||||
|
|
||||||
|
return PresenceUpdate(data, client_instance=self._client)
|
||||||
|
|
||||||
|
def _parse_typing_start(self, data: Dict[str, Any]):
|
||||||
|
"""Parses raw TYPING_START data into a TypingStart object."""
|
||||||
|
|
||||||
|
from .models import TypingStart
|
||||||
|
|
||||||
|
return TypingStart(data, client_instance=self._client)
|
||||||
|
|
||||||
|
# Potentially add _parse_user for events that directly provide a full user object
|
||||||
|
# def _parse_user_update(self, data: Dict[str, Any]) -> User:
|
||||||
|
# return User(data=data)
|
||||||
|
|
||||||
|
def register(self, event_name: str, coro: EventListener):
|
||||||
|
"""
|
||||||
|
Registers a coroutine function to listen for a specific event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_name (str): The name of the event (e.g., 'MESSAGE_CREATE').
|
||||||
|
coro (Callable): The coroutine function to call when the event occurs.
|
||||||
|
It should accept arguments appropriate for the event.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the provided callback is not a coroutine function.
|
||||||
|
"""
|
||||||
|
if not inspect.iscoroutinefunction(coro):
|
||||||
|
raise TypeError(
|
||||||
|
f"Event listener for '{event_name}' must be a coroutine function (async def)."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize event name, e.g., 'on_message' -> 'MESSAGE_CREATE'
|
||||||
|
# For now, we assume event_name is already the Discord event type string.
|
||||||
|
# If using decorators like @client.on_message, the decorator would handle this mapping.
|
||||||
|
self._listeners[event_name.upper()].append(coro)
|
||||||
|
|
||||||
|
def unregister(self, event_name: str, coro: EventListener):
|
||||||
|
"""
|
||||||
|
Unregisters a coroutine function from an event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_name (str): The name of the event.
|
||||||
|
coro (Callable): The coroutine function to unregister.
|
||||||
|
"""
|
||||||
|
event_name_upper = event_name.upper()
|
||||||
|
if event_name_upper in self._listeners:
|
||||||
|
try:
|
||||||
|
self._listeners[event_name_upper].remove(coro)
|
||||||
|
except ValueError:
|
||||||
|
pass # Listener not in list
|
||||||
|
|
||||||
|
def add_waiter(
|
||||||
|
self,
|
||||||
|
event_name: str,
|
||||||
|
future: asyncio.Future,
|
||||||
|
check: Optional[Callable[[Any], bool]] = None,
|
||||||
|
) -> None:
|
||||||
|
self._waiters[event_name.upper()].append((future, check))
|
||||||
|
|
||||||
|
def remove_waiter(self, event_name: str, future: asyncio.Future) -> None:
|
||||||
|
waiters = self._waiters.get(event_name.upper())
|
||||||
|
if not waiters:
|
||||||
|
return
|
||||||
|
self._waiters[event_name.upper()] = [
|
||||||
|
(f, c) for f, c in waiters if f is not future
|
||||||
|
]
|
||||||
|
if not self._waiters[event_name.upper()]:
|
||||||
|
self._waiters.pop(event_name.upper(), None)
|
||||||
|
|
||||||
|
def _resolve_waiters(self, event_name: str, data: Any) -> None:
|
||||||
|
waiters = self._waiters.get(event_name)
|
||||||
|
if not waiters:
|
||||||
|
return
|
||||||
|
to_remove: List[tuple[asyncio.Future, Optional[Callable[[Any], bool]]]] = []
|
||||||
|
for future, check in waiters:
|
||||||
|
if future.cancelled():
|
||||||
|
to_remove.append((future, check))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if check is None or check(data):
|
||||||
|
future.set_result(data)
|
||||||
|
to_remove.append((future, check))
|
||||||
|
except Exception as exc:
|
||||||
|
future.set_exception(exc)
|
||||||
|
to_remove.append((future, check))
|
||||||
|
for item in to_remove:
|
||||||
|
if item in waiters:
|
||||||
|
waiters.remove(item)
|
||||||
|
if not waiters:
|
||||||
|
self._waiters.pop(event_name, None)
|
||||||
|
|
||||||
|
async def dispatch(self, event_name: str, raw_data: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Dispatches an event to all registered listeners.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_name (str): The name of the event (e.g., 'MESSAGE_CREATE').
|
||||||
|
raw_data (Dict[str, Any]): The raw data payload from the Discord Gateway for this event.
|
||||||
|
"""
|
||||||
|
event_name_upper = event_name.upper()
|
||||||
|
listeners = self._listeners.get(event_name_upper)
|
||||||
|
|
||||||
|
if not listeners:
|
||||||
|
# print(f"No listeners for event {event_name_upper}")
|
||||||
|
return
|
||||||
|
|
||||||
|
parsed_data: Any = raw_data
|
||||||
|
if event_name_upper in self._event_parsers:
|
||||||
|
try:
|
||||||
|
parser = self._event_parsers[event_name_upper]
|
||||||
|
parsed_data = parser(raw_data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing event data for {event_name_upper}: {e}")
|
||||||
|
# Optionally, dispatch with raw_data or raise, or log more formally
|
||||||
|
# For now, we'll proceed to dispatch with raw_data if parsing fails,
|
||||||
|
# or just log and return if parsed_data is critical.
|
||||||
|
# Let's assume if a parser exists, its output is critical.
|
||||||
|
return
|
||||||
|
|
||||||
|
self._resolve_waiters(event_name_upper, parsed_data)
|
||||||
|
# print(f"Dispatching event {event_name_upper} with data: {parsed_data} to {len(listeners)} listeners.")
|
||||||
|
for listener in listeners:
|
||||||
|
try:
|
||||||
|
# Inspect the listener to see how many arguments it expects
|
||||||
|
sig = inspect.signature(listener)
|
||||||
|
num_params = len(sig.parameters)
|
||||||
|
|
||||||
|
if num_params == 0: # Listener takes no arguments
|
||||||
|
await listener()
|
||||||
|
elif (
|
||||||
|
num_params == 1
|
||||||
|
): # Listener takes one argument (the parsed data or model)
|
||||||
|
await listener(parsed_data)
|
||||||
|
# elif num_params == 2 and event_name_upper == "MESSAGE_CREATE": # Special case for (client, message)
|
||||||
|
# await listener(self._client, parsed_data) # This might be too specific here
|
||||||
|
else:
|
||||||
|
# Fallback or error if signature doesn't match expected patterns
|
||||||
|
# For now, assume one arg is the most common for parsed data.
|
||||||
|
# Or, if you want to be strict:
|
||||||
|
print(
|
||||||
|
f"Warning: Listener {listener.__name__} for {event_name_upper} has an unhandled number of parameters ({num_params}). Skipping or attempting with one arg."
|
||||||
|
)
|
||||||
|
if num_params > 0: # Try with one arg if it takes any
|
||||||
|
await listener(parsed_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
callback = self.on_dispatch_error
|
||||||
|
if callback is not None:
|
||||||
|
try:
|
||||||
|
await callback(event_name_upper, e, listener)
|
||||||
|
|
||||||
|
except Exception as hook_error:
|
||||||
|
print(f"Error in on_dispatch_error hook itself: {hook_error}")
|
||||||
|
else:
|
||||||
|
# Default error handling if no hook is set
|
||||||
|
print(
|
||||||
|
f"Error in event listener {listener.__name__} for {event_name_upper}: {e}"
|
||||||
|
)
|
||||||
|
if hasattr(self._client, "on_error"):
|
||||||
|
try:
|
||||||
|
await self._client.on_error(event_name_upper, e, listener)
|
||||||
|
except Exception as client_err_e:
|
||||||
|
print(f"Error in client.on_error itself: {client_err_e}")
|
46
disagreement/ext/app_commands/__init__.py
Normal file
46
disagreement/ext/app_commands/__init__.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# disagreement/ext/app_commands/__init__.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Application Commands Extension for Disagreement.
|
||||||
|
|
||||||
|
This package provides the framework for creating and handling
|
||||||
|
Discord Application Commands (slash commands, user commands, message commands).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .commands import (
|
||||||
|
AppCommand,
|
||||||
|
SlashCommand,
|
||||||
|
UserCommand,
|
||||||
|
MessageCommand,
|
||||||
|
AppCommandGroup,
|
||||||
|
)
|
||||||
|
from .decorators import (
|
||||||
|
slash_command,
|
||||||
|
user_command,
|
||||||
|
message_command,
|
||||||
|
hybrid_command,
|
||||||
|
group,
|
||||||
|
subcommand,
|
||||||
|
subcommand_group,
|
||||||
|
OptionMetadata,
|
||||||
|
)
|
||||||
|
from .context import AppCommandContext
|
||||||
|
|
||||||
|
# from .handler import AppCommandHandler # Will be imported when defined
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AppCommand",
|
||||||
|
"SlashCommand",
|
||||||
|
"UserCommand",
|
||||||
|
"MessageCommand",
|
||||||
|
"AppCommandGroup", # To be defined
|
||||||
|
"slash_command",
|
||||||
|
"user_command",
|
||||||
|
"message_command",
|
||||||
|
"hybrid_command",
|
||||||
|
"group",
|
||||||
|
"subcommand",
|
||||||
|
"subcommand_group",
|
||||||
|
"OptionMetadata",
|
||||||
|
"AppCommandContext", # To be defined
|
||||||
|
]
|
513
disagreement/ext/app_commands/commands.py
Normal file
513
disagreement/ext/app_commands/commands.py
Normal file
@ -0,0 +1,513 @@
|
|||||||
|
# disagreement/ext/app_commands/commands.py
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Callable, Optional, List, Dict, Any, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from disagreement.ext.commands.core import (
|
||||||
|
Command as PrefixCommand,
|
||||||
|
) # Alias to avoid name clash
|
||||||
|
from disagreement.interactions import ApplicationCommandOption, Snowflake
|
||||||
|
from disagreement.enums import (
|
||||||
|
ApplicationCommandType,
|
||||||
|
IntegrationType,
|
||||||
|
InteractionContextType,
|
||||||
|
ApplicationCommandOptionType, # Added
|
||||||
|
)
|
||||||
|
from disagreement.ext.commands.cog import Cog # Corrected import path
|
||||||
|
|
||||||
|
# Placeholder for Cog if not using the existing one or if it needs adaptation
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
# This dynamic Cog = Any might not be ideal if Cog is used in runtime type checks.
|
||||||
|
# However, for type hinting purposes when TYPE_CHECKING is false, it avoids import.
|
||||||
|
# If Cog is needed at runtime by this module (it is, for AppCommand.cog type hint),
|
||||||
|
# it should be imported directly.
|
||||||
|
# For now, the TYPE_CHECKING block handles the proper import for static analysis.
|
||||||
|
# Let's ensure Cog is available at runtime if AppCommand.cog is accessed.
|
||||||
|
# A simple way is to import it outside TYPE_CHECKING too, if it doesn't cause circularity.
|
||||||
|
# Given its usage, a forward reference string 'Cog' might be better in AppCommand.cog type hint.
|
||||||
|
# Let's try importing it directly for runtime, assuming no circularity with this specific module.
|
||||||
|
try:
|
||||||
|
from disagreement.ext.commands.cog import Cog
|
||||||
|
except ImportError:
|
||||||
|
Cog = Any # Fallback if direct import fails (e.g. during partial builds/tests)
|
||||||
|
# Import PrefixCommand at runtime for HybridCommand
|
||||||
|
try:
|
||||||
|
from disagreement.ext.commands.core import Command as PrefixCommand
|
||||||
|
except Exception: # pragma: no cover - safeguard against unusual import issues
|
||||||
|
PrefixCommand = Any # type: ignore
|
||||||
|
# Import enums used at runtime
|
||||||
|
try:
|
||||||
|
from disagreement.enums import (
|
||||||
|
ApplicationCommandType,
|
||||||
|
IntegrationType,
|
||||||
|
InteractionContextType,
|
||||||
|
ApplicationCommandOptionType,
|
||||||
|
)
|
||||||
|
from disagreement.interactions import ApplicationCommandOption, Snowflake
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
ApplicationCommandType = ApplicationCommandOptionType = IntegrationType = (
|
||||||
|
InteractionContextType
|
||||||
|
) = Any # type: ignore
|
||||||
|
ApplicationCommandOption = Snowflake = Any # type: ignore
|
||||||
|
else: # When TYPE_CHECKING is true, Cog and PrefixCommand are already imported above.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AppCommand:
|
||||||
|
"""
|
||||||
|
Base class for an application command.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The name of the command.
|
||||||
|
description (Optional[str]): The description of the command.
|
||||||
|
Required for CHAT_INPUT, empty for USER and MESSAGE commands.
|
||||||
|
callback (Callable[..., Any]): The coroutine function that will be called when the command is executed.
|
||||||
|
type (ApplicationCommandType): The type of the application command.
|
||||||
|
options (Optional[List[ApplicationCommandOption]]): Parameters for the command. Populated by decorators.
|
||||||
|
guild_ids (Optional[List[Snowflake]]): List of guild IDs where this command is active. None for global.
|
||||||
|
default_member_permissions (Optional[str]): Bitwise permissions required by default for users to run the command.
|
||||||
|
nsfw (bool): Whether the command is age-restricted.
|
||||||
|
parent (Optional['AppCommandGroup']): The parent group if this is a subcommand.
|
||||||
|
cog (Optional[Cog]): The cog this command belongs to, if any.
|
||||||
|
_full_description (Optional[str]): Stores the original full description, e.g. from docstring,
|
||||||
|
even if the payload description is different (like for User/Message commands).
|
||||||
|
name_localizations (Optional[Dict[str, str]]): Localizations for the command's name.
|
||||||
|
description_localizations (Optional[Dict[str, str]]): Localizations for the command's description.
|
||||||
|
integration_types (Optional[List[IntegrationType]]): Installation contexts.
|
||||||
|
contexts (Optional[List[InteractionContextType]]): Interaction contexts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
callback: Callable[..., Any],
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
type: "ApplicationCommandType",
|
||||||
|
guild_ids: Optional[List["Snowflake"]] = None,
|
||||||
|
default_member_permissions: Optional[str] = None,
|
||||||
|
nsfw: bool = False,
|
||||||
|
parent: Optional["AppCommandGroup"] = None,
|
||||||
|
cog: Optional[
|
||||||
|
Any
|
||||||
|
] = None, # Changed 'Cog' to Any to avoid runtime import issues if Cog is complex
|
||||||
|
name_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
description_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
integration_types: Optional[List["IntegrationType"]] = None,
|
||||||
|
contexts: Optional[List["InteractionContextType"]] = None,
|
||||||
|
):
|
||||||
|
if not asyncio.iscoroutinefunction(callback):
|
||||||
|
raise TypeError(
|
||||||
|
"Application command callback must be a coroutine function."
|
||||||
|
)
|
||||||
|
|
||||||
|
if locale:
|
||||||
|
from disagreement import i18n
|
||||||
|
|
||||||
|
translate = i18n.translate
|
||||||
|
|
||||||
|
self.name = translate(name, locale)
|
||||||
|
self.description = (
|
||||||
|
translate(description, locale) if description is not None else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.locale: Optional[str] = locale
|
||||||
|
self.callback: Callable[..., Any] = callback
|
||||||
|
self.type: "ApplicationCommandType" = type
|
||||||
|
self.options: List["ApplicationCommandOption"] = [] # Populated by decorator
|
||||||
|
self.guild_ids: Optional[List["Snowflake"]] = guild_ids
|
||||||
|
self.default_member_permissions: Optional[str] = default_member_permissions
|
||||||
|
self.nsfw: bool = nsfw
|
||||||
|
self.parent: Optional["AppCommandGroup"] = parent
|
||||||
|
self.cog: Optional[Any] = cog # Changed 'Cog' to Any
|
||||||
|
self.name_localizations: Optional[Dict[str, str]] = name_localizations
|
||||||
|
self.description_localizations: Optional[Dict[str, str]] = (
|
||||||
|
description_localizations
|
||||||
|
)
|
||||||
|
self.integration_types: Optional[List["IntegrationType"]] = integration_types
|
||||||
|
self.contexts: Optional[List["InteractionContextType"]] = contexts
|
||||||
|
self._full_description: Optional[str] = (
|
||||||
|
None # Initialized by decorator if needed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Signature for argument parsing by decorators/handlers
|
||||||
|
self.params = inspect.signature(callback).parameters
|
||||||
|
|
||||||
|
async def invoke(
|
||||||
|
self, context: "AppCommandContext", *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Invokes the command's callback with the given context and arguments."""
|
||||||
|
# Similar to Command.invoke, handle cog if present
|
||||||
|
actual_args = []
|
||||||
|
if self.cog:
|
||||||
|
actual_args.append(self.cog)
|
||||||
|
actual_args.append(context)
|
||||||
|
actual_args.extend(args)
|
||||||
|
|
||||||
|
await self.callback(*actual_args, **kwargs)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Converts the command to a dictionary payload for Discord API."""
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"name": self.name,
|
||||||
|
"type": self.type.value,
|
||||||
|
# CHAT_INPUT commands require a description.
|
||||||
|
# USER and MESSAGE commands must have an empty description in the payload if not omitted.
|
||||||
|
# The constructor for UserCommand/MessageCommand already sets self.description to ""
|
||||||
|
"description": (
|
||||||
|
self.description
|
||||||
|
if self.type == ApplicationCommandType.CHAT_INPUT
|
||||||
|
else ""
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# For CHAT_INPUT commands, options are its parameters.
|
||||||
|
# For USER/MESSAGE commands, options should be empty or not present.
|
||||||
|
if self.type == ApplicationCommandType.CHAT_INPUT and self.options:
|
||||||
|
payload["options"] = [opt.to_dict() for opt in self.options]
|
||||||
|
|
||||||
|
if self.default_member_permissions is not None: # Can be "0" for no permissions
|
||||||
|
payload["default_member_permissions"] = str(self.default_member_permissions)
|
||||||
|
|
||||||
|
# nsfw defaults to False, only include if True
|
||||||
|
if self.nsfw:
|
||||||
|
payload["nsfw"] = True
|
||||||
|
|
||||||
|
if self.name_localizations:
|
||||||
|
payload["name_localizations"] = self.name_localizations
|
||||||
|
|
||||||
|
# Description localizations only apply if there's a description (CHAT_INPUT commands)
|
||||||
|
if (
|
||||||
|
self.type == ApplicationCommandType.CHAT_INPUT
|
||||||
|
and self.description
|
||||||
|
and self.description_localizations
|
||||||
|
):
|
||||||
|
payload["description_localizations"] = self.description_localizations
|
||||||
|
|
||||||
|
if self.integration_types:
|
||||||
|
payload["integration_types"] = [it.value for it in self.integration_types]
|
||||||
|
|
||||||
|
if self.contexts:
|
||||||
|
payload["contexts"] = [ict.value for ict in self.contexts]
|
||||||
|
|
||||||
|
# According to Discord API, guild_id is not part of this payload,
|
||||||
|
# it's used in the URL path for guild-specific command registration.
|
||||||
|
# However, the global command registration takes an 'application_id' in the payload,
|
||||||
|
# but that's handled by the HTTPClient.
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__} name='{self.name}' type={self.type!r}>"
|
||||||
|
|
||||||
|
|
||||||
|
class SlashCommand(AppCommand):
|
||||||
|
"""Represents a CHAT_INPUT (slash) command."""
|
||||||
|
|
||||||
|
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
|
||||||
|
if not kwargs.get("description"):
|
||||||
|
raise ValueError("SlashCommand requires a description.")
|
||||||
|
super().__init__(callback, type=ApplicationCommandType.CHAT_INPUT, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class UserCommand(AppCommand):
|
||||||
|
"""Represents a USER context menu command."""
|
||||||
|
|
||||||
|
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
|
||||||
|
# Description is not allowed by Discord API for User Commands, but can be set to empty string.
|
||||||
|
kwargs["description"] = kwargs.get(
|
||||||
|
"description", ""
|
||||||
|
) # Ensure it's empty or not present in payload
|
||||||
|
super().__init__(callback, type=ApplicationCommandType.USER, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageCommand(AppCommand):
|
||||||
|
"""Represents a MESSAGE context menu command."""
|
||||||
|
|
||||||
|
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
|
||||||
|
# Description is not allowed by Discord API for Message Commands.
|
||||||
|
kwargs["description"] = kwargs.get("description", "")
|
||||||
|
super().__init__(callback, type=ApplicationCommandType.MESSAGE, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class HybridCommand(SlashCommand, PrefixCommand): # Inherit from both
|
||||||
|
"""
|
||||||
|
Represents a command that can be invoked as both a slash command
|
||||||
|
and a traditional prefix-based command.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
|
||||||
|
# Initialize SlashCommand part (which calls AppCommand.__init__)
|
||||||
|
# We need to ensure 'type' is correctly passed for AppCommand
|
||||||
|
# kwargs for SlashCommand: name, description, guild_ids, default_member_permissions, nsfw, parent, cog, etc.
|
||||||
|
# kwargs for PrefixCommand: name, aliases, brief, description, cog
|
||||||
|
|
||||||
|
# Pop prefix-specific args before passing to SlashCommand constructor
|
||||||
|
prefix_aliases = kwargs.pop("aliases", [])
|
||||||
|
prefix_brief = kwargs.pop("brief", None)
|
||||||
|
# Description is used by both, AppCommand's constructor will handle it.
|
||||||
|
# Name is used by both. Cog is used by both.
|
||||||
|
|
||||||
|
# Call SlashCommand's __init__
|
||||||
|
# This will set up name, description, callback, type=CHAT_INPUT, options, etc.
|
||||||
|
super().__init__(callback, **kwargs) # This is SlashCommand.__init__
|
||||||
|
|
||||||
|
# Now, explicitly initialize the PrefixCommand parts that SlashCommand didn't cover
|
||||||
|
# or that need specific values for the prefix version.
|
||||||
|
# PrefixCommand.__init__(self, callback, name=self.name, aliases=prefix_aliases, brief=prefix_brief, description=self.description, cog=self.cog)
|
||||||
|
# However, PrefixCommand.__init__ also sets self.params, which AppCommand already did.
|
||||||
|
# We need to be careful not to re-initialize things unnecessarily or incorrectly.
|
||||||
|
# Let's manually set the distinct attributes for the PrefixCommand aspect.
|
||||||
|
|
||||||
|
# Attributes from PrefixCommand:
|
||||||
|
# self.callback is already set by AppCommand
|
||||||
|
# self.name is already set by AppCommand
|
||||||
|
self.aliases: List[str] = (
|
||||||
|
prefix_aliases # This was specific to HybridCommand before, now aligns with PrefixCommand
|
||||||
|
)
|
||||||
|
self.brief: Optional[str] = prefix_brief
|
||||||
|
# self.description is already set by AppCommand (SlashCommand ensures it exists)
|
||||||
|
# self.cog is already set by AppCommand
|
||||||
|
# self.params is already set by AppCommand
|
||||||
|
|
||||||
|
# Ensure the MRO is handled correctly. Python's MRO (C3 linearization)
|
||||||
|
# should call SlashCommand's __init__ then AppCommand's __init__.
|
||||||
|
# PrefixCommand.__init__ won't be called automatically unless we explicitly call it.
|
||||||
|
# By setting attributes directly, we avoid potential issues with multiple __init__ calls
|
||||||
|
# if their logic overlaps too much (e.g., both trying to set self.params).
|
||||||
|
|
||||||
|
# We might need to override invoke if the context or argument passing differs significantly
|
||||||
|
# between app command invocation and prefix command invocation.
|
||||||
|
# For now, SlashCommand.invoke and PrefixCommand.invoke are separate.
|
||||||
|
# The correct one will be called depending on how the command is dispatched.
|
||||||
|
# The AppCommandHandler will use AppCommand.invoke (via SlashCommand).
|
||||||
|
# The prefix CommandHandler will use PrefixCommand.invoke.
|
||||||
|
# This seems acceptable.
|
||||||
|
|
||||||
|
|
||||||
|
class AppCommandGroup:
|
||||||
|
"""
|
||||||
|
Represents a group of application commands (subcommands or subcommand groups).
|
||||||
|
This itself is not directly callable but acts as a namespace.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: Optional[
|
||||||
|
str
|
||||||
|
] = None, # Required for top-level groups that form part of a slash command
|
||||||
|
guild_ids: Optional[List["Snowflake"]] = None,
|
||||||
|
parent: Optional["AppCommandGroup"] = None,
|
||||||
|
default_member_permissions: Optional[str] = None,
|
||||||
|
nsfw: bool = False,
|
||||||
|
name_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
description_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
integration_types: Optional[List["IntegrationType"]] = None,
|
||||||
|
contexts: Optional[List["InteractionContextType"]] = None,
|
||||||
|
):
|
||||||
|
self.name: str = name
|
||||||
|
self.description: Optional[str] = description
|
||||||
|
self.guild_ids: Optional[List["Snowflake"]] = guild_ids
|
||||||
|
self.parent: Optional["AppCommandGroup"] = parent
|
||||||
|
self.commands: Dict[str, Union[AppCommand, "AppCommandGroup"]] = {}
|
||||||
|
self.default_member_permissions: Optional[str] = default_member_permissions
|
||||||
|
self.nsfw: bool = nsfw
|
||||||
|
self.name_localizations: Optional[Dict[str, str]] = name_localizations
|
||||||
|
self.description_localizations: Optional[Dict[str, str]] = (
|
||||||
|
description_localizations
|
||||||
|
)
|
||||||
|
self.integration_types: Optional[List["IntegrationType"]] = integration_types
|
||||||
|
self.contexts: Optional[List["InteractionContextType"]] = contexts
|
||||||
|
# A group itself doesn't have a cog directly, its commands do.
|
||||||
|
|
||||||
|
def add_command(self, command: Union[AppCommand, "AppCommandGroup"]) -> None:
|
||||||
|
if command.name in self.commands:
|
||||||
|
raise ValueError(
|
||||||
|
f"Command or group '{command.name}' already exists in group '{self.name}'."
|
||||||
|
)
|
||||||
|
command.parent = self
|
||||||
|
self.commands[command.name] = command
|
||||||
|
|
||||||
|
def get_command(self, name: str) -> Optional[Union[AppCommand, "AppCommandGroup"]]:
|
||||||
|
return self.commands.get(name)
|
||||||
|
|
||||||
|
def command(self, *d_args: Any, **d_kwargs: Any):
|
||||||
|
d_kwargs.setdefault("parent", self)
|
||||||
|
from .decorators import slash_command
|
||||||
|
|
||||||
|
return slash_command(*d_args, **d_kwargs)
|
||||||
|
|
||||||
|
def group(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
sub_group = AppCommandGroup(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
parent=self,
|
||||||
|
guild_ids=kwargs.get("guild_ids"),
|
||||||
|
default_member_permissions=kwargs.get("default_member_permissions"),
|
||||||
|
nsfw=kwargs.get("nsfw", False),
|
||||||
|
name_localizations=kwargs.get("name_localizations"),
|
||||||
|
description_localizations=kwargs.get("description_localizations"),
|
||||||
|
integration_types=kwargs.get("integration_types"),
|
||||||
|
contexts=kwargs.get("contexts"),
|
||||||
|
)
|
||||||
|
self.add_command(sub_group)
|
||||||
|
|
||||||
|
def decorator(func: Optional[Callable[..., Any]] = None):
|
||||||
|
if func is not None:
|
||||||
|
setattr(func, "__app_command_object__", sub_group)
|
||||||
|
return sub_group
|
||||||
|
return sub_group
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<AppCommandGroup name='{self.name}' commands={len(self.commands)}>"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Converts the command group to a dictionary payload for Discord API.
|
||||||
|
This represents a top-level command that has subcommands/subcommand groups.
|
||||||
|
"""
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"name": self.name,
|
||||||
|
"type": ApplicationCommandType.CHAT_INPUT.value, # Groups are implicitly CHAT_INPUT
|
||||||
|
"description": self.description
|
||||||
|
or "No description provided", # Top-level groups require a description
|
||||||
|
"options": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.default_member_permissions is not None:
|
||||||
|
payload["default_member_permissions"] = str(self.default_member_permissions)
|
||||||
|
if self.nsfw:
|
||||||
|
payload["nsfw"] = True
|
||||||
|
if self.name_localizations:
|
||||||
|
payload["name_localizations"] = self.name_localizations
|
||||||
|
if (
|
||||||
|
self.description and self.description_localizations
|
||||||
|
): # Only if description is not empty
|
||||||
|
payload["description_localizations"] = self.description_localizations
|
||||||
|
if self.integration_types:
|
||||||
|
payload["integration_types"] = [it.value for it in self.integration_types]
|
||||||
|
if self.contexts:
|
||||||
|
payload["contexts"] = [ict.value for ict in self.contexts]
|
||||||
|
|
||||||
|
# guild_ids are handled at the registration level, not in this specific payload part.
|
||||||
|
|
||||||
|
options_payload: List[Dict[str, Any]] = []
|
||||||
|
for cmd_name, command_or_group in self.commands.items():
|
||||||
|
if isinstance(command_or_group, AppCommand): # This is a Subcommand
|
||||||
|
# Subcommands use their own options (parameters)
|
||||||
|
sub_options = (
|
||||||
|
[opt.to_dict() for opt in command_or_group.options]
|
||||||
|
if command_or_group.options
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
option_dict = {
|
||||||
|
"type": ApplicationCommandOptionType.SUB_COMMAND.value,
|
||||||
|
"name": command_or_group.name,
|
||||||
|
"description": command_or_group.description
|
||||||
|
or "No description provided",
|
||||||
|
"options": sub_options,
|
||||||
|
}
|
||||||
|
# Add localization for subcommand name and description if available
|
||||||
|
if command_or_group.name_localizations:
|
||||||
|
option_dict["name_localizations"] = (
|
||||||
|
command_or_group.name_localizations
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
command_or_group.description
|
||||||
|
and command_or_group.description_localizations
|
||||||
|
):
|
||||||
|
option_dict["description_localizations"] = (
|
||||||
|
command_or_group.description_localizations
|
||||||
|
)
|
||||||
|
options_payload.append(option_dict)
|
||||||
|
|
||||||
|
elif isinstance(
|
||||||
|
command_or_group, AppCommandGroup
|
||||||
|
): # This is a Subcommand Group
|
||||||
|
# Subcommand groups have their subcommands/groups as options
|
||||||
|
sub_group_options: List[Dict[str, Any]] = []
|
||||||
|
for sub_cmd_name, sub_command in command_or_group.commands.items():
|
||||||
|
# Nested groups can only contain subcommands, not further nested groups as per Discord rules.
|
||||||
|
# So, sub_command here must be an AppCommand.
|
||||||
|
if isinstance(
|
||||||
|
sub_command, AppCommand
|
||||||
|
): # Should always be AppCommand if structure is valid
|
||||||
|
sub_cmd_options = (
|
||||||
|
[opt.to_dict() for opt in sub_command.options]
|
||||||
|
if sub_command.options
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
sub_group_option_entry = {
|
||||||
|
"type": ApplicationCommandOptionType.SUB_COMMAND.value,
|
||||||
|
"name": sub_command.name,
|
||||||
|
"description": sub_command.description
|
||||||
|
or "No description provided",
|
||||||
|
"options": sub_cmd_options,
|
||||||
|
}
|
||||||
|
# Add localization for subcommand name and description if available
|
||||||
|
if sub_command.name_localizations:
|
||||||
|
sub_group_option_entry["name_localizations"] = (
|
||||||
|
sub_command.name_localizations
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
sub_command.description
|
||||||
|
and sub_command.description_localizations
|
||||||
|
):
|
||||||
|
sub_group_option_entry["description_localizations"] = (
|
||||||
|
sub_command.description_localizations
|
||||||
|
)
|
||||||
|
sub_group_options.append(sub_group_option_entry)
|
||||||
|
# else:
|
||||||
|
# # This case implies a group nested inside a group, which then contains another group.
|
||||||
|
# # Discord's structure is:
|
||||||
|
# # command -> option (SUB_COMMAND_GROUP) -> option (SUB_COMMAND) -> option (param)
|
||||||
|
# # This should be caught by validation logic in decorators or add_command.
|
||||||
|
# # For now, we assume valid structure where AppCommandGroup's commands are AppCommands.
|
||||||
|
# pass
|
||||||
|
|
||||||
|
option_dict = {
|
||||||
|
"type": ApplicationCommandOptionType.SUB_COMMAND_GROUP.value,
|
||||||
|
"name": command_or_group.name,
|
||||||
|
"description": command_or_group.description
|
||||||
|
or "No description provided",
|
||||||
|
"options": sub_group_options, # These are the SUB_COMMANDs
|
||||||
|
}
|
||||||
|
# Add localization for subcommand group name and description if available
|
||||||
|
if command_or_group.name_localizations:
|
||||||
|
option_dict["name_localizations"] = (
|
||||||
|
command_or_group.name_localizations
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
command_or_group.description
|
||||||
|
and command_or_group.description_localizations
|
||||||
|
):
|
||||||
|
option_dict["description_localizations"] = (
|
||||||
|
command_or_group.description_localizations
|
||||||
|
)
|
||||||
|
options_payload.append(option_dict)
|
||||||
|
|
||||||
|
payload["options"] = options_payload
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
# Need to import asyncio for iscoroutinefunction check
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .context import AppCommandContext # For type hint in AppCommand.invoke
|
||||||
|
|
||||||
|
# Ensure ApplicationCommandOptionType is available for the to_dict method
|
||||||
|
from disagreement.enums import ApplicationCommandOptionType
|
556
disagreement/ext/app_commands/context.py
Normal file
556
disagreement/ext/app_commands/context.py
Normal file
@ -0,0 +1,556 @@
|
|||||||
|
# disagreement/ext/app_commands/context.py
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional, List, Union, Any, Dict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.interactions import (
|
||||||
|
Interaction,
|
||||||
|
InteractionCallbackData,
|
||||||
|
InteractionResponsePayload,
|
||||||
|
Snowflake,
|
||||||
|
)
|
||||||
|
from disagreement.enums import InteractionCallbackType, MessageFlags
|
||||||
|
from disagreement.models import (
|
||||||
|
User,
|
||||||
|
Member,
|
||||||
|
Message,
|
||||||
|
Channel,
|
||||||
|
ActionRow,
|
||||||
|
)
|
||||||
|
from disagreement.ui.view import View
|
||||||
|
|
||||||
|
# For full model hints, these would be imported from disagreement.models when defined:
|
||||||
|
Embed = Any
|
||||||
|
PartialAttachment = Any
|
||||||
|
Guild = Any # from disagreement.models import Guild
|
||||||
|
TextChannel = Any # from disagreement.models import TextChannel, etc.
|
||||||
|
from .commands import AppCommand
|
||||||
|
|
||||||
|
from disagreement.enums import InteractionCallbackType, MessageFlags
|
||||||
|
from disagreement.interactions import (
|
||||||
|
Interaction,
|
||||||
|
InteractionCallbackData,
|
||||||
|
InteractionResponsePayload,
|
||||||
|
Snowflake,
|
||||||
|
)
|
||||||
|
from disagreement.models import Message
|
||||||
|
from disagreement.typing import Typing
|
||||||
|
|
||||||
|
|
||||||
|
class AppCommandContext:
|
||||||
|
"""
|
||||||
|
Represents the context in which an application command is being invoked.
|
||||||
|
Provides methods to respond to the interaction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bot: "Client",
|
||||||
|
interaction: "Interaction",
|
||||||
|
command: Optional["AppCommand"] = None,
|
||||||
|
):
|
||||||
|
self.bot: "Client" = bot
|
||||||
|
self.interaction: "Interaction" = interaction
|
||||||
|
self.command: Optional["AppCommand"] = command # The command that was invoked
|
||||||
|
|
||||||
|
self._responded: bool = False
|
||||||
|
self._deferred: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token(self) -> str:
|
||||||
|
"""The interaction token."""
|
||||||
|
return self.interaction.token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interaction_id(self) -> "Snowflake":
|
||||||
|
"""The interaction ID."""
|
||||||
|
return self.interaction.id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def application_id(self) -> "Snowflake":
|
||||||
|
"""The application ID of the interaction."""
|
||||||
|
return self.interaction.application_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guild_id(self) -> Optional["Snowflake"]:
|
||||||
|
"""The ID of the guild where the interaction occurred, if any."""
|
||||||
|
return self.interaction.guild_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channel_id(self) -> Optional["Snowflake"]:
|
||||||
|
"""The ID of the channel where the interaction occurred."""
|
||||||
|
return self.interaction.channel_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def author(self) -> Optional[Union["User", "Member"]]:
|
||||||
|
"""The user or member who invoked the interaction."""
|
||||||
|
return self.interaction.member or self.interaction.user
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user(self) -> Optional["User"]:
|
||||||
|
"""The user who invoked the interaction.
|
||||||
|
If in a guild, this is the user part of the member.
|
||||||
|
If in a DM, this is the top-level user.
|
||||||
|
"""
|
||||||
|
return self.interaction.user
|
||||||
|
|
||||||
|
@property
|
||||||
|
def member(self) -> Optional["Member"]:
|
||||||
|
"""The member who invoked the interaction, if this occurred in a guild."""
|
||||||
|
return self.interaction.member
|
||||||
|
|
||||||
|
@property
|
||||||
|
def locale(self) -> Optional[str]:
|
||||||
|
"""The selected language of the invoking user."""
|
||||||
|
return self.interaction.locale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guild_locale(self) -> Optional[str]:
|
||||||
|
"""The guild's preferred language, if applicable."""
|
||||||
|
return self.interaction.guild_locale
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def guild(self) -> Optional["Guild"]:
|
||||||
|
"""The guild object where the interaction occurred, if available."""
|
||||||
|
|
||||||
|
if not self.guild_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
guild = None
|
||||||
|
if hasattr(self.bot, "get_guild"):
|
||||||
|
guild = self.bot.get_guild(self.guild_id)
|
||||||
|
|
||||||
|
if not guild and hasattr(self.bot, "fetch_guild"):
|
||||||
|
try:
|
||||||
|
guild = await self.bot.fetch_guild(self.guild_id)
|
||||||
|
except Exception:
|
||||||
|
guild = None
|
||||||
|
|
||||||
|
return guild
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def channel(self) -> Optional[Any]:
|
||||||
|
"""The channel object where the interaction occurred, if available."""
|
||||||
|
|
||||||
|
if not self.channel_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
channel = None
|
||||||
|
if hasattr(self.bot, "get_channel"):
|
||||||
|
channel = self.bot.get_channel(self.channel_id)
|
||||||
|
elif hasattr(self.bot, "_channels"):
|
||||||
|
channel = self.bot._channels.get(self.channel_id)
|
||||||
|
|
||||||
|
if not channel and hasattr(self.bot, "fetch_channel"):
|
||||||
|
try:
|
||||||
|
channel = await self.bot.fetch_channel(self.channel_id)
|
||||||
|
except Exception:
|
||||||
|
channel = None
|
||||||
|
|
||||||
|
return channel
|
||||||
|
|
||||||
|
async def _send_response(
|
||||||
|
self,
|
||||||
|
response_type: "InteractionCallbackType",
|
||||||
|
data: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Internal helper to send interaction responses."""
|
||||||
|
if (
|
||||||
|
self._responded
|
||||||
|
and not self._deferred
|
||||||
|
and response_type
|
||||||
|
!= InteractionCallbackType.APPLICATION_COMMAND_AUTOCOMPLETE_RESULT
|
||||||
|
):
|
||||||
|
# If already responded and not deferred, subsequent responses must be followups
|
||||||
|
# (unless it's an autocomplete result which is a special case)
|
||||||
|
# For now, let's assume followups are handled by separate methods.
|
||||||
|
# This logic might need refinement based on how followups are exposed.
|
||||||
|
raise RuntimeError(
|
||||||
|
"Interaction has already been responded to. Use send_followup()."
|
||||||
|
)
|
||||||
|
|
||||||
|
callback_data = InteractionCallbackData(data) if data else None
|
||||||
|
payload = InteractionResponsePayload(type=response_type, data=callback_data)
|
||||||
|
|
||||||
|
await self.bot._http.create_interaction_response(
|
||||||
|
interaction_id=self.interaction_id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
response_type
|
||||||
|
!= InteractionCallbackType.APPLICATION_COMMAND_AUTOCOMPLETE_RESULT
|
||||||
|
):
|
||||||
|
self._responded = True
|
||||||
|
if (
|
||||||
|
response_type
|
||||||
|
== InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE
|
||||||
|
or response_type == InteractionCallbackType.DEFERRED_UPDATE_MESSAGE
|
||||||
|
):
|
||||||
|
self._deferred = True
|
||||||
|
|
||||||
|
async def defer(self, ephemeral: bool = False, thinking: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Defers the interaction response.
|
||||||
|
|
||||||
|
This is typically used when your command might take longer than 3 seconds to process.
|
||||||
|
You must send a followup message within 15 minutes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ephemeral (bool): Whether the subsequent followup response should be ephemeral.
|
||||||
|
Only applicable if `thinking` is True.
|
||||||
|
thinking (bool): If True (default), responds with a "Bot is thinking..." message
|
||||||
|
(DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE).
|
||||||
|
If False, responds with DEFERRED_UPDATE_MESSAGE (for components).
|
||||||
|
"""
|
||||||
|
if self._responded:
|
||||||
|
raise RuntimeError("Interaction has already been responded to or deferred.")
|
||||||
|
|
||||||
|
response_type = (
|
||||||
|
InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE
|
||||||
|
if thinking
|
||||||
|
else InteractionCallbackType.DEFERRED_UPDATE_MESSAGE
|
||||||
|
)
|
||||||
|
data = None
|
||||||
|
if ephemeral and thinking:
|
||||||
|
data = {
|
||||||
|
"flags": MessageFlags.EPHEMERAL.value
|
||||||
|
} # Assuming MessageFlags enum exists
|
||||||
|
|
||||||
|
await self._send_response(response_type, data)
|
||||||
|
self._deferred = True # Mark as deferred
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
embed: Optional["Embed"] = None, # Convenience for single embed
|
||||||
|
embeds: Optional[List["Embed"]] = None,
|
||||||
|
*,
|
||||||
|
tts: bool = False,
|
||||||
|
files: Optional[List[Any]] = None,
|
||||||
|
components: Optional[List[ActionRow]] = None,
|
||||||
|
view: Optional[View] = None,
|
||||||
|
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||||
|
ephemeral: bool = False,
|
||||||
|
flags: Optional[int] = None,
|
||||||
|
) -> Optional[
|
||||||
|
"Message"
|
||||||
|
]: # Returns Message if not ephemeral and response was not deferred
|
||||||
|
"""
|
||||||
|
Sends a response to the interaction.
|
||||||
|
If the interaction was previously deferred, this will edit the original deferred response.
|
||||||
|
Otherwise, it sends an initial response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (Optional[str]): The message content.
|
||||||
|
embed (Optional[Embed]): A single embed to send. If `embeds` is also provided, this is ignored.
|
||||||
|
embeds (Optional[List[Embed]]): A list of embeds to send (max 10).
|
||||||
|
ephemeral (bool): Whether the message should be ephemeral (only visible to the invoker).
|
||||||
|
flags (Optional[int]): Additional message flags to apply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Message]: The sent message object if a new message was created and not ephemeral.
|
||||||
|
None if the response was ephemeral or an edit to a deferred message.
|
||||||
|
"""
|
||||||
|
if not self._responded and self._deferred: # Editing a deferred response
|
||||||
|
# Use edit_original_interaction_response
|
||||||
|
payload: Dict[str, Any] = {}
|
||||||
|
if content is not None:
|
||||||
|
payload["content"] = content
|
||||||
|
|
||||||
|
if tts:
|
||||||
|
payload["tts"] = True
|
||||||
|
|
||||||
|
actual_embeds = embeds
|
||||||
|
if embed and not embeds:
|
||||||
|
actual_embeds = [embed]
|
||||||
|
if actual_embeds:
|
||||||
|
payload["embeds"] = [e.to_dict() for e in actual_embeds]
|
||||||
|
|
||||||
|
if view:
|
||||||
|
await view._start(self.bot)
|
||||||
|
payload["components"] = view.to_components_payload()
|
||||||
|
elif components:
|
||||||
|
payload["components"] = [c.to_dict() for c in components]
|
||||||
|
|
||||||
|
if files is not None:
|
||||||
|
payload["attachments"] = [
|
||||||
|
f.to_dict() if hasattr(f, "to_dict") else f for f in files
|
||||||
|
]
|
||||||
|
|
||||||
|
if allowed_mentions is not None:
|
||||||
|
payload["allowed_mentions"] = allowed_mentions
|
||||||
|
|
||||||
|
# Flags (like ephemeral) cannot be set when editing the original deferred response this way.
|
||||||
|
# Ephemeral for deferred must be set during defer().
|
||||||
|
|
||||||
|
msg_data = await self.bot._http.edit_original_interaction_response(
|
||||||
|
application_id=self.application_id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
self._responded = True # Ensure it's marked as fully responded
|
||||||
|
if view and msg_data and "id" in msg_data:
|
||||||
|
view.message_id = msg_data["id"]
|
||||||
|
self.bot._views[msg_data["id"]] = view
|
||||||
|
# Construct and return Message object if needed, for now returns None for edits
|
||||||
|
return None
|
||||||
|
|
||||||
|
elif not self._responded: # Sending an initial response
|
||||||
|
data: Dict[str, Any] = {}
|
||||||
|
if content is not None:
|
||||||
|
data["content"] = content
|
||||||
|
|
||||||
|
if tts:
|
||||||
|
data["tts"] = True
|
||||||
|
|
||||||
|
actual_embeds = embeds
|
||||||
|
if embed and not embeds:
|
||||||
|
actual_embeds = [embed]
|
||||||
|
if actual_embeds:
|
||||||
|
data["embeds"] = [
|
||||||
|
e.to_dict() for e in actual_embeds
|
||||||
|
] # Assuming embeds have to_dict()
|
||||||
|
|
||||||
|
if view:
|
||||||
|
await view._start(self.bot)
|
||||||
|
data["components"] = view.to_components_payload()
|
||||||
|
elif components:
|
||||||
|
data["components"] = [c.to_dict() for c in components]
|
||||||
|
|
||||||
|
if files is not None:
|
||||||
|
data["attachments"] = [
|
||||||
|
f.to_dict() if hasattr(f, "to_dict") else f for f in files
|
||||||
|
]
|
||||||
|
|
||||||
|
if allowed_mentions is not None:
|
||||||
|
data["allowed_mentions"] = allowed_mentions
|
||||||
|
|
||||||
|
flags_value = 0
|
||||||
|
if ephemeral:
|
||||||
|
flags_value |= MessageFlags.EPHEMERAL.value
|
||||||
|
if flags:
|
||||||
|
flags_value |= flags
|
||||||
|
if flags_value:
|
||||||
|
data["flags"] = flags_value
|
||||||
|
|
||||||
|
await self._send_response(
|
||||||
|
InteractionCallbackType.CHANNEL_MESSAGE_WITH_SOURCE, data
|
||||||
|
)
|
||||||
|
|
||||||
|
if view and not ephemeral:
|
||||||
|
try:
|
||||||
|
msg_data = await self.bot._http.get_original_interaction_response(
|
||||||
|
application_id=self.application_id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
)
|
||||||
|
if msg_data and "id" in msg_data:
|
||||||
|
view.message_id = msg_data["id"]
|
||||||
|
self.bot._views[msg_data["id"]] = view
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not ephemeral:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# If already responded and not deferred, this should be a followup.
|
||||||
|
# This method is for initial response or editing deferred.
|
||||||
|
raise RuntimeError(
|
||||||
|
"Interaction has already been responded to. Use send_followup()."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_followup(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
embed: Optional["Embed"] = None,
|
||||||
|
embeds: Optional[List["Embed"]] = None,
|
||||||
|
*,
|
||||||
|
ephemeral: bool = False,
|
||||||
|
tts: bool = False,
|
||||||
|
files: Optional[List[Any]] = None,
|
||||||
|
components: Optional[List["ActionRow"]] = None,
|
||||||
|
view: Optional[View] = None,
|
||||||
|
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||||
|
flags: Optional[int] = None,
|
||||||
|
) -> Optional["Message"]:
|
||||||
|
"""
|
||||||
|
Sends a followup message to an interaction.
|
||||||
|
This can be used after an initial response or a deferred response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (Optional[str]): The message content.
|
||||||
|
embed (Optional[Embed]): A single embed to send.
|
||||||
|
embeds (Optional[List[Embed]]): A list of embeds to send.
|
||||||
|
ephemeral (bool): Whether the followup message should be ephemeral.
|
||||||
|
flags (Optional[int]): Additional message flags to apply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message: The sent followup message object.
|
||||||
|
"""
|
||||||
|
if not self._responded:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Must acknowledge or defer the interaction before sending a followup."
|
||||||
|
)
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {}
|
||||||
|
if content is not None:
|
||||||
|
payload["content"] = content
|
||||||
|
|
||||||
|
if tts:
|
||||||
|
payload["tts"] = True
|
||||||
|
|
||||||
|
actual_embeds = embeds
|
||||||
|
if embed and not embeds:
|
||||||
|
actual_embeds = [embed]
|
||||||
|
if actual_embeds:
|
||||||
|
payload["embeds"] = [
|
||||||
|
e.to_dict() for e in actual_embeds
|
||||||
|
] # Assuming embeds have to_dict()
|
||||||
|
|
||||||
|
if view:
|
||||||
|
await view._start(self.bot)
|
||||||
|
payload["components"] = view.to_components_payload()
|
||||||
|
elif components:
|
||||||
|
payload["components"] = [c.to_dict() for c in components]
|
||||||
|
|
||||||
|
if files is not None:
|
||||||
|
payload["attachments"] = [
|
||||||
|
f.to_dict() if hasattr(f, "to_dict") else f for f in files
|
||||||
|
]
|
||||||
|
|
||||||
|
if allowed_mentions is not None:
|
||||||
|
payload["allowed_mentions"] = allowed_mentions
|
||||||
|
|
||||||
|
flags_value = 0
|
||||||
|
if ephemeral:
|
||||||
|
flags_value |= MessageFlags.EPHEMERAL.value
|
||||||
|
if flags:
|
||||||
|
flags_value |= flags
|
||||||
|
if flags_value:
|
||||||
|
payload["flags"] = flags_value
|
||||||
|
|
||||||
|
# Followup messages are sent to a webhook endpoint
|
||||||
|
message_data = await self.bot._http.create_followup_message(
|
||||||
|
application_id=self.application_id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
if view and message_data and "id" in message_data:
|
||||||
|
view.message_id = message_data["id"]
|
||||||
|
self.bot._views[message_data["id"]] = view
|
||||||
|
from disagreement.models import Message # Ensure Message is available
|
||||||
|
|
||||||
|
return Message(data=message_data, client_instance=self.bot)
|
||||||
|
|
||||||
|
async def edit(
|
||||||
|
self,
|
||||||
|
message_id: "Snowflake" = "@original", # Defaults to editing the original response
|
||||||
|
content: Optional[str] = None,
|
||||||
|
embed: Optional["Embed"] = None,
|
||||||
|
embeds: Optional[List["Embed"]] = None,
|
||||||
|
*,
|
||||||
|
components: Optional[List["ActionRow"]] = None,
|
||||||
|
attachments: Optional[List[Any]] = None,
|
||||||
|
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Optional["Message"]:
|
||||||
|
"""
|
||||||
|
Edits a message previously sent in response to this interaction.
|
||||||
|
Can edit the original response or a followup message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id (Snowflake): The ID of the message to edit. Defaults to "@original"
|
||||||
|
to edit the initial interaction response.
|
||||||
|
content (Optional[str]): The new message content.
|
||||||
|
embed (Optional[Embed]): A single new embed.
|
||||||
|
embeds (Optional[List[Embed]]): A list of new embeds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Message]: The edited message object if available.
|
||||||
|
"""
|
||||||
|
if not self._responded:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot edit response if interaction hasn't been responded to or deferred."
|
||||||
|
)
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {}
|
||||||
|
if content is not None:
|
||||||
|
payload["content"] = content # Use None to clear
|
||||||
|
|
||||||
|
actual_embeds = embeds
|
||||||
|
if embed and not embeds:
|
||||||
|
actual_embeds = [embed]
|
||||||
|
if actual_embeds is not None: # Allow passing empty list to clear embeds
|
||||||
|
payload["embeds"] = [
|
||||||
|
e.to_dict() for e in actual_embeds
|
||||||
|
] # Assuming embeds have to_dict()
|
||||||
|
|
||||||
|
if components is not None:
|
||||||
|
payload["components"] = [c.to_dict() for c in components]
|
||||||
|
|
||||||
|
if attachments is not None:
|
||||||
|
payload["attachments"] = [
|
||||||
|
a.to_dict() if hasattr(a, "to_dict") else a for a in attachments
|
||||||
|
]
|
||||||
|
|
||||||
|
if allowed_mentions is not None:
|
||||||
|
payload["allowed_mentions"] = allowed_mentions
|
||||||
|
|
||||||
|
if message_id == "@original":
|
||||||
|
edited_message_data = (
|
||||||
|
await self.bot._http.edit_original_interaction_response(
|
||||||
|
application_id=self.application_id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
edited_message_data = await self.bot._http.edit_followup_message(
|
||||||
|
application_id=self.application_id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
message_id=message_id,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
# The HTTP methods used in tests return minimal data that is insufficient
|
||||||
|
# to construct a full ``Message`` instance, so we simply return ``None``
|
||||||
|
# rather than attempting to parse the response.
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete(self, message_id: "Snowflake" = "@original") -> None:
|
||||||
|
"""
|
||||||
|
Deletes a message previously sent in response to this interaction.
|
||||||
|
Can delete the original response or a followup message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id (Snowflake): The ID of the message to delete. Defaults to "@original"
|
||||||
|
to delete the initial interaction response.
|
||||||
|
"""
|
||||||
|
if not self._responded:
|
||||||
|
# If not responded, there's nothing to delete via this interaction's lifecycle.
|
||||||
|
# Deferral doesn't create a message to delete until a followup is sent.
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot delete response if interaction hasn't been responded to."
|
||||||
|
)
|
||||||
|
|
||||||
|
if message_id == "@original":
|
||||||
|
await self.bot._http.delete_original_interaction_response(
|
||||||
|
application_id=self.application_id, interaction_token=self.token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.bot._http.delete_followup_message(
|
||||||
|
application_id=self.application_id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
# After deleting the original response, further followups might be problematic.
|
||||||
|
# Discord docs: "Once the original message is deleted, you can no longer edit the message or send followups."
|
||||||
|
# Consider implications for context state.
|
||||||
|
|
||||||
|
def typing(self) -> Typing:
|
||||||
|
"""Return a typing context manager for this interaction's channel."""
|
||||||
|
|
||||||
|
if not self.channel_id:
|
||||||
|
raise RuntimeError("Cannot send typing indicator without a channel.")
|
||||||
|
return self.bot.typing(self.channel_id)
|
478
disagreement/ext/app_commands/converters.py
Normal file
478
disagreement/ext/app_commands/converters.py
Normal file
@ -0,0 +1,478 @@
|
|||||||
|
# disagreement/ext/app_commands/converters.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Converters for transforming application command option values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Protocol,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
)
|
||||||
|
from disagreement.enums import ApplicationCommandOptionType
|
||||||
|
from disagreement.errors import (
|
||||||
|
AppCommandOptionConversionError,
|
||||||
|
) # To be created in disagreement/errors.py
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from disagreement.interactions import Interaction # For context if needed
|
||||||
|
from disagreement.models import (
|
||||||
|
User,
|
||||||
|
Member,
|
||||||
|
Role,
|
||||||
|
Channel,
|
||||||
|
Attachment,
|
||||||
|
) # Discord models
|
||||||
|
from disagreement.client import Client # For fetching objects
|
||||||
|
|
||||||
|
T = TypeVar("T", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Converter(Protocol[T]):
|
||||||
|
"""
|
||||||
|
A protocol for classes that can convert an interaction option value to a specific type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> T:
|
||||||
|
"""
|
||||||
|
Converts the given value to the target type.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
interaction (Interaction): The interaction context.
|
||||||
|
value (Any): The raw value from the interaction option.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
T: The converted value.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AppCommandOptionConversionError: If conversion fails.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# Basic Type Converters
|
||||||
|
|
||||||
|
|
||||||
|
class StringConverter(Converter[str]):
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> str:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Expected a string, but got {type(value).__name__}: {value}"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class IntegerConverter(Converter[int]):
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> int:
|
||||||
|
if not isinstance(value, int):
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Expected an integer, but got {type(value).__name__}: {value}"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class BooleanConverter(Converter[bool]):
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> bool:
|
||||||
|
if not isinstance(value, bool):
|
||||||
|
if isinstance(value, str):
|
||||||
|
if value.lower() == "true":
|
||||||
|
return True
|
||||||
|
elif value.lower() == "false":
|
||||||
|
return False
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Expected a boolean, but got {type(value).__name__}: {value}"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class NumberConverter(Converter[float]): # Discord 'NUMBER' type is float
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> float:
|
||||||
|
if not isinstance(value, (int, float)):
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Expected a number (float), but got {type(value).__name__}: {value}"
|
||||||
|
)
|
||||||
|
return float(value) # Ensure it's a float even if int is passed
|
||||||
|
|
||||||
|
|
||||||
|
# Discord Model Converters
|
||||||
|
|
||||||
|
|
||||||
|
class UserConverter(Converter["User"]):
|
||||||
|
def __init__(self, client: "Client"):
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> "User":
|
||||||
|
if isinstance(value, str): # Assume it's a user ID
|
||||||
|
user_id = value
|
||||||
|
# Attempt to get from interaction resolved data first
|
||||||
|
if (
|
||||||
|
interaction.data
|
||||||
|
and interaction.data.resolved
|
||||||
|
and interaction.data.resolved.users
|
||||||
|
):
|
||||||
|
user_object = interaction.data.resolved.users.get(
|
||||||
|
user_id
|
||||||
|
) # This is already a User object
|
||||||
|
if user_object:
|
||||||
|
return user_object # Return the already parsed User object
|
||||||
|
|
||||||
|
# Fallback to fetching if not in resolved or if interaction has no resolved data
|
||||||
|
try:
|
||||||
|
user = await self._client.fetch_user(
|
||||||
|
user_id
|
||||||
|
) # fetch_user now also parses and caches
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"User with ID '{user_id}' not found.",
|
||||||
|
option_name="user",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
except Exception as e: # Catch potential HTTP errors from fetch_user
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Failed to fetch user '{user_id}': {e}",
|
||||||
|
option_name="user",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
isinstance(value, dict) and "id" in value
|
||||||
|
): # If it's raw user data dict (less common path now)
|
||||||
|
return self._client.parse_user(value) # parse_user handles dict -> User
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Expected a user ID string or user data dict, got {type(value).__name__}",
|
||||||
|
option_name="user",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemberConverter(Converter["Member"]):
|
||||||
|
def __init__(self, client: "Client"):
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> "Member":
|
||||||
|
if not interaction.guild_id:
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
"Cannot convert to Member outside of a guild context.",
|
||||||
|
option_name="member",
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(value, str): # Assume it's a user ID
|
||||||
|
member_id = value
|
||||||
|
# Attempt to get from interaction resolved data first
|
||||||
|
if (
|
||||||
|
interaction.data
|
||||||
|
and interaction.data.resolved
|
||||||
|
and interaction.data.resolved.members
|
||||||
|
):
|
||||||
|
# The Member object from resolved.members should already be correctly initialized
|
||||||
|
# by ResolvedData, including its User part.
|
||||||
|
member = interaction.data.resolved.members.get(member_id)
|
||||||
|
if member:
|
||||||
|
return (
|
||||||
|
member # Return the already resolved and parsed Member object
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback to fetching if not in resolved
|
||||||
|
try:
|
||||||
|
member = await self._client.fetch_member(
|
||||||
|
interaction.guild_id, member_id
|
||||||
|
)
|
||||||
|
if member:
|
||||||
|
return member
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Member with ID '{member_id}' not found in guild '{interaction.guild_id}'.",
|
||||||
|
option_name="member",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Failed to fetch member '{member_id}': {e}",
|
||||||
|
option_name="member",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
elif isinstance(value, dict) and "id" in value.get(
|
||||||
|
"user", {}
|
||||||
|
): # If it's already a member data dict
|
||||||
|
return self._client.parse_member(value, interaction.guild_id)
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Expected a member ID string or member data dict, got {type(value).__name__}",
|
||||||
|
option_name="member",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RoleConverter(Converter["Role"]):
|
||||||
|
def __init__(self, client: "Client"):
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> "Role":
|
||||||
|
if not interaction.guild_id:
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
"Cannot convert to Role outside of a guild context.", option_name="role"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(value, str): # Assume it's a role ID
|
||||||
|
role_id = value
|
||||||
|
# Attempt to get from interaction resolved data first
|
||||||
|
if (
|
||||||
|
interaction.data
|
||||||
|
and interaction.data.resolved
|
||||||
|
and interaction.data.resolved.roles
|
||||||
|
):
|
||||||
|
role_object = interaction.data.resolved.roles.get(
|
||||||
|
role_id
|
||||||
|
) # Should be a Role object
|
||||||
|
if role_object:
|
||||||
|
return role_object
|
||||||
|
|
||||||
|
# Fallback to fetching from guild if not in resolved
|
||||||
|
# This requires Client to have a fetch_role method or similar
|
||||||
|
try:
|
||||||
|
# Assuming Client.fetch_role(guild_id, role_id) will be implemented
|
||||||
|
role = await self._client.fetch_role(interaction.guild_id, role_id)
|
||||||
|
if role:
|
||||||
|
return role
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Role with ID '{role_id}' not found in guild '{interaction.guild_id}'.",
|
||||||
|
option_name="role",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Failed to fetch role '{role_id}': {e}",
|
||||||
|
option_name="role",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
isinstance(value, dict) and "id" in value
|
||||||
|
): # If it's already role data dict
|
||||||
|
if (
|
||||||
|
not interaction.guild_id
|
||||||
|
): # Should have been caught earlier, but as a safeguard
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
"Guild context is required to parse role data.",
|
||||||
|
option_name="role",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
return self._client.parse_role(value, interaction.guild_id)
|
||||||
|
# This path is reached if value is not a string (role ID) and not a dict (role data)
|
||||||
|
# or if it's a string but all fetching/lookup attempts failed.
|
||||||
|
# The final raise AppCommandOptionConversionError should be outside the if/elif for string values.
|
||||||
|
# If value was a string, an error should have been raised within the 'if isinstance(value, str)' block.
|
||||||
|
# If it wasn't a string or dict, this is the correct place to raise.
|
||||||
|
# The previous structure was slightly off, as the final raise was inside the string check block.
|
||||||
|
# Let's ensure the final raise is at the correct scope.
|
||||||
|
# The current structure seems to imply that if it's not a string, it must be a dict or error.
|
||||||
|
# If it's a string and all lookups fail, an error is raised within that block.
|
||||||
|
# If it's a dict and parsing fails (or guild_id missing), error raised.
|
||||||
|
# If it's neither, this final raise is correct.
|
||||||
|
# The "Function with declared return type "Role" must return value on all code paths"
|
||||||
|
# error suggests a path where no return or raise happens.
|
||||||
|
# This happens if `isinstance(value, str)` is true, but then all internal paths
|
||||||
|
# (resolved check, fetch try/except) don't lead to a return or raise *before*
|
||||||
|
# falling out of the `if isinstance(value, str)` block.
|
||||||
|
# The `raise AppCommandOptionConversionError` at the end of the `if isinstance(value, str)` block
|
||||||
|
# (line 156 in previous version) handles the case where a role ID is given but not found.
|
||||||
|
# The one at the very end (line 164 in previous) handles cases where value is not str/dict.
|
||||||
|
|
||||||
|
# Corrected structure for the final raise:
|
||||||
|
# It should be at the same level as the initial `if isinstance(value, str):`
|
||||||
|
# to catch cases where `value` is neither a str nor a dict.
|
||||||
|
# However, the current logic within the `if isinstance(value, str):` block
|
||||||
|
# ensures a raise if the role ID is not found.
|
||||||
|
# The `elif isinstance(value, dict)` handles the dict case.
|
||||||
|
# The final `raise` (line 164) is for types other than str or dict.
|
||||||
|
# The Pylance error "Function with declared return type "Role" must return value on all code paths"
|
||||||
|
# implies that if `value` is a string, and `interaction.data.resolved.roles.get(role_id)` is None,
|
||||||
|
# AND `self._client.fetch_role` returns None (which it can), then the
|
||||||
|
# `raise AppCommandOptionConversionError` on line 156 is correctly hit.
|
||||||
|
# The issue might be that Pylance doesn't see `AppCommandOptionConversionError` as definitively terminating.
|
||||||
|
# This is unlikely. Let's re-verify the logic flow.
|
||||||
|
|
||||||
|
# The `raise` on line 156 is correct if role_id is not found after fetching.
|
||||||
|
# The `raise` on line 164 is for when `value` is not a str and not a dict.
|
||||||
|
# This seems logically sound. The Pylance error might be a misinterpretation or a subtle issue.
|
||||||
|
# For now, the duplicated `except` is the primary syntax error.
|
||||||
|
# The "must return value on all code paths" often occurs if an if/elif chain doesn't
|
||||||
|
# exhaust all possibilities or if a path through a try/except doesn't guarantee a return/raise.
|
||||||
|
# In this case, if `value` is a string, it either returns a Role or raises an error.
|
||||||
|
# If `value` is a dict, it either returns a Role or raises an error.
|
||||||
|
# If `value` is neither, it raises an error. All paths seem covered.
|
||||||
|
# The syntax error from the duplicated `except` is the most likely culprit for Pylance's confusion.
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Expected a role ID string or role data dict, got {type(value).__name__}",
|
||||||
|
option_name="role",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConverter(Converter["Channel"]):
|
||||||
|
def __init__(self, client: "Client"):
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> "Channel":
|
||||||
|
if isinstance(value, str): # Assume it's a channel ID
|
||||||
|
channel_id = value
|
||||||
|
# Attempt to get from interaction resolved data first
|
||||||
|
if (
|
||||||
|
interaction.data
|
||||||
|
and interaction.data.resolved
|
||||||
|
and interaction.data.resolved.channels
|
||||||
|
):
|
||||||
|
# Resolved channels are PartialChannel. Client.fetch_channel will get the full typed one.
|
||||||
|
partial_channel = interaction.data.resolved.channels.get(channel_id)
|
||||||
|
if partial_channel:
|
||||||
|
# Client.fetch_channel should handle fetching and parsing to the correct Channel subtype
|
||||||
|
full_channel = await self._client.fetch_channel(partial_channel.id)
|
||||||
|
if full_channel:
|
||||||
|
return full_channel
|
||||||
|
# If fetch_channel returns None even with a resolved ID, it's an issue.
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Failed to fetch full channel for resolved ID '{channel_id}'.",
|
||||||
|
option_name="channel",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback to fetching directly if not in resolved or if resolved fetch failed
|
||||||
|
try:
|
||||||
|
channel = await self._client.fetch_channel(
|
||||||
|
channel_id
|
||||||
|
) # fetch_channel handles parsing
|
||||||
|
if channel:
|
||||||
|
return channel
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Channel with ID '{channel_id}' not found.",
|
||||||
|
option_name="channel",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Failed to fetch channel '{channel_id}': {e}",
|
||||||
|
option_name="channel",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
# Raw channel data dicts are not typically provided for slash command options.
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Expected a channel ID string, got {type(value).__name__}",
|
||||||
|
option_name="channel",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AttachmentConverter(Converter["Attachment"]):
|
||||||
|
def __init__(
|
||||||
|
self, client: "Client"
|
||||||
|
): # Client might be needed for future enhancements or consistency
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def convert(self, interaction: "Interaction", value: Any) -> "Attachment":
|
||||||
|
if isinstance(value, str): # Value is the attachment ID
|
||||||
|
attachment_id = value
|
||||||
|
if (
|
||||||
|
interaction.data
|
||||||
|
and interaction.data.resolved
|
||||||
|
and interaction.data.resolved.attachments
|
||||||
|
):
|
||||||
|
attachment_object = interaction.data.resolved.attachments.get(
|
||||||
|
attachment_id
|
||||||
|
) # This is already an Attachment object
|
||||||
|
if attachment_object:
|
||||||
|
return (
|
||||||
|
attachment_object # Return the already parsed Attachment object
|
||||||
|
)
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Attachment with ID '{attachment_id}' not found in resolved data.",
|
||||||
|
option_name="attachment",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"Expected an attachment ID string, got {type(value).__name__}",
|
||||||
|
option_name="attachment",
|
||||||
|
original_value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Converters can be registered dynamically using
|
||||||
|
# :meth:`disagreement.ext.app_commands.handler.AppCommandHandler.register_converter`.
|
||||||
|
|
||||||
|
# Mapping from ApplicationCommandOptionType to default converters
|
||||||
|
# This will be used by the AppCommandHandler to automatically apply converters
|
||||||
|
# if no explicit converter is specified for a command option's type hint.
|
||||||
|
DEFAULT_CONVERTERS: Dict[
|
||||||
|
ApplicationCommandOptionType, Callable[..., Converter[Any]]
|
||||||
|
] = { # Changed Callable signature
|
||||||
|
ApplicationCommandOptionType.STRING: StringConverter,
|
||||||
|
ApplicationCommandOptionType.INTEGER: IntegerConverter,
|
||||||
|
ApplicationCommandOptionType.BOOLEAN: BooleanConverter,
|
||||||
|
ApplicationCommandOptionType.NUMBER: NumberConverter,
|
||||||
|
ApplicationCommandOptionType.USER: UserConverter,
|
||||||
|
ApplicationCommandOptionType.CHANNEL: ChannelConverter,
|
||||||
|
ApplicationCommandOptionType.ROLE: RoleConverter,
|
||||||
|
# ApplicationCommandOptionType.MENTIONABLE: MentionableConverter, # Special case, can be User or Role
|
||||||
|
ApplicationCommandOptionType.ATTACHMENT: AttachmentConverter, # Added
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def run_converters(
|
||||||
|
interaction: "Interaction",
|
||||||
|
param_type: Any, # The type hint of the parameter
|
||||||
|
option_type: ApplicationCommandOptionType, # The Discord option type
|
||||||
|
value: Any,
|
||||||
|
client: "Client", # Needed for model converters
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Runs the appropriate converter for a given parameter type and value.
|
||||||
|
This function will be more complex, handling custom converters, unions, optionals etc.
|
||||||
|
For now, a basic lookup.
|
||||||
|
"""
|
||||||
|
converter_class_factory = DEFAULT_CONVERTERS.get(option_type)
|
||||||
|
if converter_class_factory:
|
||||||
|
# Check if the factory needs the client instance
|
||||||
|
# This is a bit simplistic; a more robust way might involve inspecting __init__ signature
|
||||||
|
# or having converters register their needs.
|
||||||
|
if option_type in [
|
||||||
|
ApplicationCommandOptionType.USER,
|
||||||
|
ApplicationCommandOptionType.CHANNEL, # Anticipating these
|
||||||
|
ApplicationCommandOptionType.ROLE,
|
||||||
|
ApplicationCommandOptionType.MENTIONABLE,
|
||||||
|
ApplicationCommandOptionType.ATTACHMENT,
|
||||||
|
]:
|
||||||
|
converter_instance = converter_class_factory(client=client)
|
||||||
|
else:
|
||||||
|
converter_instance = converter_class_factory()
|
||||||
|
|
||||||
|
return await converter_instance.convert(interaction, value)
|
||||||
|
|
||||||
|
# Fallback for unhandled types or if direct type matching is needed
|
||||||
|
if param_type is str and isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if param_type is int and isinstance(value, int):
|
||||||
|
return value
|
||||||
|
if param_type is bool and isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if param_type is float and isinstance(value, (float, int)):
|
||||||
|
return float(value)
|
||||||
|
|
||||||
|
# If no specific converter, and it's not a basic type match, raise error or return raw
|
||||||
|
# For now, let's raise if no converter found for a specific option type
|
||||||
|
if option_type in DEFAULT_CONVERTERS: # Should have been handled
|
||||||
|
pass # This path implies a logic error above or missing converter in DEFAULT_CONVERTERS
|
||||||
|
|
||||||
|
# If it's a model type but no converter yet, this will need to be handled
|
||||||
|
# e.g. if param_type is User and option_type is ApplicationCommandOptionType.USER
|
||||||
|
|
||||||
|
raise AppCommandOptionConversionError(
|
||||||
|
f"No suitable converter found for option type {option_type.name} "
|
||||||
|
f"with value '{value}' to target type {param_type.__name__ if hasattr(param_type, '__name__') else param_type}"
|
||||||
|
)
|
569
disagreement/ext/app_commands/decorators.py
Normal file
569
disagreement/ext/app_commands/decorators.py
Normal file
@ -0,0 +1,569 @@
|
|||||||
|
# disagreement/ext/app_commands/decorators.py
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Union,
|
||||||
|
Type,
|
||||||
|
get_origin,
|
||||||
|
get_args,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Literal,
|
||||||
|
Annotated,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .commands import (
|
||||||
|
SlashCommand,
|
||||||
|
UserCommand,
|
||||||
|
MessageCommand,
|
||||||
|
AppCommand,
|
||||||
|
AppCommandGroup,
|
||||||
|
HybridCommand,
|
||||||
|
)
|
||||||
|
from disagreement.interactions import (
|
||||||
|
ApplicationCommandOption,
|
||||||
|
ApplicationCommandOptionChoice,
|
||||||
|
Snowflake,
|
||||||
|
)
|
||||||
|
from disagreement.enums import (
|
||||||
|
ApplicationCommandOptionType,
|
||||||
|
IntegrationType,
|
||||||
|
InteractionContextType,
|
||||||
|
# Assuming ChannelType will be added to disagreement.enums
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from disagreement.client import Client # For potential future use
|
||||||
|
from disagreement.models import Channel, User
|
||||||
|
|
||||||
|
# Assuming TextChannel, VoiceChannel etc. might be defined or aliased
|
||||||
|
# For now, we'll use string comparisons for channel types or rely on a yet-to-be-defined ChannelType enum
|
||||||
|
Channel = Any
|
||||||
|
Member = Any
|
||||||
|
Role = Any
|
||||||
|
Attachment = Any
|
||||||
|
# from .cog import Cog # Placeholder
|
||||||
|
else:
|
||||||
|
# Runtime fallbacks for optional model classes
|
||||||
|
from disagreement.models import Channel
|
||||||
|
|
||||||
|
Client = Any # type: ignore
|
||||||
|
User = Any # type: ignore
|
||||||
|
Member = Any # type: ignore
|
||||||
|
Role = Any # type: ignore
|
||||||
|
Attachment = Any # type: ignore
|
||||||
|
|
||||||
|
# Mapping Python types to Discord ApplicationCommandOptionType
|
||||||
|
# This will need to be expanded and made more robust.
|
||||||
|
# Consider using a registry or a more sophisticated type mapping system.
|
||||||
|
_type_mapping: Dict[Any, ApplicationCommandOptionType] = (
|
||||||
|
{ # Changed Type to Any for key due to placeholders
|
||||||
|
str: ApplicationCommandOptionType.STRING,
|
||||||
|
int: ApplicationCommandOptionType.INTEGER,
|
||||||
|
bool: ApplicationCommandOptionType.BOOLEAN,
|
||||||
|
float: ApplicationCommandOptionType.NUMBER, # Discord 'NUMBER' type is for float/double
|
||||||
|
User: ApplicationCommandOptionType.USER,
|
||||||
|
Channel: ApplicationCommandOptionType.CHANNEL,
|
||||||
|
# Placeholders for actual model types from disagreement.models
|
||||||
|
# These will be resolved to their actual types when TYPE_CHECKING is False or via isinstance checks
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper dataclass for storing extra option metadata
|
||||||
|
@dataclass
|
||||||
|
class OptionMetadata:
|
||||||
|
channel_types: Optional[List[int]] = None
|
||||||
|
min_value: Optional[Union[int, float]] = None
|
||||||
|
max_value: Optional[Union[int, float]] = None
|
||||||
|
min_length: Optional[int] = None
|
||||||
|
max_length: Optional[int] = None
|
||||||
|
autocomplete: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# Ensure these are updated if model names/locations change
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# _type_mapping[User] = ApplicationCommandOptionType.USER # Already added above
|
||||||
|
_type_mapping[Member] = ApplicationCommandOptionType.USER # Member implies User
|
||||||
|
_type_mapping[Role] = ApplicationCommandOptionType.ROLE
|
||||||
|
_type_mapping[Attachment] = ApplicationCommandOptionType.ATTACHMENT
|
||||||
|
_type_mapping[Channel] = ApplicationCommandOptionType.CHANNEL
|
||||||
|
|
||||||
|
# TypeVar for the app command decorator factory
|
||||||
|
AppCmdType = TypeVar("AppCmdType", bound=AppCommand)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_options_from_signature(
|
||||||
|
func: Callable[..., Any], option_meta: Optional[Dict[str, OptionMetadata]] = None
|
||||||
|
) -> List[ApplicationCommandOption]:
|
||||||
|
"""
|
||||||
|
Inspects a function signature and generates ApplicationCommandOption list.
|
||||||
|
"""
|
||||||
|
options: List[ApplicationCommandOption] = []
|
||||||
|
params = inspect.signature(func).parameters
|
||||||
|
|
||||||
|
doc = inspect.getdoc(func)
|
||||||
|
param_descriptions: Dict[str, str] = {}
|
||||||
|
if doc:
|
||||||
|
for line in inspect.cleandoc(doc).splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith(":param"):
|
||||||
|
try:
|
||||||
|
_, rest = line.split(" ", 1)
|
||||||
|
name, desc = rest.split(":", 1)
|
||||||
|
param_descriptions[name.strip()] = desc.strip()
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip 'self' (for cogs) and 'ctx' (context) parameters
|
||||||
|
param_iter = iter(params.values())
|
||||||
|
first_param = next(param_iter, None)
|
||||||
|
|
||||||
|
# Heuristic: if the function is bound to a class (cog), 'self' might be the first param.
|
||||||
|
# A more robust way would be to check if `func` is a method of a Cog instance later.
|
||||||
|
# For now, simple name check.
|
||||||
|
if first_param and first_param.name == "self":
|
||||||
|
first_param = next(param_iter, None) # Consume 'self', get next
|
||||||
|
|
||||||
|
if first_param and first_param.name == "ctx": # Consume 'ctx'
|
||||||
|
pass # ctx is handled, now iterate over actual command options
|
||||||
|
elif (
|
||||||
|
first_param
|
||||||
|
): # If first_param was not 'self' and not 'ctx', it's a command option
|
||||||
|
param_iter = iter(params.values()) # Reset iterator to include the first param
|
||||||
|
|
||||||
|
for param in param_iter:
|
||||||
|
if param.name == "self" or param.name == "ctx": # Should have been skipped
|
||||||
|
continue
|
||||||
|
|
||||||
|
if param.kind == param.VAR_POSITIONAL or param.kind == param.VAR_KEYWORD:
|
||||||
|
# *args and **kwargs are not directly supported by slash command options structure.
|
||||||
|
# Could raise an error or ignore. For now, ignore.
|
||||||
|
# print(f"Warning: *args/**kwargs ({param.name}) are not supported for slash command options.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
option_name = param.name
|
||||||
|
option_description = param_descriptions.get(
|
||||||
|
option_name, f"Description for {option_name}"
|
||||||
|
)
|
||||||
|
meta = option_meta.get(option_name) if option_meta else None
|
||||||
|
|
||||||
|
param_type_hint = param.annotation
|
||||||
|
if param_type_hint == inspect.Parameter.empty:
|
||||||
|
# Default to string if no type hint, or raise error.
|
||||||
|
# Forcing type hints is generally better for slash commands.
|
||||||
|
# raise TypeError(f"Option '{option_name}' must have a type hint for slash commands.")
|
||||||
|
param_type_hint = str # Defaulting to string, can be made stricter
|
||||||
|
|
||||||
|
option_type: Optional[ApplicationCommandOptionType] = None
|
||||||
|
choices: Optional[List[ApplicationCommandOptionChoice]] = None
|
||||||
|
|
||||||
|
origin = get_origin(param_type_hint)
|
||||||
|
args = get_args(param_type_hint)
|
||||||
|
|
||||||
|
if origin is Annotated:
|
||||||
|
param_type_hint = args[0]
|
||||||
|
for extra in args[1:]:
|
||||||
|
if isinstance(extra, OptionMetadata):
|
||||||
|
meta = extra
|
||||||
|
origin = get_origin(param_type_hint)
|
||||||
|
args = get_args(param_type_hint)
|
||||||
|
|
||||||
|
actual_type_for_mapping = param_type_hint
|
||||||
|
is_optional = False
|
||||||
|
|
||||||
|
if origin is Union: # Handles Optional[T] which is Union[T, NoneType]
|
||||||
|
# Filter out NoneType to get the actual type for mapping
|
||||||
|
union_types = [t for t in args if t is not type(None)]
|
||||||
|
if len(union_types) == 1:
|
||||||
|
actual_type_for_mapping = union_types[0]
|
||||||
|
is_optional = True
|
||||||
|
else:
|
||||||
|
# More complex Unions are not directly supported by a single option type.
|
||||||
|
# Could default to STRING or raise.
|
||||||
|
# For now, let's assume simple Optional[T] or direct types.
|
||||||
|
# print(f"Warning: Complex Union type for '{option_name}' not fully supported, defaulting to STRING.")
|
||||||
|
actual_type_for_mapping = str
|
||||||
|
|
||||||
|
elif origin is list and len(args) == 1:
|
||||||
|
# List[T] is not a direct option type. Discord handles multiple values for some types
|
||||||
|
# via repeated options or specific component interactions, not directly in slash command options.
|
||||||
|
# This might indicate a need for a different interaction pattern or custom parsing.
|
||||||
|
# For now, treat List[str] as a string, others might error or default.
|
||||||
|
# print(f"Warning: List type for '{option_name}' not directly supported as a single option. Consider type {args[0]}.")
|
||||||
|
actual_type_for_mapping = args[
|
||||||
|
0
|
||||||
|
] # Use the inner type for mapping, but this is a simplification.
|
||||||
|
|
||||||
|
if origin is Literal: # typing.Literal['a', 'b']
|
||||||
|
choices = []
|
||||||
|
for choice_val in args:
|
||||||
|
if not isinstance(choice_val, (str, int, float)):
|
||||||
|
raise TypeError(
|
||||||
|
f"Literal choices for '{option_name}' must be str, int, or float. Got {type(choice_val)}."
|
||||||
|
)
|
||||||
|
choices.append(
|
||||||
|
ApplicationCommandOptionChoice(
|
||||||
|
data={"name": str(choice_val), "value": choice_val}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# The type of the Literal's arguments determines the option type
|
||||||
|
if choices:
|
||||||
|
literal_arg_type = type(choices[0].value)
|
||||||
|
option_type = _type_mapping.get(literal_arg_type)
|
||||||
|
if (
|
||||||
|
not option_type and literal_arg_type is float
|
||||||
|
): # float maps to NUMBER
|
||||||
|
option_type = ApplicationCommandOptionType.NUMBER
|
||||||
|
|
||||||
|
if not option_type: # If not determined by Literal
|
||||||
|
option_type = _type_mapping.get(actual_type_for_mapping)
|
||||||
|
# Special handling for User, Member, Role, Attachment, Channel if not directly in _type_mapping
|
||||||
|
# This is a bit crude; a proper registry or isinstance checks would be better.
|
||||||
|
if not option_type:
|
||||||
|
if (
|
||||||
|
actual_type_for_mapping.__name__ == "User"
|
||||||
|
or actual_type_for_mapping.__name__ == "Member"
|
||||||
|
):
|
||||||
|
option_type = ApplicationCommandOptionType.USER
|
||||||
|
elif actual_type_for_mapping.__name__ == "Role":
|
||||||
|
option_type = ApplicationCommandOptionType.ROLE
|
||||||
|
elif actual_type_for_mapping.__name__ == "Attachment":
|
||||||
|
option_type = ApplicationCommandOptionType.ATTACHMENT
|
||||||
|
elif (
|
||||||
|
inspect.isclass(actual_type_for_mapping)
|
||||||
|
and isinstance(Channel, type)
|
||||||
|
and issubclass(actual_type_for_mapping, cast(type, Channel))
|
||||||
|
):
|
||||||
|
option_type = ApplicationCommandOptionType.CHANNEL
|
||||||
|
|
||||||
|
if not option_type:
|
||||||
|
# Fallback or error if type couldn't be mapped
|
||||||
|
# print(f"Warning: Could not map type '{actual_type_for_mapping}' for option '{option_name}'. Defaulting to STRING.")
|
||||||
|
option_type = ApplicationCommandOptionType.STRING # Default fallback
|
||||||
|
|
||||||
|
required = (param.default == inspect.Parameter.empty) and not is_optional
|
||||||
|
|
||||||
|
data: Dict[str, Any] = {
|
||||||
|
"name": option_name,
|
||||||
|
"description": option_description,
|
||||||
|
"type": option_type.value,
|
||||||
|
"required": required,
|
||||||
|
"choices": ([c.to_dict() for c in choices] if choices else None),
|
||||||
|
}
|
||||||
|
|
||||||
|
if meta:
|
||||||
|
if meta.channel_types is not None:
|
||||||
|
data["channel_types"] = meta.channel_types
|
||||||
|
if meta.min_value is not None:
|
||||||
|
data["min_value"] = meta.min_value
|
||||||
|
if meta.max_value is not None:
|
||||||
|
data["max_value"] = meta.max_value
|
||||||
|
if meta.min_length is not None:
|
||||||
|
data["min_length"] = meta.min_length
|
||||||
|
if meta.max_length is not None:
|
||||||
|
data["max_length"] = meta.max_length
|
||||||
|
if meta.autocomplete:
|
||||||
|
data["autocomplete"] = True
|
||||||
|
|
||||||
|
options.append(ApplicationCommandOption(data=data))
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def _app_command_decorator(
|
||||||
|
cls: Type[AppCmdType],
|
||||||
|
option_meta: Optional[Dict[str, OptionMetadata]] = None,
|
||||||
|
**attrs: Any,
|
||||||
|
) -> Callable[[Callable[..., Any]], AppCmdType]:
|
||||||
|
"""Generic factory for creating app command decorators."""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Any]) -> AppCmdType:
|
||||||
|
if not asyncio.iscoroutinefunction(func):
|
||||||
|
raise TypeError(
|
||||||
|
"Application command callback must be a coroutine function."
|
||||||
|
)
|
||||||
|
|
||||||
|
name = attrs.pop("name", None) or func.__name__
|
||||||
|
description = attrs.pop("description", None) or inspect.getdoc(func)
|
||||||
|
if description: # Clean up docstring
|
||||||
|
description = inspect.cleandoc(description).split("\n\n", 1)[
|
||||||
|
0
|
||||||
|
] # Use first paragraph
|
||||||
|
|
||||||
|
parent_group = attrs.pop("parent", None)
|
||||||
|
if parent_group and not isinstance(parent_group, AppCommandGroup):
|
||||||
|
raise TypeError(
|
||||||
|
"The 'parent' argument must be an AppCommandGroup instance."
|
||||||
|
)
|
||||||
|
|
||||||
|
# For User/Message commands, description should be empty for payload, but can be stored for help.
|
||||||
|
if cls is UserCommand or cls is MessageCommand:
|
||||||
|
actual_description_for_payload = ""
|
||||||
|
else:
|
||||||
|
actual_description_for_payload = description
|
||||||
|
if not actual_description_for_payload and cls is SlashCommand:
|
||||||
|
raise ValueError(f"Slash command '{name}' must have a description.")
|
||||||
|
|
||||||
|
# Create the command instance
|
||||||
|
cmd_instance = cls(
|
||||||
|
callback=func,
|
||||||
|
name=name,
|
||||||
|
description=actual_description_for_payload, # Use payload-appropriate description
|
||||||
|
**attrs, # Remaining attributes like guild_ids, nsfw, etc.
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store original description if different (e.g. for User/Message commands for help text)
|
||||||
|
if description != actual_description_for_payload:
|
||||||
|
cmd_instance._full_description = (
|
||||||
|
description # Custom attribute for library use
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cmd_instance, SlashCommand):
|
||||||
|
cmd_instance.options = _extract_options_from_signature(func, option_meta)
|
||||||
|
|
||||||
|
if parent_group:
|
||||||
|
parent_group.add_command(cmd_instance) # This also sets cmd_instance.parent
|
||||||
|
|
||||||
|
# Attach command object to the function for later collection by Cog or Client
|
||||||
|
# This is a common pattern.
|
||||||
|
if hasattr(func, "__app_command_object__"):
|
||||||
|
# Function might already be decorated (e.g. hybrid or stacked decorators)
|
||||||
|
# Decide on behavior: error, overwrite, or store list of commands.
|
||||||
|
# For now, let's assume one app command decorator of a specific type per function.
|
||||||
|
# Hybrid commands will need special handling.
|
||||||
|
print(
|
||||||
|
f"Warning: Function {func.__name__} is already an app command or has one attached. Overwriting."
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(func, "__app_command_object__", cmd_instance)
|
||||||
|
setattr(cmd_instance, "__app_command_object__", cmd_instance)
|
||||||
|
|
||||||
|
# If the command is a HybridCommand, also set the attribute
|
||||||
|
# that the prefix command system's Cog._inject looks for.
|
||||||
|
if isinstance(cmd_instance, HybridCommand):
|
||||||
|
setattr(func, "__command_object__", cmd_instance)
|
||||||
|
setattr(cmd_instance, "__command_object__", cmd_instance)
|
||||||
|
|
||||||
|
return cmd_instance # Return the command instance itself, not the function
|
||||||
|
# This allows it to be added to cogs/handlers directly.
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def slash_command(
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
guild_ids: Optional[List[Snowflake]] = None,
|
||||||
|
default_member_permissions: Optional[str] = None,
|
||||||
|
nsfw: bool = False,
|
||||||
|
name_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
description_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
integration_types: Optional[List[IntegrationType]] = None,
|
||||||
|
contexts: Optional[List[InteractionContextType]] = None,
|
||||||
|
*,
|
||||||
|
guilds: bool = True,
|
||||||
|
dms: bool = True,
|
||||||
|
private_channels: bool = True,
|
||||||
|
parent: Optional[AppCommandGroup] = None, # Added parent parameter
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
option_meta: Optional[Dict[str, OptionMetadata]] = None,
|
||||||
|
) -> Callable[[Callable[..., Any]], SlashCommand]:
|
||||||
|
"""
|
||||||
|
Decorator to create a CHAT_INPUT (slash) command.
|
||||||
|
Options are inferred from the function's type hints.
|
||||||
|
"""
|
||||||
|
if contexts is None:
|
||||||
|
ctxs: List[InteractionContextType] = []
|
||||||
|
if guilds:
|
||||||
|
ctxs.append(InteractionContextType.GUILD)
|
||||||
|
if dms:
|
||||||
|
ctxs.append(InteractionContextType.BOT_DM)
|
||||||
|
if private_channels:
|
||||||
|
ctxs.append(InteractionContextType.PRIVATE_CHANNEL)
|
||||||
|
if len(ctxs) != 3:
|
||||||
|
contexts = ctxs
|
||||||
|
attrs = {
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"guild_ids": guild_ids,
|
||||||
|
"default_member_permissions": default_member_permissions,
|
||||||
|
"nsfw": nsfw,
|
||||||
|
"name_localizations": name_localizations,
|
||||||
|
"description_localizations": description_localizations,
|
||||||
|
"integration_types": integration_types,
|
||||||
|
"contexts": contexts,
|
||||||
|
"parent": parent, # Pass parent to attrs
|
||||||
|
"locale": locale,
|
||||||
|
}
|
||||||
|
# Filter out None values to avoid passing them as explicit None to command constructor
|
||||||
|
# Keep 'parent' even if None, as _app_command_decorator handles None parent.
|
||||||
|
# nsfw default is False, so it's fine if not present and defaults.
|
||||||
|
attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw", "parent"]}
|
||||||
|
return _app_command_decorator(SlashCommand, option_meta, **attrs)
|
||||||
|
|
||||||
|
|
||||||
|
def user_command(
|
||||||
|
name: Optional[str] = None,
|
||||||
|
guild_ids: Optional[List[Snowflake]] = None,
|
||||||
|
default_member_permissions: Optional[str] = None,
|
||||||
|
nsfw: bool = False, # Though less common for user commands
|
||||||
|
name_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
integration_types: Optional[List[IntegrationType]] = None,
|
||||||
|
contexts: Optional[List[InteractionContextType]] = None,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
# description is not used by Discord for User commands
|
||||||
|
) -> Callable[[Callable[..., Any]], UserCommand]:
|
||||||
|
"""Decorator to create a USER context menu command."""
|
||||||
|
attrs = {
|
||||||
|
"name": name,
|
||||||
|
"guild_ids": guild_ids,
|
||||||
|
"default_member_permissions": default_member_permissions,
|
||||||
|
"nsfw": nsfw,
|
||||||
|
"name_localizations": name_localizations,
|
||||||
|
"integration_types": integration_types,
|
||||||
|
"contexts": contexts,
|
||||||
|
"locale": locale,
|
||||||
|
}
|
||||||
|
attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw"]}
|
||||||
|
return _app_command_decorator(UserCommand, **attrs)
|
||||||
|
|
||||||
|
|
||||||
|
def message_command(
|
||||||
|
name: Optional[str] = None,
|
||||||
|
guild_ids: Optional[List[Snowflake]] = None,
|
||||||
|
default_member_permissions: Optional[str] = None,
|
||||||
|
nsfw: bool = False, # Though less common for message commands
|
||||||
|
name_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
integration_types: Optional[List[IntegrationType]] = None,
|
||||||
|
contexts: Optional[List[InteractionContextType]] = None,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
# description is not used by Discord for Message commands
|
||||||
|
) -> Callable[[Callable[..., Any]], MessageCommand]:
|
||||||
|
"""Decorator to create a MESSAGE context menu command."""
|
||||||
|
attrs = {
|
||||||
|
"name": name,
|
||||||
|
"guild_ids": guild_ids,
|
||||||
|
"default_member_permissions": default_member_permissions,
|
||||||
|
"nsfw": nsfw,
|
||||||
|
"name_localizations": name_localizations,
|
||||||
|
"integration_types": integration_types,
|
||||||
|
"contexts": contexts,
|
||||||
|
"locale": locale,
|
||||||
|
}
|
||||||
|
attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw"]}
|
||||||
|
return _app_command_decorator(MessageCommand, **attrs)
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_command(
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
guild_ids: Optional[List[Snowflake]] = None,
|
||||||
|
default_member_permissions: Optional[str] = None,
|
||||||
|
nsfw: bool = False,
|
||||||
|
name_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
description_localizations: Optional[Dict[str, str]] = None,
|
||||||
|
integration_types: Optional[List[IntegrationType]] = None,
|
||||||
|
contexts: Optional[List[InteractionContextType]] = None,
|
||||||
|
*,
|
||||||
|
guilds: bool = True,
|
||||||
|
dms: bool = True,
|
||||||
|
private_channels: bool = True,
|
||||||
|
aliases: Optional[List[str]] = None, # Specific to prefix command aspect
|
||||||
|
# Other prefix-specific options can be added here (e.g., help, brief)
|
||||||
|
option_meta: Optional[Dict[str, OptionMetadata]] = None,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
) -> Callable[[Callable[..., Any]], HybridCommand]:
|
||||||
|
"""
|
||||||
|
Decorator to create a command that can be invoked as both a slash command
|
||||||
|
and a traditional prefix-based command.
|
||||||
|
Options for the slash command part are inferred from the function's type hints.
|
||||||
|
"""
|
||||||
|
if contexts is None:
|
||||||
|
ctxs: List[InteractionContextType] = []
|
||||||
|
if guilds:
|
||||||
|
ctxs.append(InteractionContextType.GUILD)
|
||||||
|
if dms:
|
||||||
|
ctxs.append(InteractionContextType.BOT_DM)
|
||||||
|
if private_channels:
|
||||||
|
ctxs.append(InteractionContextType.PRIVATE_CHANNEL)
|
||||||
|
if len(ctxs) != 3:
|
||||||
|
contexts = ctxs
|
||||||
|
attrs = {
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"guild_ids": guild_ids,
|
||||||
|
"default_member_permissions": default_member_permissions,
|
||||||
|
"nsfw": nsfw,
|
||||||
|
"name_localizations": name_localizations,
|
||||||
|
"description_localizations": description_localizations,
|
||||||
|
"integration_types": integration_types,
|
||||||
|
"contexts": contexts,
|
||||||
|
"aliases": aliases or [], # Ensure aliases is a list
|
||||||
|
"locale": locale,
|
||||||
|
}
|
||||||
|
# Filter out None values to avoid passing them as explicit None to command constructor
|
||||||
|
# Keep 'nsfw' and 'aliases' as they have defaults (False, [])
|
||||||
|
attrs = {
|
||||||
|
k: v for k, v in attrs.items() if v is not None or k in ["nsfw", "aliases"]
|
||||||
|
}
|
||||||
|
return _app_command_decorator(HybridCommand, option_meta, **attrs)
|
||||||
|
|
||||||
|
|
||||||
|
def subcommand(
|
||||||
|
parent: AppCommandGroup, *d_args: Any, **d_kwargs: Any
|
||||||
|
) -> Callable[[Callable[..., Any]], SlashCommand]:
|
||||||
|
"""Create a subcommand under an existing :class:`AppCommandGroup`."""
|
||||||
|
|
||||||
|
d_kwargs.setdefault("parent", parent)
|
||||||
|
return slash_command(*d_args, **d_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def group(
|
||||||
|
name: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Callable[[Optional[Callable[..., Any]]], AppCommandGroup]:
|
||||||
|
"""Decorator to declare a top level :class:`AppCommandGroup`."""
|
||||||
|
|
||||||
|
def decorator(func: Optional[Callable[..., Any]] = None) -> AppCommandGroup:
|
||||||
|
grp = AppCommandGroup(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
guild_ids=kwargs.get("guild_ids"),
|
||||||
|
parent=kwargs.get("parent"),
|
||||||
|
default_member_permissions=kwargs.get("default_member_permissions"),
|
||||||
|
nsfw=kwargs.get("nsfw", False),
|
||||||
|
name_localizations=kwargs.get("name_localizations"),
|
||||||
|
description_localizations=kwargs.get("description_localizations"),
|
||||||
|
integration_types=kwargs.get("integration_types"),
|
||||||
|
contexts=kwargs.get("contexts"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
setattr(func, "__app_command_object__", grp)
|
||||||
|
return grp
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def subcommand_group(
|
||||||
|
parent: AppCommandGroup,
|
||||||
|
name: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Callable[[Optional[Callable[..., Any]]], AppCommandGroup]:
|
||||||
|
"""Create a nested :class:`AppCommandGroup` under ``parent``."""
|
||||||
|
|
||||||
|
return parent.group(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
**kwargs,
|
||||||
|
)
|
627
disagreement/ext/app_commands/handler.py
Normal file
627
disagreement/ext/app_commands/handler.py
Normal file
@ -0,0 +1,627 @@
|
|||||||
|
# disagreement/ext/app_commands/handler.py
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Any,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
get_origin,
|
||||||
|
get_args,
|
||||||
|
Literal,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.interactions import Interaction, ResolvedData, Snowflake
|
||||||
|
from disagreement.enums import (
|
||||||
|
ApplicationCommandType,
|
||||||
|
ApplicationCommandOptionType,
|
||||||
|
InteractionType,
|
||||||
|
)
|
||||||
|
from .commands import (
|
||||||
|
AppCommand,
|
||||||
|
SlashCommand,
|
||||||
|
UserCommand,
|
||||||
|
MessageCommand,
|
||||||
|
AppCommandGroup,
|
||||||
|
)
|
||||||
|
from .context import AppCommandContext
|
||||||
|
from disagreement.models import (
|
||||||
|
User,
|
||||||
|
Member,
|
||||||
|
Role,
|
||||||
|
Attachment,
|
||||||
|
Message,
|
||||||
|
) # For resolved data
|
||||||
|
|
||||||
|
# Channel models would also go here
|
||||||
|
|
||||||
|
# Placeholder for models not yet fully defined or imported
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
from disagreement.enums import (
|
||||||
|
ApplicationCommandType,
|
||||||
|
ApplicationCommandOptionType,
|
||||||
|
InteractionType,
|
||||||
|
)
|
||||||
|
from .commands import (
|
||||||
|
AppCommand,
|
||||||
|
SlashCommand,
|
||||||
|
UserCommand,
|
||||||
|
MessageCommand,
|
||||||
|
AppCommandGroup,
|
||||||
|
)
|
||||||
|
from .context import AppCommandContext
|
||||||
|
|
||||||
|
User = Any
|
||||||
|
Member = Any
|
||||||
|
Role = Any
|
||||||
|
Attachment = Any
|
||||||
|
Channel = Any
|
||||||
|
Message = Any
|
||||||
|
|
||||||
|
|
||||||
|
class AppCommandHandler:
|
||||||
|
"""
|
||||||
|
Manages application command registration, parsing, and dispatching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client: "Client"):
|
||||||
|
self.client: "Client" = client
|
||||||
|
# Store commands: key could be (name, type) for global, or (name, type, guild_id) for guild-specific
|
||||||
|
# For simplicity, let's start with a flat structure and refine if needed for guild commands.
|
||||||
|
# A more robust system might have separate dicts for global and guild commands.
|
||||||
|
self._slash_commands: Dict[str, SlashCommand] = {}
|
||||||
|
self._user_commands: Dict[str, UserCommand] = {}
|
||||||
|
self._message_commands: Dict[str, MessageCommand] = {}
|
||||||
|
self._app_command_groups: Dict[str, AppCommandGroup] = {}
|
||||||
|
self._converter_registry: Dict[type, type] = {}
|
||||||
|
|
||||||
|
def add_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
|
||||||
|
"""Adds an application command or a command group to the handler."""
|
||||||
|
if isinstance(command, AppCommandGroup):
|
||||||
|
if command.name in self._app_command_groups:
|
||||||
|
raise ValueError(
|
||||||
|
f"AppCommandGroup '{command.name}' is already registered."
|
||||||
|
)
|
||||||
|
self._app_command_groups[command.name] = command
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(command, SlashCommand):
|
||||||
|
if command.name in self._slash_commands:
|
||||||
|
raise ValueError(
|
||||||
|
f"SlashCommand '{command.name}' is already registered."
|
||||||
|
)
|
||||||
|
self._slash_commands[command.name] = command
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(command, UserCommand):
|
||||||
|
if command.name in self._user_commands:
|
||||||
|
raise ValueError(f"UserCommand '{command.name}' is already registered.")
|
||||||
|
self._user_commands[command.name] = command
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(command, MessageCommand):
|
||||||
|
if command.name in self._message_commands:
|
||||||
|
raise ValueError(
|
||||||
|
f"MessageCommand '{command.name}' is already registered."
|
||||||
|
)
|
||||||
|
self._message_commands[command.name] = command
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(command, AppCommand):
|
||||||
|
# Fallback for plain AppCommand objects
|
||||||
|
if command.type == ApplicationCommandType.CHAT_INPUT:
|
||||||
|
if command.name in self._slash_commands:
|
||||||
|
raise ValueError(
|
||||||
|
f"SlashCommand '{command.name}' is already registered."
|
||||||
|
)
|
||||||
|
self._slash_commands[command.name] = command # type: ignore
|
||||||
|
elif command.type == ApplicationCommandType.USER:
|
||||||
|
if command.name in self._user_commands:
|
||||||
|
raise ValueError(
|
||||||
|
f"UserCommand '{command.name}' is already registered."
|
||||||
|
)
|
||||||
|
self._user_commands[command.name] = command # type: ignore
|
||||||
|
elif command.type == ApplicationCommandType.MESSAGE:
|
||||||
|
if command.name in self._message_commands:
|
||||||
|
raise ValueError(
|
||||||
|
f"MessageCommand '{command.name}' is already registered."
|
||||||
|
)
|
||||||
|
self._message_commands[command.name] = command # type: ignore
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unsupported command type: {command.type} for '{command.name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError("Can only add AppCommand or AppCommandGroup instances.")
|
||||||
|
|
||||||
|
def remove_command(
|
||||||
|
self, name: str
|
||||||
|
) -> Optional[Union["AppCommand", "AppCommandGroup"]]:
|
||||||
|
"""Removes an application command or group by name."""
|
||||||
|
if name in self._slash_commands:
|
||||||
|
return self._slash_commands.pop(name)
|
||||||
|
if name in self._user_commands:
|
||||||
|
return self._user_commands.pop(name)
|
||||||
|
if name in self._message_commands:
|
||||||
|
return self._message_commands.pop(name)
|
||||||
|
if name in self._app_command_groups:
|
||||||
|
return self._app_command_groups.pop(name)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def register_converter(self, annotation: type, converter_cls: type) -> None:
|
||||||
|
"""Register a custom converter class for a type annotation."""
|
||||||
|
self._converter_registry[annotation] = converter_cls
|
||||||
|
|
||||||
|
def get_converter(self, annotation: type) -> Optional[type]:
|
||||||
|
"""Retrieve a registered converter class for a type annotation."""
|
||||||
|
return self._converter_registry.get(annotation)
|
||||||
|
|
||||||
|
def get_command(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
command_type: "ApplicationCommandType",
|
||||||
|
interaction_options: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
) -> Optional["AppCommand"]:
|
||||||
|
"""Retrieves a command of a specific type."""
|
||||||
|
if command_type == ApplicationCommandType.CHAT_INPUT:
|
||||||
|
if not interaction_options:
|
||||||
|
return self._slash_commands.get(name)
|
||||||
|
|
||||||
|
# Handle subcommands/groups
|
||||||
|
current_options = interaction_options
|
||||||
|
target_command_or_group: Optional[Union[AppCommand, AppCommandGroup]] = (
|
||||||
|
self._app_command_groups.get(name)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not target_command_or_group:
|
||||||
|
return self._slash_commands.get(name)
|
||||||
|
|
||||||
|
final_command: Optional[AppCommand] = None
|
||||||
|
|
||||||
|
while current_options:
|
||||||
|
opt_data = current_options[0]
|
||||||
|
opt_name = opt_data.get("name")
|
||||||
|
opt_type = (
|
||||||
|
ApplicationCommandOptionType(opt_data["type"])
|
||||||
|
if opt_data.get("type")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not opt_name or not isinstance(
|
||||||
|
target_command_or_group, AppCommandGroup
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
next_target = target_command_or_group.get_command(opt_name)
|
||||||
|
|
||||||
|
if isinstance(next_target, AppCommand) and (
|
||||||
|
opt_type == ApplicationCommandOptionType.SUB_COMMAND
|
||||||
|
or not opt_data.get("options")
|
||||||
|
):
|
||||||
|
final_command = next_target
|
||||||
|
break
|
||||||
|
elif (
|
||||||
|
isinstance(next_target, AppCommandGroup)
|
||||||
|
and opt_type == ApplicationCommandOptionType.SUB_COMMAND_GROUP
|
||||||
|
):
|
||||||
|
target_command_or_group = next_target
|
||||||
|
current_options = opt_data.get("options", [])
|
||||||
|
if not current_options:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return final_command
|
||||||
|
|
||||||
|
if command_type == ApplicationCommandType.USER:
|
||||||
|
return self._user_commands.get(name)
|
||||||
|
|
||||||
|
if command_type == ApplicationCommandType.MESSAGE:
|
||||||
|
return self._message_commands.get(name)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _resolve_option_value(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
expected_type: Any,
|
||||||
|
resolved_data: Optional["ResolvedData"],
|
||||||
|
guild_id: Optional["Snowflake"],
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Resolves an option value to the expected Python type using resolved_data.
|
||||||
|
"""
|
||||||
|
converter_cls = self.get_converter(expected_type)
|
||||||
|
if converter_cls:
|
||||||
|
try:
|
||||||
|
init_params = inspect.signature(converter_cls.__init__).parameters
|
||||||
|
if "client" in init_params:
|
||||||
|
converter_instance = converter_cls(client=self.client) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
converter_instance = converter_cls()
|
||||||
|
return await converter_instance.convert(None, value) # type: ignore[arg-type]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# This is a simplified resolver. A more robust one would use converters.
|
||||||
|
if resolved_data:
|
||||||
|
if expected_type is User or expected_type.__name__ == "User":
|
||||||
|
return resolved_data.users.get(value) if resolved_data.users else None
|
||||||
|
|
||||||
|
if expected_type is Member or expected_type.__name__ == "Member":
|
||||||
|
member_obj = (
|
||||||
|
resolved_data.members.get(value) if resolved_data.members else None
|
||||||
|
)
|
||||||
|
if member_obj:
|
||||||
|
if (
|
||||||
|
hasattr(member_obj, "username")
|
||||||
|
and not member_obj.username
|
||||||
|
and resolved_data.users
|
||||||
|
):
|
||||||
|
user_obj = resolved_data.users.get(value)
|
||||||
|
if user_obj:
|
||||||
|
member_obj.username = user_obj.username
|
||||||
|
member_obj.discriminator = user_obj.discriminator
|
||||||
|
member_obj.avatar = user_obj.avatar
|
||||||
|
member_obj.bot = user_obj.bot
|
||||||
|
member_obj.user = user_obj # type: ignore[attr-defined]
|
||||||
|
return member_obj
|
||||||
|
return None
|
||||||
|
if expected_type is Role or expected_type.__name__ == "Role":
|
||||||
|
return resolved_data.roles.get(value) if resolved_data.roles else None
|
||||||
|
if expected_type is Attachment or expected_type.__name__ == "Attachment":
|
||||||
|
return (
|
||||||
|
resolved_data.attachments.get(value)
|
||||||
|
if resolved_data.attachments
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if expected_type is Message or expected_type.__name__ == "Message":
|
||||||
|
return (
|
||||||
|
resolved_data.messages.get(value)
|
||||||
|
if resolved_data.messages
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if "Channel" in expected_type.__name__:
|
||||||
|
return (
|
||||||
|
resolved_data.channels.get(value)
|
||||||
|
if resolved_data.channels
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# For basic types, Discord already sends them correctly (string, int, bool, float)
|
||||||
|
if isinstance(value, expected_type):
|
||||||
|
return value
|
||||||
|
try: # Attempt direct conversion for basic types if Discord sent string for int/float/bool
|
||||||
|
if expected_type is int:
|
||||||
|
return int(value)
|
||||||
|
if expected_type is float:
|
||||||
|
return float(value)
|
||||||
|
if expected_type is bool: # Discord sends true/false
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.lower() == "true"
|
||||||
|
return bool(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass # Conversion failed
|
||||||
|
return value # Return as is if no specific resolution or conversion applied
|
||||||
|
|
||||||
|
async def _resolve_value(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
expected_type: Any,
|
||||||
|
resolved_data: Optional["ResolvedData"],
|
||||||
|
guild_id: Optional["Snowflake"],
|
||||||
|
) -> Any:
|
||||||
|
"""Public wrapper around ``_resolve_option_value`` used by tests."""
|
||||||
|
|
||||||
|
return await self._resolve_option_value(
|
||||||
|
value=value,
|
||||||
|
expected_type=expected_type,
|
||||||
|
resolved_data=resolved_data,
|
||||||
|
guild_id=guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _parse_interaction_options(
|
||||||
|
self,
|
||||||
|
command_params: Dict[str, inspect.Parameter], # From command.params
|
||||||
|
interaction_options: Optional[List[Dict[str, Any]]],
|
||||||
|
resolved_data: Optional["ResolvedData"],
|
||||||
|
guild_id: Optional["Snowflake"],
|
||||||
|
) -> Tuple[List[Any], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Parses options from an interaction payload and maps them to command function arguments.
|
||||||
|
"""
|
||||||
|
args_list: List[Any] = []
|
||||||
|
kwargs_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if not interaction_options: # No options provided in interaction
|
||||||
|
# Check if command has required params without defaults
|
||||||
|
for name, param in command_params.items():
|
||||||
|
if param.default == inspect.Parameter.empty:
|
||||||
|
# This should ideally be caught by Discord if option is marked required
|
||||||
|
raise ValueError(f"Missing required option: {name}")
|
||||||
|
return args_list, kwargs_dict
|
||||||
|
|
||||||
|
# Create a dictionary of provided options by name for easier lookup
|
||||||
|
provided_options: Dict[str, Any] = {
|
||||||
|
opt["name"]: opt["value"] for opt in interaction_options if "value" in opt
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, param in command_params.items():
|
||||||
|
if name in provided_options:
|
||||||
|
raw_value = provided_options[name]
|
||||||
|
expected_type = (
|
||||||
|
param.annotation
|
||||||
|
if param.annotation != inspect.Parameter.empty
|
||||||
|
else str
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle Optional[T]
|
||||||
|
origin_type = get_origin(expected_type)
|
||||||
|
if origin_type is Union:
|
||||||
|
union_args = get_args(expected_type)
|
||||||
|
# Assuming Optional[T] is Union[T, NoneType]
|
||||||
|
non_none_types = [t for t in union_args if t is not type(None)]
|
||||||
|
if len(non_none_types) == 1:
|
||||||
|
expected_type = non_none_types[0]
|
||||||
|
# Else, complex Union, might need more sophisticated handling or default to raw_value/str
|
||||||
|
elif origin_type is Literal:
|
||||||
|
literal_args = get_args(expected_type)
|
||||||
|
if literal_args:
|
||||||
|
expected_type = type(literal_args[0])
|
||||||
|
else:
|
||||||
|
expected_type = str
|
||||||
|
|
||||||
|
resolved_value = await self._resolve_option_value(
|
||||||
|
raw_value, expected_type, resolved_data, guild_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||||
|
or param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
):
|
||||||
|
kwargs_dict[name] = resolved_value
|
||||||
|
# Note: Slash commands don't map directly to *args. All options are named.
|
||||||
|
# So, we'll primarily use kwargs_dict and then construct args_list based on param order if needed,
|
||||||
|
# but Discord sends named options, so direct kwarg usage is more natural.
|
||||||
|
elif param.default != inspect.Parameter.empty:
|
||||||
|
kwargs_dict[name] = param.default
|
||||||
|
else:
|
||||||
|
# Required parameter not provided by Discord - this implies an issue with command definition
|
||||||
|
# or Discord's validation, as Discord should enforce required options.
|
||||||
|
raise ValueError(
|
||||||
|
f"Required option '{name}' not found in interaction payload."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Populate args_list based on the order in command_params for positional arguments
|
||||||
|
# This assumes that all args that are not keyword-only are passed positionally if present in kwargs_dict
|
||||||
|
for name, param in command_params.items():
|
||||||
|
if param.kind == inspect.Parameter.POSITIONAL_ONLY or (
|
||||||
|
param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
and name in kwargs_dict
|
||||||
|
):
|
||||||
|
if name in kwargs_dict: # Ensure it was resolved or had a default
|
||||||
|
args_list.append(kwargs_dict[name])
|
||||||
|
# If it was POSITIONAL_ONLY and not in kwargs_dict, it's an error (already raised)
|
||||||
|
elif param.kind == inspect.Parameter.VAR_POSITIONAL: # *args
|
||||||
|
# Slash commands don't map to *args well. This would be empty.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Filter kwargs_dict to only include actual KEYWORD_ONLY or POSITIONAL_OR_KEYWORD params
|
||||||
|
# that were not used for args_list (if strict positional/keyword separation is desired).
|
||||||
|
# For slash commands, it's simpler to pass all resolved named options as kwargs.
|
||||||
|
final_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in kwargs_dict.items()
|
||||||
|
if k in command_params
|
||||||
|
and command_params[k].kind != inspect.Parameter.POSITIONAL_ONLY
|
||||||
|
}
|
||||||
|
|
||||||
|
# For simplicity with slash commands, let's assume all resolved options are passed via kwargs
|
||||||
|
# and the command signature is primarily (self, ctx, **options) or (ctx, **options)
|
||||||
|
# or (self, ctx, option1, option2) where names match.
|
||||||
|
# The AppCommand.invoke will handle passing them.
|
||||||
|
# The current args_list and final_kwargs might be redundant if invoke just uses **final_kwargs.
|
||||||
|
# Let's return kwargs_dict directly for now, and AppCommand.invoke can map them.
|
||||||
|
|
||||||
|
return [], kwargs_dict # Return empty args, all in kwargs for now.
|
||||||
|
|
||||||
|
async def dispatch_app_command_error(
|
||||||
|
self, context: "AppCommandContext", error: Exception
|
||||||
|
) -> None:
|
||||||
|
"""Dispatches an app command error to the client if implemented."""
|
||||||
|
if hasattr(self.client, "on_app_command_error"):
|
||||||
|
await self.client.on_app_command_error(context, error)
|
||||||
|
|
||||||
|
async def process_interaction(self, interaction: "Interaction") -> None:
|
||||||
|
"""Processes an incoming interaction."""
|
||||||
|
if interaction.type == InteractionType.MODAL_SUBMIT:
|
||||||
|
callback = getattr(self.client, "on_modal_submit", None)
|
||||||
|
if callback is not None:
|
||||||
|
from typing import Awaitable, Callable, cast
|
||||||
|
|
||||||
|
await cast(Callable[["Interaction"], Awaitable[None]], callback)(
|
||||||
|
interaction
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if interaction.type == InteractionType.APPLICATION_COMMAND_AUTOCOMPLETE:
|
||||||
|
callback = getattr(self.client, "on_autocomplete", None)
|
||||||
|
if callback is not None:
|
||||||
|
from typing import Awaitable, Callable, cast
|
||||||
|
|
||||||
|
await cast(Callable[["Interaction"], Awaitable[None]], callback)(
|
||||||
|
interaction
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if interaction.type != InteractionType.APPLICATION_COMMAND:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not interaction.data or not interaction.data.name:
|
||||||
|
from .context import AppCommandContext
|
||||||
|
|
||||||
|
ctx = AppCommandContext(
|
||||||
|
bot=self.client, interaction=interaction, command=None
|
||||||
|
)
|
||||||
|
await ctx.send("Command not found.", ephemeral=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
command_name = interaction.data.name
|
||||||
|
command_type = interaction.data.type or ApplicationCommandType.CHAT_INPUT
|
||||||
|
command = self.get_command(
|
||||||
|
command_name,
|
||||||
|
command_type,
|
||||||
|
interaction.data.options if interaction.data else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not command:
|
||||||
|
from .context import AppCommandContext
|
||||||
|
|
||||||
|
ctx = AppCommandContext(
|
||||||
|
bot=self.client, interaction=interaction, command=None
|
||||||
|
)
|
||||||
|
await ctx.send(f"Command '{command_name}' not found.", ephemeral=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create context
|
||||||
|
from .context import AppCommandContext # Ensure AppCommandContext is available
|
||||||
|
|
||||||
|
ctx = AppCommandContext(
|
||||||
|
bot=self.client, interaction=interaction, command=command
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare arguments for the command callback
|
||||||
|
# Skip 'self' and 'ctx' from command.params for parsing interaction options
|
||||||
|
params_to_parse = {
|
||||||
|
name: param
|
||||||
|
for name, param in command.params.items()
|
||||||
|
if name not in ("self", "ctx")
|
||||||
|
}
|
||||||
|
|
||||||
|
if command.type in (
|
||||||
|
ApplicationCommandType.USER,
|
||||||
|
ApplicationCommandType.MESSAGE,
|
||||||
|
):
|
||||||
|
# Context menu commands provide a target_id. Resolve and pass it
|
||||||
|
args = []
|
||||||
|
kwargs = {}
|
||||||
|
if params_to_parse and interaction.data and interaction.data.target_id:
|
||||||
|
first_param = next(iter(params_to_parse.values()))
|
||||||
|
expected = (
|
||||||
|
first_param.annotation
|
||||||
|
if first_param.annotation != inspect.Parameter.empty
|
||||||
|
else str
|
||||||
|
)
|
||||||
|
resolved = await self._resolve_option_value(
|
||||||
|
interaction.data.target_id,
|
||||||
|
expected,
|
||||||
|
interaction.data.resolved,
|
||||||
|
interaction.guild_id,
|
||||||
|
)
|
||||||
|
if first_param.kind in (
|
||||||
|
inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
):
|
||||||
|
args.append(resolved)
|
||||||
|
else:
|
||||||
|
kwargs[first_param.name] = resolved
|
||||||
|
|
||||||
|
await command.invoke(ctx, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
parsed_args, parsed_kwargs = await self._parse_interaction_options(
|
||||||
|
command_params=params_to_parse,
|
||||||
|
interaction_options=interaction.data.options,
|
||||||
|
resolved_data=interaction.data.resolved,
|
||||||
|
guild_id=interaction.guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error invoking app command '{command.name}': {e}")
|
||||||
|
await self.dispatch_app_command_error(ctx, e)
|
||||||
|
# else:
|
||||||
|
# # Default error reply if no handler on client
|
||||||
|
# try:
|
||||||
|
# await ctx.send(f"An error occurred: {e}", ephemeral=True)
|
||||||
|
# except Exception as send_e:
|
||||||
|
# print(f"Failed to send error message for app command: {send_e}")
|
||||||
|
|
||||||
|
async def sync_commands(
|
||||||
|
self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Synchronizes (registers/updates) all application commands with Discord.
|
||||||
|
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
|
||||||
|
"""
|
||||||
|
commands_to_sync: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
# Collect commands based on scope (global or specific guild)
|
||||||
|
# This needs to be more sophisticated to handle guild_ids on commands/groups
|
||||||
|
|
||||||
|
source_commands = (
|
||||||
|
list(self._slash_commands.values())
|
||||||
|
+ list(self._user_commands.values())
|
||||||
|
+ list(self._message_commands.values())
|
||||||
|
+ list(self._app_command_groups.values())
|
||||||
|
)
|
||||||
|
|
||||||
|
for cmd_or_group in source_commands:
|
||||||
|
# Determine if this command/group should be synced for the current scope
|
||||||
|
is_guild_specific_command = (
|
||||||
|
cmd_or_group.guild_ids is not None and len(cmd_or_group.guild_ids) > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if guild_id: # Syncing for a specific guild
|
||||||
|
# Skip if not a guild-specific command OR if it's for a different guild
|
||||||
|
if not is_guild_specific_command or (
|
||||||
|
cmd_or_group.guild_ids is not None
|
||||||
|
and guild_id not in cmd_or_group.guild_ids
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
else: # Syncing global commands
|
||||||
|
if is_guild_specific_command:
|
||||||
|
continue # Skip guild-specific commands when syncing global
|
||||||
|
|
||||||
|
# Use the to_dict() method from AppCommand or AppCommandGroup
|
||||||
|
try:
|
||||||
|
payload = cmd_or_group.to_dict()
|
||||||
|
commands_to_sync.append(payload)
|
||||||
|
except AttributeError:
|
||||||
|
print(
|
||||||
|
f"Warning: Command or group '{cmd_or_group.name}' does not have a to_dict() method. Skipping."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Error converting command/group '{cmd_or_group.name}' to dict: {e}. Skipping."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not commands_to_sync:
|
||||||
|
print(
|
||||||
|
f"No commands to sync for {'guild ' + str(guild_id) if guild_id else 'global'} scope."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if guild_id:
|
||||||
|
print(
|
||||||
|
f"Syncing {len(commands_to_sync)} commands for guild {guild_id}..."
|
||||||
|
)
|
||||||
|
await self.client._http.bulk_overwrite_guild_application_commands(
|
||||||
|
application_id, guild_id, commands_to_sync
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"Syncing {len(commands_to_sync)} global commands...")
|
||||||
|
await self.client._http.bulk_overwrite_global_application_commands(
|
||||||
|
application_id, commands_to_sync
|
||||||
|
)
|
||||||
|
print("Command sync successful.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error syncing application commands: {e}")
|
||||||
|
# Consider re-raising or specific error handling
|
49
disagreement/ext/commands/__init__.py
Normal file
49
disagreement/ext/commands/__init__.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# disagreement/ext/commands/__init__.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
disagreement.ext.commands - A command framework extension for the Disagreement library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .cog import Cog
|
||||||
|
from .core import (
|
||||||
|
Command,
|
||||||
|
CommandContext,
|
||||||
|
CommandHandler,
|
||||||
|
) # CommandHandler might be internal
|
||||||
|
from .decorators import command, listener, check, check_any, cooldown
|
||||||
|
from .errors import (
|
||||||
|
CommandError,
|
||||||
|
CommandNotFound,
|
||||||
|
BadArgument,
|
||||||
|
MissingRequiredArgument,
|
||||||
|
ArgumentParsingError,
|
||||||
|
CheckFailure,
|
||||||
|
CheckAnyFailure,
|
||||||
|
CommandOnCooldown,
|
||||||
|
CommandInvokeError,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Cog
|
||||||
|
"Cog",
|
||||||
|
# Core
|
||||||
|
"Command",
|
||||||
|
"CommandContext",
|
||||||
|
# "CommandHandler", # Usually not part of public API for direct use by bot devs
|
||||||
|
# Decorators
|
||||||
|
"command",
|
||||||
|
"listener",
|
||||||
|
"check",
|
||||||
|
"check_any",
|
||||||
|
"cooldown",
|
||||||
|
# Errors
|
||||||
|
"CommandError",
|
||||||
|
"CommandNotFound",
|
||||||
|
"BadArgument",
|
||||||
|
"MissingRequiredArgument",
|
||||||
|
"ArgumentParsingError",
|
||||||
|
"CheckFailure",
|
||||||
|
"CheckAnyFailure",
|
||||||
|
"CommandOnCooldown",
|
||||||
|
"CommandInvokeError",
|
||||||
|
]
|
155
disagreement/ext/commands/cog.py
Normal file
155
disagreement/ext/commands/cog.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
# disagreement/ext/commands/cog.py
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import TYPE_CHECKING, List, Tuple, Callable, Awaitable, Any, Dict, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from disagreement.client import Client
|
||||||
|
from .core import Command
|
||||||
|
from disagreement.ext.app_commands.commands import (
|
||||||
|
AppCommand,
|
||||||
|
AppCommandGroup,
|
||||||
|
) # Added
|
||||||
|
else: # pragma: no cover - runtime imports for isinstance checks
|
||||||
|
from disagreement.ext.app_commands.commands import AppCommand, AppCommandGroup
|
||||||
|
|
||||||
|
# EventDispatcher might be needed if cogs register listeners directly
|
||||||
|
# from disagreement.event_dispatcher import EventDispatcher
|
||||||
|
|
||||||
|
|
||||||
|
class Cog:
|
||||||
|
"""
|
||||||
|
The base class for cogs, which are collections of commands and listeners.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client: "Client"):
|
||||||
|
self._client: "Client" = client
|
||||||
|
self._cog_name: str = self.__class__.__name__
|
||||||
|
self._commands: Dict[str, "Command"] = {}
|
||||||
|
self._listeners: List[Tuple[str, Callable[..., Awaitable[None]]]] = []
|
||||||
|
self._app_commands_and_groups: List[Union["AppCommand", "AppCommandGroup"]] = (
|
||||||
|
[]
|
||||||
|
) # Added
|
||||||
|
|
||||||
|
# Discover commands and listeners defined in this cog instance
|
||||||
|
self._inject()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> "Client":
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cog_name(self) -> str:
|
||||||
|
return self._cog_name
|
||||||
|
|
||||||
|
def _inject(self) -> None:
|
||||||
|
"""
|
||||||
|
Called to discover and prepare commands and listeners within this cog.
|
||||||
|
This is typically called by the CommandHandler when adding the cog.
|
||||||
|
"""
|
||||||
|
# Clear any previously injected state (e.g., if re-injecting)
|
||||||
|
self._commands.clear()
|
||||||
|
self._listeners.clear()
|
||||||
|
self._app_commands_and_groups.clear() # Added
|
||||||
|
|
||||||
|
for member_name, member in inspect.getmembers(self):
|
||||||
|
if hasattr(member, "__command_object__"):
|
||||||
|
# This is a prefix or hybrid command object
|
||||||
|
cmd: "Command" = getattr(member, "__command_object__")
|
||||||
|
cmd.cog = self # Assign the cog instance to the command
|
||||||
|
if cmd.name in self._commands:
|
||||||
|
# This should ideally be caught earlier or handled by CommandHandler
|
||||||
|
print(
|
||||||
|
f"Warning: Duplicate command name '{cmd.name}' in cog '{self.cog_name}'. Overwriting."
|
||||||
|
)
|
||||||
|
self._commands[cmd.name.lower()] = cmd
|
||||||
|
# Also register aliases
|
||||||
|
for alias in cmd.aliases:
|
||||||
|
self._commands[alias.lower()] = cmd
|
||||||
|
|
||||||
|
# If this command is also an application command (HybridCommand)
|
||||||
|
if isinstance(cmd, (AppCommand, AppCommandGroup)):
|
||||||
|
self._app_commands_and_groups.append(cmd)
|
||||||
|
|
||||||
|
elif hasattr(member, "__app_command_object__"): # Added for app commands
|
||||||
|
app_cmd_obj = getattr(member, "__app_command_object__")
|
||||||
|
if isinstance(app_cmd_obj, (AppCommand, AppCommandGroup)):
|
||||||
|
if isinstance(app_cmd_obj, AppCommand):
|
||||||
|
app_cmd_obj.cog = self # Associate cog
|
||||||
|
# For AppCommandGroup, its commands will have cog set individually if they are AppCommands
|
||||||
|
self._app_commands_and_groups.append(app_cmd_obj)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Warning: Member '{member_name}' in cog '{self.cog_name}' has '__app_command_object__' but it's not an AppCommand or AppCommandGroup."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(member, (AppCommand, AppCommandGroup)):
|
||||||
|
if isinstance(member, AppCommand):
|
||||||
|
member.cog = self
|
||||||
|
self._app_commands_and_groups.append(member)
|
||||||
|
|
||||||
|
elif hasattr(member, "__listener_name__"):
|
||||||
|
# This is a method decorated with @commands.Cog.listener or @commands.listener
|
||||||
|
if not inspect.iscoroutinefunction(member):
|
||||||
|
# Decorator should have caught this, but double check
|
||||||
|
print(
|
||||||
|
f"Warning: Listener '{member_name}' in cog '{self.cog_name}' is not a coroutine. Skipping."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
event_name: str = getattr(member, "__listener_name__")
|
||||||
|
# The callback needs to be the bound method from this cog instance
|
||||||
|
self._listeners.append((event_name, member))
|
||||||
|
|
||||||
|
def _eject(self) -> None:
|
||||||
|
"""
|
||||||
|
Called when the cog is being removed.
|
||||||
|
The CommandHandler will handle unregistering commands/listeners.
|
||||||
|
This method is for any cog-specific cleanup before that.
|
||||||
|
"""
|
||||||
|
# For now, just clear local collections. Actual unregistration is external.
|
||||||
|
self._commands.clear()
|
||||||
|
self._listeners.clear()
|
||||||
|
self._app_commands_and_groups.clear() # Added
|
||||||
|
|
||||||
|
def get_commands(self) -> List["Command"]:
|
||||||
|
"""Returns a list of commands in this cog."""
|
||||||
|
# Avoid duplicates if aliases point to the same command object
|
||||||
|
return list(dict.fromkeys(self._commands.values()))
|
||||||
|
|
||||||
|
def get_listeners(self) -> List[Tuple[str, Callable[..., Awaitable[None]]]]:
|
||||||
|
"""Returns a list of (event_name, callback) tuples for listeners in this cog."""
|
||||||
|
return self._listeners
|
||||||
|
|
||||||
|
def get_app_commands_and_groups(
|
||||||
|
self,
|
||||||
|
) -> List[Union["AppCommand", "AppCommandGroup"]]:
|
||||||
|
"""Returns a list of application commands and groups in this cog."""
|
||||||
|
return self._app_commands_and_groups
|
||||||
|
|
||||||
|
async def cog_load(self) -> None:
|
||||||
|
"""
|
||||||
|
A special method that is called when the cog is loaded.
|
||||||
|
This is a good place for any asynchronous setup.
|
||||||
|
Subclasses should override this if they need async setup.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def cog_unload(self) -> None:
|
||||||
|
"""
|
||||||
|
A special method that is called when the cog is unloaded.
|
||||||
|
This is a good place for any asynchronous cleanup.
|
||||||
|
Subclasses should override this if they need async cleanup.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Example of how a listener might be defined within a Cog using the decorator
|
||||||
|
# from .decorators import listener # Would be imported at module level
|
||||||
|
#
|
||||||
|
# @listener(name="ON_MESSAGE_CREATE_CUSTOM") # Explicit name
|
||||||
|
# async def on_my_event(self, message: 'Message'):
|
||||||
|
# print(f"Cog '{self.cog_name}' received event with message: {message.content}")
|
||||||
|
#
|
||||||
|
# @listener() # Name derived from method: on_ready
|
||||||
|
# async def on_ready(self):
|
||||||
|
# print(f"Cog '{self.cog_name}' is ready.")
|
175
disagreement/ext/commands/converters.py
Normal file
175
disagreement/ext/commands/converters.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
# disagreement/ext/commands/converters.py
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import re
|
||||||
|
|
||||||
|
from .errors import BadArgument
|
||||||
|
from disagreement.models import Member, Guild, Role
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .core import CommandContext
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Converter(ABC, Generic[T]):
|
||||||
|
"""
|
||||||
|
Base class for custom command argument converters.
|
||||||
|
Subclasses must implement the `convert` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> T:
|
||||||
|
"""
|
||||||
|
Converts the argument to the desired type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: The invocation context.
|
||||||
|
argument: The string argument to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The converted argument.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BadArgument: If the conversion fails.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Converter subclass must implement convert method.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Built-in Type Converters ---
|
||||||
|
|
||||||
|
|
||||||
|
class IntConverter(Converter[int]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> int:
|
||||||
|
try:
|
||||||
|
return int(argument)
|
||||||
|
except ValueError:
|
||||||
|
raise BadArgument(f"'{argument}' is not a valid integer.")
|
||||||
|
|
||||||
|
|
||||||
|
class FloatConverter(Converter[float]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> float:
|
||||||
|
try:
|
||||||
|
return float(argument)
|
||||||
|
except ValueError:
|
||||||
|
raise BadArgument(f"'{argument}' is not a valid number.")
|
||||||
|
|
||||||
|
|
||||||
|
class BoolConverter(Converter[bool]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> bool:
|
||||||
|
lowered = argument.lower()
|
||||||
|
if lowered in ("yes", "y", "true", "t", "1", "on", "enable", "enabled"):
|
||||||
|
return True
|
||||||
|
elif lowered in ("no", "n", "false", "f", "0", "off", "disable", "disabled"):
|
||||||
|
return False
|
||||||
|
raise BadArgument(f"'{argument}' is not a valid boolean-like value.")
|
||||||
|
|
||||||
|
|
||||||
|
class StringConverter(Converter[str]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> str:
|
||||||
|
# For basic string, no conversion is needed, but this provides a consistent interface
|
||||||
|
return argument
|
||||||
|
|
||||||
|
|
||||||
|
# --- Discord Model Converters ---
|
||||||
|
|
||||||
|
|
||||||
|
class MemberConverter(Converter["Member"]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> "Member":
|
||||||
|
if not ctx.message.guild_id:
|
||||||
|
raise BadArgument("Member converter requires guild context.")
|
||||||
|
|
||||||
|
match = re.match(r"<@!?(\d+)>$", argument)
|
||||||
|
member_id = match.group(1) if match else argument
|
||||||
|
|
||||||
|
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||||
|
if guild:
|
||||||
|
member = guild.get_member(member_id)
|
||||||
|
if member:
|
||||||
|
return member
|
||||||
|
|
||||||
|
member = await ctx.bot.fetch_member(ctx.message.guild_id, member_id)
|
||||||
|
if member:
|
||||||
|
return member
|
||||||
|
raise BadArgument(f"Member '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class RoleConverter(Converter["Role"]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> "Role":
|
||||||
|
if not ctx.message.guild_id:
|
||||||
|
raise BadArgument("Role converter requires guild context.")
|
||||||
|
|
||||||
|
match = re.match(r"<@&(?P<id>\d+)>$", argument)
|
||||||
|
role_id = match.group("id") if match else argument
|
||||||
|
|
||||||
|
guild = ctx.bot.get_guild(ctx.message.guild_id)
|
||||||
|
if guild:
|
||||||
|
role = guild.get_role(role_id)
|
||||||
|
if role:
|
||||||
|
return role
|
||||||
|
|
||||||
|
role = await ctx.bot.fetch_role(ctx.message.guild_id, role_id)
|
||||||
|
if role:
|
||||||
|
return role
|
||||||
|
raise BadArgument(f"Role '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class GuildConverter(Converter["Guild"]):
|
||||||
|
async def convert(self, ctx: "CommandContext", argument: str) -> "Guild":
|
||||||
|
guild_id = argument.strip("<>") # allow <id> style
|
||||||
|
|
||||||
|
guild = ctx.bot.get_guild(guild_id)
|
||||||
|
if guild:
|
||||||
|
return guild
|
||||||
|
|
||||||
|
guild = await ctx.bot.fetch_guild(guild_id)
|
||||||
|
if guild:
|
||||||
|
return guild
|
||||||
|
raise BadArgument(f"Guild '{argument}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
# Default converters mapping
|
||||||
|
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
|
||||||
|
int: IntConverter(),
|
||||||
|
float: FloatConverter(),
|
||||||
|
bool: BoolConverter(),
|
||||||
|
str: StringConverter(),
|
||||||
|
Member: MemberConverter(),
|
||||||
|
Guild: GuildConverter(),
|
||||||
|
Role: RoleConverter(),
|
||||||
|
# User: UserConverter(), # Add when User model and converter are ready
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def run_converters(ctx: "CommandContext", annotation: Any, argument: str) -> Any:
|
||||||
|
"""
|
||||||
|
Attempts to run a converter for the given annotation and argument.
|
||||||
|
"""
|
||||||
|
converter = DEFAULT_CONVERTERS.get(annotation)
|
||||||
|
if converter:
|
||||||
|
return await converter.convert(ctx, argument)
|
||||||
|
|
||||||
|
# If no direct converter, check if annotation itself is a Converter subclass
|
||||||
|
if inspect.isclass(annotation) and issubclass(annotation, Converter):
|
||||||
|
try:
|
||||||
|
instance = annotation() # type: ignore
|
||||||
|
return await instance.convert(ctx, argument)
|
||||||
|
except Exception as e: # Catch instantiation errors or other issues
|
||||||
|
raise BadArgument(
|
||||||
|
f"Failed to use custom converter {annotation.__name__}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If it's a custom class that's not a Converter, we can't handle it by default
|
||||||
|
# Or if it's a complex type hint like Union, Optional, Literal etc.
|
||||||
|
# This part needs more advanced logic for those.
|
||||||
|
|
||||||
|
# For now, if no specific converter, and it's not 'str', raise error or return as str?
|
||||||
|
# Let's be strict for now if an annotation is given but no converter found.
|
||||||
|
if annotation is not str and annotation is not inspect.Parameter.empty:
|
||||||
|
raise BadArgument(f"No converter found for type annotation '{annotation}'.")
|
||||||
|
|
||||||
|
return argument # Default to string if no annotation or annotation is str
|
||||||
|
|
||||||
|
|
||||||
|
# Need to import inspect for the run_converters function
|
||||||
|
import inspect
|
490
disagreement/ext/commands/core.py
Normal file
490
disagreement/ext/commands/core.py
Normal file
@ -0,0 +1,490 @@
|
|||||||
|
# disagreement/ext/commands/core.py
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Union,
|
||||||
|
Callable,
|
||||||
|
Awaitable,
|
||||||
|
Tuple,
|
||||||
|
get_origin,
|
||||||
|
get_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .view import StringView
|
||||||
|
from .errors import (
|
||||||
|
CommandError,
|
||||||
|
CommandNotFound,
|
||||||
|
BadArgument,
|
||||||
|
MissingRequiredArgument,
|
||||||
|
ArgumentParsingError,
|
||||||
|
CheckFailure,
|
||||||
|
CommandInvokeError,
|
||||||
|
)
|
||||||
|
from .converters import run_converters, DEFAULT_CONVERTERS, Converter
|
||||||
|
from .cog import Cog
|
||||||
|
from disagreement.typing import Typing
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.models import Message, User
|
||||||
|
|
||||||
|
|
||||||
|
class Command:
|
||||||
|
"""
|
||||||
|
Represents a bot command.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The primary name of the command.
|
||||||
|
callback (Callable[..., Awaitable[None]]): The coroutine function to execute.
|
||||||
|
aliases (List[str]): Alternative names for the command.
|
||||||
|
brief (Optional[str]): A short description for help commands.
|
||||||
|
description (Optional[str]): A longer description for help commands.
|
||||||
|
cog (Optional['Cog']): Reference to the Cog this command belongs to.
|
||||||
|
params (Dict[str, inspect.Parameter]): Cached parameters of the callback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
|
||||||
|
if not asyncio.iscoroutinefunction(callback):
|
||||||
|
raise TypeError("Command callback must be a coroutine function.")
|
||||||
|
|
||||||
|
self.callback: Callable[..., Awaitable[None]] = callback
|
||||||
|
self.name: str = attrs.get("name", callback.__name__)
|
||||||
|
self.aliases: List[str] = attrs.get("aliases", [])
|
||||||
|
self.brief: Optional[str] = attrs.get("brief")
|
||||||
|
self.description: Optional[str] = attrs.get("description") or callback.__doc__
|
||||||
|
self.cog: Optional["Cog"] = attrs.get("cog")
|
||||||
|
|
||||||
|
self.params = inspect.signature(callback).parameters
|
||||||
|
self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = []
|
||||||
|
if hasattr(callback, "__command_checks__"):
|
||||||
|
self.checks.extend(getattr(callback, "__command_checks__"))
|
||||||
|
|
||||||
|
def add_check(
|
||||||
|
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||||
|
) -> None:
|
||||||
|
self.checks.append(predicate)
|
||||||
|
|
||||||
|
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
|
||||||
|
from .errors import CheckFailure
|
||||||
|
|
||||||
|
for predicate in self.checks:
|
||||||
|
result = predicate(ctx)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
result = await result
|
||||||
|
if not result:
|
||||||
|
raise CheckFailure("Check predicate failed.")
|
||||||
|
|
||||||
|
if self.cog:
|
||||||
|
await self.callback(self.cog, ctx, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
await self.callback(ctx, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandContext:
|
||||||
|
"""
|
||||||
|
Represents the context in which a command is being invoked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
message: "Message",
|
||||||
|
bot: "Client",
|
||||||
|
prefix: str,
|
||||||
|
command: "Command",
|
||||||
|
invoked_with: str,
|
||||||
|
args: Optional[List[Any]] = None,
|
||||||
|
kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
cog: Optional["Cog"] = None,
|
||||||
|
):
|
||||||
|
self.message: "Message" = message
|
||||||
|
self.bot: "Client" = bot
|
||||||
|
self.prefix: str = prefix
|
||||||
|
self.command: "Command" = command
|
||||||
|
self.invoked_with: str = invoked_with
|
||||||
|
self.args: List[Any] = args or []
|
||||||
|
self.kwargs: Dict[str, Any] = kwargs or {}
|
||||||
|
self.cog: Optional["Cog"] = cog
|
||||||
|
|
||||||
|
self.author: "User" = message.author
|
||||||
|
|
||||||
|
async def reply(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
mention_author: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "Message":
|
||||||
|
"""Replies to the invoking message.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
content: str
|
||||||
|
The content to send.
|
||||||
|
mention_author: Optional[bool]
|
||||||
|
Whether to mention the author in the reply. If ``None`` the
|
||||||
|
client's :attr:`mention_replies` value is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
allowed_mentions = kwargs.pop("allowed_mentions", None)
|
||||||
|
if mention_author is None:
|
||||||
|
mention_author = getattr(self.bot, "mention_replies", False)
|
||||||
|
|
||||||
|
if allowed_mentions is None:
|
||||||
|
allowed_mentions = {"replied_user": mention_author}
|
||||||
|
else:
|
||||||
|
allowed_mentions = dict(allowed_mentions)
|
||||||
|
allowed_mentions.setdefault("replied_user", mention_author)
|
||||||
|
|
||||||
|
return await self.bot.send_message(
|
||||||
|
channel_id=self.message.channel_id,
|
||||||
|
content=content,
|
||||||
|
message_reference={
|
||||||
|
"message_id": self.message.id,
|
||||||
|
"channel_id": self.message.channel_id,
|
||||||
|
"guild_id": self.message.guild_id,
|
||||||
|
},
|
||||||
|
allowed_mentions=allowed_mentions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send(self, content: str, **kwargs: Any) -> "Message":
|
||||||
|
return await self.bot.send_message(
|
||||||
|
channel_id=self.message.channel_id, content=content, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def edit(
|
||||||
|
self,
|
||||||
|
message: Union[str, "Message"],
|
||||||
|
*,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "Message":
|
||||||
|
"""Edits a message previously sent by the bot."""
|
||||||
|
|
||||||
|
message_id = message if isinstance(message, str) else message.id
|
||||||
|
return await self.bot.edit_message(
|
||||||
|
channel_id=self.message.channel_id,
|
||||||
|
message_id=message_id,
|
||||||
|
content=content,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def typing(self) -> "Typing":
|
||||||
|
"""Return a typing context manager for this context's channel."""
|
||||||
|
|
||||||
|
return self.bot.typing(self.message.channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandHandler:
|
||||||
|
"""
|
||||||
|
Manages command registration, parsing, and dispatching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: "Client",
|
||||||
|
prefix: Union[
|
||||||
|
str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
|
||||||
|
],
|
||||||
|
):
|
||||||
|
self.client: "Client" = client
|
||||||
|
self.prefix: Union[
|
||||||
|
str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
|
||||||
|
] = prefix
|
||||||
|
self.commands: Dict[str, Command] = {}
|
||||||
|
self.cogs: Dict[str, "Cog"] = {}
|
||||||
|
|
||||||
|
from .help import HelpCommand
|
||||||
|
|
||||||
|
self.add_command(HelpCommand(self))
|
||||||
|
|
||||||
|
def add_command(self, command: Command) -> None:
|
||||||
|
if command.name in self.commands:
|
||||||
|
raise ValueError(f"Command '{command.name}' is already registered.")
|
||||||
|
|
||||||
|
self.commands[command.name.lower()] = command
|
||||||
|
for alias in command.aliases:
|
||||||
|
if alias in self.commands:
|
||||||
|
print(
|
||||||
|
f"Warning: Alias '{alias}' for command '{command.name}' conflicts with an existing command or alias."
|
||||||
|
)
|
||||||
|
self.commands[alias.lower()] = command
|
||||||
|
|
||||||
|
def remove_command(self, name: str) -> Optional[Command]:
|
||||||
|
command = self.commands.pop(name.lower(), None)
|
||||||
|
if command:
|
||||||
|
for alias in command.aliases:
|
||||||
|
self.commands.pop(alias.lower(), None)
|
||||||
|
return command
|
||||||
|
|
||||||
|
def get_command(self, name: str) -> Optional[Command]:
|
||||||
|
return self.commands.get(name.lower())
|
||||||
|
|
||||||
|
def add_cog(self, cog_to_add: "Cog") -> None:
|
||||||
|
if not isinstance(cog_to_add, Cog):
|
||||||
|
raise TypeError("Argument must be a subclass of Cog.")
|
||||||
|
|
||||||
|
if cog_to_add.cog_name in self.cogs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cog with name '{cog_to_add.cog_name}' is already registered."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cogs[cog_to_add.cog_name] = cog_to_add
|
||||||
|
|
||||||
|
for cmd in cog_to_add.get_commands():
|
||||||
|
self.add_command(cmd)
|
||||||
|
|
||||||
|
if hasattr(self.client, "_event_dispatcher"):
|
||||||
|
for event_name, callback in cog_to_add.get_listeners():
|
||||||
|
self.client._event_dispatcher.register(event_name.upper(), callback)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Warning: Client does not have '_event_dispatcher'. Listeners for cog '{cog_to_add.cog_name}' not registered."
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(cog_to_add, "cog_load") and inspect.iscoroutinefunction(
|
||||||
|
cog_to_add.cog_load
|
||||||
|
):
|
||||||
|
asyncio.create_task(cog_to_add.cog_load())
|
||||||
|
|
||||||
|
print(f"Cog '{cog_to_add.cog_name}' added.")
|
||||||
|
|
||||||
|
def remove_cog(self, cog_name: str) -> Optional["Cog"]:
|
||||||
|
cog_to_remove = self.cogs.pop(cog_name, None)
|
||||||
|
if cog_to_remove:
|
||||||
|
for cmd in cog_to_remove.get_commands():
|
||||||
|
self.remove_command(cmd.name)
|
||||||
|
|
||||||
|
if hasattr(self.client, "_event_dispatcher"):
|
||||||
|
for event_name, callback in cog_to_remove.get_listeners():
|
||||||
|
print(
|
||||||
|
f"Note: Listener '{callback.__name__}' for event '{event_name}' from cog '{cog_name}' needs manual unregistration logic in EventDispatcher."
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(cog_to_remove, "cog_unload") and inspect.iscoroutinefunction(
|
||||||
|
cog_to_remove.cog_unload
|
||||||
|
):
|
||||||
|
asyncio.create_task(cog_to_remove.cog_unload())
|
||||||
|
|
||||||
|
cog_to_remove._eject()
|
||||||
|
print(f"Cog '{cog_name}' removed.")
|
||||||
|
return cog_to_remove
|
||||||
|
|
||||||
|
async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
|
||||||
|
if callable(self.prefix):
|
||||||
|
if inspect.iscoroutinefunction(self.prefix):
|
||||||
|
return await self.prefix(self.client, message)
|
||||||
|
else:
|
||||||
|
return self.prefix(self.client, message) # type: ignore
|
||||||
|
return self.prefix
|
||||||
|
|
||||||
|
async def _parse_arguments(
|
||||||
|
self, command: Command, ctx: CommandContext, view: StringView
|
||||||
|
) -> Tuple[List[Any], Dict[str, Any]]:
|
||||||
|
args_list = []
|
||||||
|
kwargs_dict = {}
|
||||||
|
params_to_parse = list(command.params.values())
|
||||||
|
|
||||||
|
if params_to_parse and params_to_parse[0].name == "self" and command.cog:
|
||||||
|
params_to_parse.pop(0)
|
||||||
|
if params_to_parse and params_to_parse[0].name == "ctx":
|
||||||
|
params_to_parse.pop(0)
|
||||||
|
|
||||||
|
for param in params_to_parse:
|
||||||
|
view.skip_whitespace()
|
||||||
|
final_value_for_param: Any = inspect.Parameter.empty
|
||||||
|
|
||||||
|
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||||
|
while not view.eof:
|
||||||
|
view.skip_whitespace()
|
||||||
|
if view.eof:
|
||||||
|
break
|
||||||
|
word = view.get_word()
|
||||||
|
if word or not view.eof:
|
||||||
|
args_list.append(word)
|
||||||
|
elif view.eof:
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
arg_str_value: Optional[str] = (
|
||||||
|
None # Holds the raw string for current param
|
||||||
|
)
|
||||||
|
|
||||||
|
if view.eof: # No more input string
|
||||||
|
if param.default is not inspect.Parameter.empty:
|
||||||
|
final_value_for_param = param.default
|
||||||
|
elif param.kind != inspect.Parameter.VAR_KEYWORD:
|
||||||
|
raise MissingRequiredArgument(param.name)
|
||||||
|
else: # VAR_KEYWORD at EOF is fine
|
||||||
|
break
|
||||||
|
else: # Input available
|
||||||
|
is_last_pos_str_greedy = (
|
||||||
|
param == params_to_parse[-1]
|
||||||
|
and param.annotation is str
|
||||||
|
and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_last_pos_str_greedy:
|
||||||
|
arg_str_value = view.read_rest().strip()
|
||||||
|
if (
|
||||||
|
not arg_str_value
|
||||||
|
and param.default is not inspect.Parameter.empty
|
||||||
|
):
|
||||||
|
final_value_for_param = param.default
|
||||||
|
else: # Includes empty string if that's what's left
|
||||||
|
final_value_for_param = arg_str_value
|
||||||
|
else: # Not greedy, or not string, or not last positional
|
||||||
|
if view.buffer[view.index] == '"':
|
||||||
|
arg_str_value = view.get_quoted_string()
|
||||||
|
if arg_str_value == "" and view.buffer[view.index] == '"':
|
||||||
|
raise BadArgument(
|
||||||
|
f"Unterminated quoted string for argument '{param.name}'."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arg_str_value = view.get_word()
|
||||||
|
|
||||||
|
# If final_value_for_param was not set by greedy logic, try conversion
|
||||||
|
if final_value_for_param is inspect.Parameter.empty:
|
||||||
|
if (
|
||||||
|
arg_str_value is None
|
||||||
|
): # Should not happen if view.get_word/get_quoted_string is robust
|
||||||
|
if param.default is not inspect.Parameter.empty:
|
||||||
|
final_value_for_param = param.default
|
||||||
|
else:
|
||||||
|
raise MissingRequiredArgument(param.name)
|
||||||
|
else: # We have an arg_str_value (could be empty string "" from quotes)
|
||||||
|
annotation = param.annotation
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
|
||||||
|
if origin is Union: # Handles Optional[T] and Union[T1, T2]
|
||||||
|
union_args = get_args(annotation)
|
||||||
|
is_optional = (
|
||||||
|
len(union_args) == 2 and type(None) in union_args
|
||||||
|
)
|
||||||
|
|
||||||
|
converted_for_union = False
|
||||||
|
last_err_union: Optional[BadArgument] = None
|
||||||
|
for t_arg in union_args:
|
||||||
|
if t_arg is type(None):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
final_value_for_param = await run_converters(
|
||||||
|
ctx, t_arg, arg_str_value
|
||||||
|
)
|
||||||
|
converted_for_union = True
|
||||||
|
break
|
||||||
|
except BadArgument as e:
|
||||||
|
last_err_union = e
|
||||||
|
|
||||||
|
if not converted_for_union:
|
||||||
|
if (
|
||||||
|
is_optional and param.default is None
|
||||||
|
): # Special handling for Optional[T] if conversion failed
|
||||||
|
# If arg_str_value was "" and type was Optional[str], StringConverter would return ""
|
||||||
|
# If arg_str_value was "" and type was Optional[int], BadArgument would be raised.
|
||||||
|
# This path is for when all actual types in Optional[T] fail conversion.
|
||||||
|
# If default is None, we can assign None.
|
||||||
|
final_value_for_param = None
|
||||||
|
elif last_err_union:
|
||||||
|
raise last_err_union
|
||||||
|
else: # Should not be reached if logic is correct
|
||||||
|
raise BadArgument(
|
||||||
|
f"Could not convert '{arg_str_value}' to any of {union_args} for param '{param.name}'."
|
||||||
|
)
|
||||||
|
elif annotation is inspect.Parameter.empty or annotation is str:
|
||||||
|
final_value_for_param = arg_str_value
|
||||||
|
else: # Standard type hint
|
||||||
|
final_value_for_param = await run_converters(
|
||||||
|
ctx, annotation, arg_str_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final check if value was resolved
|
||||||
|
if final_value_for_param is inspect.Parameter.empty:
|
||||||
|
if param.default is not inspect.Parameter.empty:
|
||||||
|
final_value_for_param = param.default
|
||||||
|
elif param.kind != inspect.Parameter.VAR_KEYWORD:
|
||||||
|
# This state implies an issue if required and no default, and no input was parsed.
|
||||||
|
raise MissingRequiredArgument(
|
||||||
|
f"Parameter '{param.name}' could not be resolved."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign to args_list or kwargs_dict if a value was determined
|
||||||
|
if final_value_for_param is not inspect.Parameter.empty:
|
||||||
|
if (
|
||||||
|
param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
or param.kind == inspect.Parameter.POSITIONAL_ONLY
|
||||||
|
):
|
||||||
|
args_list.append(final_value_for_param)
|
||||||
|
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
|
||||||
|
kwargs_dict[param.name] = final_value_for_param
|
||||||
|
|
||||||
|
return args_list, kwargs_dict
|
||||||
|
|
||||||
|
async def process_commands(self, message: "Message") -> None:
|
||||||
|
if not message.content:
|
||||||
|
return
|
||||||
|
|
||||||
|
prefix_to_use = await self.get_prefix(message)
|
||||||
|
if not prefix_to_use:
|
||||||
|
return
|
||||||
|
|
||||||
|
actual_prefix: Optional[str] = None
|
||||||
|
if isinstance(prefix_to_use, list):
|
||||||
|
for p in prefix_to_use:
|
||||||
|
if message.content.startswith(p):
|
||||||
|
actual_prefix = p
|
||||||
|
break
|
||||||
|
if not actual_prefix:
|
||||||
|
return
|
||||||
|
elif isinstance(prefix_to_use, str):
|
||||||
|
if message.content.startswith(prefix_to_use):
|
||||||
|
actual_prefix = prefix_to_use
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
if actual_prefix is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
content_without_prefix = message.content[len(actual_prefix) :]
|
||||||
|
view = StringView(content_without_prefix)
|
||||||
|
|
||||||
|
command_name = view.get_word()
|
||||||
|
if not command_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
command = self.get_command(command_name)
|
||||||
|
if not command:
|
||||||
|
return
|
||||||
|
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message,
|
||||||
|
bot=self.client,
|
||||||
|
prefix=actual_prefix,
|
||||||
|
command=command,
|
||||||
|
invoked_with=command_name,
|
||||||
|
cog=command.cog,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
|
||||||
|
ctx.args = parsed_args
|
||||||
|
ctx.kwargs = parsed_kwargs
|
||||||
|
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
|
||||||
|
except CommandError as e:
|
||||||
|
print(f"Command error for '{command.name}': {e}")
|
||||||
|
if hasattr(self.client, "on_command_error"):
|
||||||
|
await self.client.on_command_error(ctx, e)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unexpected error invoking command '{command.name}': {e}")
|
||||||
|
exc = CommandInvokeError(e)
|
||||||
|
if hasattr(self.client, "on_command_error"):
|
||||||
|
await self.client.on_command_error(ctx, exc)
|
150
disagreement/ext/commands/decorators.py
Normal file
150
disagreement/ext/commands/decorators.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# disagreement/ext/commands/decorators.py
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import time
|
||||||
|
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .core import Command, CommandContext # For type hinting return or internal use
|
||||||
|
|
||||||
|
# from .cog import Cog # For Cog specific decorators
|
||||||
|
|
||||||
|
|
||||||
|
def command(
|
||||||
|
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
A decorator that transforms a function into a Command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (Optional[str]): The name of the command. Defaults to the function name.
|
||||||
|
aliases (Optional[List[str]]): Alternative names for the command.
|
||||||
|
**attrs: Additional attributes to pass to the Command constructor
|
||||||
|
(e.g., brief, description, hidden).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: A decorator that registers the command.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
func: Callable[..., Awaitable[None]],
|
||||||
|
) -> Callable[..., Awaitable[None]]:
|
||||||
|
if not asyncio.iscoroutinefunction(func):
|
||||||
|
raise TypeError("Command callback must be a coroutine function.")
|
||||||
|
|
||||||
|
from .core import (
|
||||||
|
Command,
|
||||||
|
) # Late import to avoid circular dependencies at module load time
|
||||||
|
|
||||||
|
# The actual registration will happen when a Cog is added or if commands are global.
|
||||||
|
# For now, this decorator creates a Command instance and attaches it to the function,
|
||||||
|
# or returns a Command instance that can be collected.
|
||||||
|
|
||||||
|
cmd_name = name or func.__name__
|
||||||
|
|
||||||
|
# Store command attributes on the function itself for later collection by Cog or Client
|
||||||
|
# This is a common pattern.
|
||||||
|
if hasattr(func, "__command_attrs__"):
|
||||||
|
# This case might occur if decorators are stacked in an unusual way,
|
||||||
|
# or if a function is decorated multiple times (which should be disallowed or handled).
|
||||||
|
# For now, let's assume one @command decorator per function.
|
||||||
|
raise TypeError("Function is already a command or has command attributes.")
|
||||||
|
|
||||||
|
# Create the command object. It will be registered by the Cog or Client.
|
||||||
|
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
||||||
|
|
||||||
|
# We can attach the command object to the function, so Cogs can find it.
|
||||||
|
func.__command_object__ = cmd # type: ignore # type: ignore[attr-defined]
|
||||||
|
return func # Return the original function, now marked.
|
||||||
|
# Or return `cmd` if commands are registered globally immediately.
|
||||||
|
# For Cogs, returning `func` and letting Cog collect is cleaner.
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def listener(
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
|
"""
|
||||||
|
A decorator that marks a function as an event listener within a Cog.
|
||||||
|
The actual registration happens when the Cog is added to the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (Optional[str]): The name of the event to listen to.
|
||||||
|
Defaults to the function name (e.g., `on_message`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
func: Callable[..., Awaitable[None]],
|
||||||
|
) -> Callable[..., Awaitable[None]]:
|
||||||
|
if not asyncio.iscoroutinefunction(func):
|
||||||
|
raise TypeError("Listener callback must be a coroutine function.")
|
||||||
|
|
||||||
|
# 'name' here is from the outer 'listener' scope (closure)
|
||||||
|
actual_event_name = name or func.__name__
|
||||||
|
# Store listener info on the function for Cog to collect
|
||||||
|
setattr(func, "__listener_name__", actual_event_name)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator # This must be correctly indented under 'listener'
|
||||||
|
|
||||||
|
|
||||||
|
def check(
|
||||||
|
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
|
||||||
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
|
"""Decorator to add a check to a command."""
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
func: Callable[..., Awaitable[None]],
|
||||||
|
) -> Callable[..., Awaitable[None]]:
|
||||||
|
checks = getattr(func, "__command_checks__", [])
|
||||||
|
checks.append(predicate)
|
||||||
|
setattr(func, "__command_checks__", checks)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def check_any(
|
||||||
|
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
|
||||||
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
|
"""Decorator that passes if any predicate returns ``True``."""
|
||||||
|
|
||||||
|
async def predicate(ctx: "CommandContext") -> bool:
|
||||||
|
from .errors import CheckAnyFailure, CheckFailure
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
for p in predicates:
|
||||||
|
try:
|
||||||
|
result = p(ctx)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
result = await result
|
||||||
|
if result:
|
||||||
|
return True
|
||||||
|
except CheckFailure as e:
|
||||||
|
errors.append(e)
|
||||||
|
raise CheckAnyFailure(errors)
|
||||||
|
|
||||||
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
|
def cooldown(
|
||||||
|
rate: int, per: float
|
||||||
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
||||||
|
"""Simple per-user cooldown decorator."""
|
||||||
|
|
||||||
|
buckets: dict[str, dict[str, float]] = {}
|
||||||
|
|
||||||
|
async def predicate(ctx: "CommandContext") -> bool:
|
||||||
|
from .errors import CommandOnCooldown
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
user_buckets = buckets.setdefault(ctx.command.name, {})
|
||||||
|
reset = user_buckets.get(ctx.author.id, 0)
|
||||||
|
if now < reset:
|
||||||
|
raise CommandOnCooldown(reset - now)
|
||||||
|
user_buckets[ctx.author.id] = now + per
|
||||||
|
return True
|
||||||
|
|
||||||
|
return check(predicate)
|
76
disagreement/ext/commands/errors.py
Normal file
76
disagreement/ext/commands/errors.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# disagreement/ext/commands/errors.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Custom exceptions for the command extension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from disagreement.errors import DisagreementException
|
||||||
|
|
||||||
|
|
||||||
|
class CommandError(DisagreementException):
|
||||||
|
"""Base exception for errors raised by the commands extension."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CommandNotFound(CommandError):
|
||||||
|
"""Exception raised when a command is not found."""
|
||||||
|
|
||||||
|
def __init__(self, command_name: str):
|
||||||
|
self.command_name = command_name
|
||||||
|
super().__init__(f"Command '{command_name}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class BadArgument(CommandError):
|
||||||
|
"""Exception raised when a command argument fails to parse or validate."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MissingRequiredArgument(BadArgument):
|
||||||
|
"""Exception raised when a required command argument is missing."""
|
||||||
|
|
||||||
|
def __init__(self, param_name: str):
|
||||||
|
self.param_name = param_name
|
||||||
|
super().__init__(f"Missing required argument: {param_name}")
|
||||||
|
|
||||||
|
|
||||||
|
class ArgumentParsingError(BadArgument):
|
||||||
|
"""Exception raised during the argument parsing process."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CheckFailure(CommandError):
|
||||||
|
"""Exception raised when a command check fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CheckAnyFailure(CheckFailure):
|
||||||
|
"""Raised when :func:`check_any` fails all checks."""
|
||||||
|
|
||||||
|
def __init__(self, errors: list[CheckFailure]):
|
||||||
|
self.errors = errors
|
||||||
|
msg = "; ".join(str(e) for e in errors)
|
||||||
|
super().__init__(f"All checks failed: {msg}")
|
||||||
|
|
||||||
|
|
||||||
|
class CommandOnCooldown(CheckFailure):
|
||||||
|
"""Raised when a command is invoked while on cooldown."""
|
||||||
|
|
||||||
|
def __init__(self, retry_after: float):
|
||||||
|
self.retry_after = retry_after
|
||||||
|
super().__init__(f"Command is on cooldown. Retry in {retry_after:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
class CommandInvokeError(CommandError):
|
||||||
|
"""Exception raised when an error occurs during command invocation."""
|
||||||
|
|
||||||
|
def __init__(self, original: Exception):
|
||||||
|
self.original = original
|
||||||
|
super().__init__(f"Error during command invocation: {original}")
|
||||||
|
|
||||||
|
|
||||||
|
# Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc.
|
||||||
|
# These might inherit from BadArgument.
|
37
disagreement/ext/commands/help.py
Normal file
37
disagreement/ext/commands/help.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# disagreement/ext/commands/help.py
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from .core import Command, CommandContext, CommandHandler
|
||||||
|
|
||||||
|
|
||||||
|
class HelpCommand(Command):
|
||||||
|
"""Built-in command that displays help information for other commands."""
|
||||||
|
|
||||||
|
def __init__(self, handler: CommandHandler) -> None:
|
||||||
|
self.handler = handler
|
||||||
|
|
||||||
|
async def callback(ctx: CommandContext, command: Optional[str] = None) -> None:
|
||||||
|
if command:
|
||||||
|
cmd = handler.get_command(command)
|
||||||
|
if not cmd or cmd.name.lower() != command.lower():
|
||||||
|
await ctx.send(f"Command '{command}' not found.")
|
||||||
|
return
|
||||||
|
description = cmd.description or cmd.brief or "No description provided."
|
||||||
|
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
|
||||||
|
else:
|
||||||
|
lines: List[str] = []
|
||||||
|
for registered in dict.fromkeys(handler.commands.values()):
|
||||||
|
brief = registered.brief or registered.description or ""
|
||||||
|
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
|
||||||
|
if lines:
|
||||||
|
await ctx.send("\n".join(lines))
|
||||||
|
else:
|
||||||
|
await ctx.send("No commands available.")
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
callback,
|
||||||
|
name="help",
|
||||||
|
brief="Show command help.",
|
||||||
|
description="Displays help for commands.",
|
||||||
|
)
|
103
disagreement/ext/commands/view.py
Normal file
103
disagreement/ext/commands/view.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# disagreement/ext/commands/view.py
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class StringView:
|
||||||
|
"""
|
||||||
|
A utility class to help with parsing strings, particularly for command arguments.
|
||||||
|
It keeps track of the current position in the string and provides methods
|
||||||
|
to read parts of it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, buffer: str):
|
||||||
|
self.buffer: str = buffer
|
||||||
|
self.original: str = buffer # Keep original for error reporting if needed
|
||||||
|
self.index: int = 0
|
||||||
|
self.end: int = len(buffer)
|
||||||
|
self.previous: int = 0 # Index before the last successful read
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remaining(self) -> str:
|
||||||
|
"""Returns the rest of the string that hasn't been consumed."""
|
||||||
|
return self.buffer[self.index :]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eof(self) -> bool:
|
||||||
|
"""Checks if the end of the string has been reached."""
|
||||||
|
return self.index >= self.end
|
||||||
|
|
||||||
|
def skip_whitespace(self) -> None:
|
||||||
|
"""Skips any leading whitespace from the current position."""
|
||||||
|
while not self.eof and self.buffer[self.index].isspace():
|
||||||
|
self.index += 1
|
||||||
|
|
||||||
|
def get_word(self) -> str:
|
||||||
|
"""
|
||||||
|
Reads a "word" from the current position.
|
||||||
|
A word is a sequence of non-whitespace characters.
|
||||||
|
"""
|
||||||
|
self.skip_whitespace()
|
||||||
|
if self.eof:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
self.previous = self.index
|
||||||
|
match = re.match(r"\S+", self.buffer[self.index :])
|
||||||
|
if match:
|
||||||
|
word = match.group(0)
|
||||||
|
self.index += len(word)
|
||||||
|
return word
|
||||||
|
return "" # Should not happen if not eof and skip_whitespace was called
|
||||||
|
|
||||||
|
def get_quoted_string(self) -> str:
|
||||||
|
"""
|
||||||
|
Reads a string enclosed in double quotes.
|
||||||
|
Handles escaped quotes inside the string.
|
||||||
|
"""
|
||||||
|
self.skip_whitespace()
|
||||||
|
if self.eof or self.buffer[self.index] != '"':
|
||||||
|
return "" # Or raise an error, or return None
|
||||||
|
|
||||||
|
self.previous = self.index
|
||||||
|
self.index += 1 # Skip the opening quote
|
||||||
|
result = []
|
||||||
|
escaped = False
|
||||||
|
|
||||||
|
while not self.eof:
|
||||||
|
char = self.buffer[self.index]
|
||||||
|
self.index += 1
|
||||||
|
|
||||||
|
if escaped:
|
||||||
|
result.append(char)
|
||||||
|
escaped = False
|
||||||
|
elif char == "\\":
|
||||||
|
escaped = True
|
||||||
|
elif char == '"':
|
||||||
|
return "".join(result) # Closing quote found
|
||||||
|
else:
|
||||||
|
result.append(char)
|
||||||
|
|
||||||
|
# If loop finishes, means EOF was reached before closing quote
|
||||||
|
# This is an error condition. Restore index and indicate failure.
|
||||||
|
self.index = self.previous
|
||||||
|
# Consider raising an error like UnterminatedQuotedStringError
|
||||||
|
return "" # Or raise
|
||||||
|
|
||||||
|
def read_rest(self) -> str:
|
||||||
|
"""Reads all remaining characters from the current position."""
|
||||||
|
self.skip_whitespace()
|
||||||
|
if self.eof:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
self.previous = self.index
|
||||||
|
result = self.buffer[self.index :]
|
||||||
|
self.index = self.end
|
||||||
|
return result
|
||||||
|
|
||||||
|
def undo(self) -> None:
|
||||||
|
"""Resets the current position to before the last successful read."""
|
||||||
|
self.index = self.previous
|
||||||
|
|
||||||
|
# Could add more methods like:
|
||||||
|
# peek() - look at next char without consuming
|
||||||
|
# match_regex(pattern) - consume if regex matches
|
43
disagreement/ext/loader.py
Normal file
43
disagreement/ext/loader.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from importlib import import_module
|
||||||
|
import sys
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
__all__ = ["load_extension", "unload_extension"]
|
||||||
|
|
||||||
|
_loaded_extensions: Dict[str, ModuleType] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def load_extension(name: str) -> ModuleType:
|
||||||
|
"""Load an extension by name.
|
||||||
|
|
||||||
|
The extension module must define a ``setup`` coroutine or function that
|
||||||
|
will be called after loading. Any value returned by ``setup`` is ignored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if name in _loaded_extensions:
|
||||||
|
raise ValueError(f"Extension '{name}' already loaded")
|
||||||
|
|
||||||
|
module = import_module(name)
|
||||||
|
|
||||||
|
if not hasattr(module, "setup"):
|
||||||
|
raise ImportError(f"Extension '{name}' does not define a setup function")
|
||||||
|
|
||||||
|
module.setup()
|
||||||
|
_loaded_extensions[name] = module
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def unload_extension(name: str) -> None:
|
||||||
|
"""Unload a previously loaded extension."""
|
||||||
|
|
||||||
|
module = _loaded_extensions.pop(name, None)
|
||||||
|
if module is None:
|
||||||
|
raise ValueError(f"Extension '{name}' is not loaded")
|
||||||
|
|
||||||
|
if hasattr(module, "teardown"):
|
||||||
|
module.teardown()
|
||||||
|
|
||||||
|
sys.modules.pop(name, None)
|
89
disagreement/ext/tasks.py
Normal file
89
disagreement/ext/tasks.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Awaitable, Callable, Optional
|
||||||
|
|
||||||
|
__all__ = ["loop", "Task"]
|
||||||
|
|
||||||
|
|
||||||
|
class Task:
|
||||||
|
"""Simple repeating task."""
|
||||||
|
|
||||||
|
def __init__(self, coro: Callable[..., Awaitable[Any]], *, seconds: float) -> None:
|
||||||
|
self._coro = coro
|
||||||
|
self._seconds = float(seconds)
|
||||||
|
self._task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
|
async def _run(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await self._coro(*args, **kwargs)
|
||||||
|
await asyncio.sleep(self._seconds)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||||
|
if self._task is None or self._task.done():
|
||||||
|
self._task = asyncio.create_task(self._run(*args, **kwargs))
|
||||||
|
return self._task
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
if self._task is not None:
|
||||||
|
self._task.cancel()
|
||||||
|
self._task = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def running(self) -> bool:
|
||||||
|
return self._task is not None and not self._task.done()
|
||||||
|
|
||||||
|
|
||||||
|
class _Loop:
|
||||||
|
def __init__(self, func: Callable[..., Awaitable[Any]], seconds: float) -> None:
|
||||||
|
self.func = func
|
||||||
|
self.seconds = seconds
|
||||||
|
self._task: Optional[Task] = None
|
||||||
|
self._owner: Any = None
|
||||||
|
|
||||||
|
def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop":
|
||||||
|
return _BoundLoop(self, obj)
|
||||||
|
|
||||||
|
def _coro(self, *args: Any, **kwargs: Any) -> Awaitable[Any]:
|
||||||
|
if self._owner is None:
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
|
return self.func(self._owner, *args, **kwargs)
|
||||||
|
|
||||||
|
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||||
|
self._task = Task(self._coro, seconds=self.seconds)
|
||||||
|
return self._task.start(*args, **kwargs)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
if self._task is not None:
|
||||||
|
self._task.stop()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def running(self) -> bool:
|
||||||
|
return self._task.running if self._task else False
|
||||||
|
|
||||||
|
|
||||||
|
class _BoundLoop:
|
||||||
|
def __init__(self, parent: _Loop, owner: Any) -> None:
|
||||||
|
self._parent = parent
|
||||||
|
self._owner = owner
|
||||||
|
|
||||||
|
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||||
|
self._parent._owner = self._owner
|
||||||
|
return self._parent.start(*args, **kwargs)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._parent.stop()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def running(self) -> bool:
|
||||||
|
return self._parent.running
|
||||||
|
|
||||||
|
|
||||||
|
def loop(*, seconds: float) -> Callable[[Callable[..., Awaitable[Any]]], _Loop]:
|
||||||
|
"""Decorator to create a looping task."""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Awaitable[Any]]) -> _Loop:
|
||||||
|
return _Loop(func, seconds)
|
||||||
|
|
||||||
|
return decorator
|
490
disagreement/gateway.py
Normal file
490
disagreement/gateway.py
Normal file
@ -0,0 +1,490 @@
|
|||||||
|
# disagreement/gateway.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Manages the WebSocket connection to the Discord Gateway.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
import zlib
|
||||||
|
import time
|
||||||
|
from typing import Optional, TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
|
from .enums import GatewayOpcode, GatewayIntent
|
||||||
|
from .errors import GatewayException, DisagreementException, AuthenticationError
|
||||||
|
from .interactions import Interaction
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import Client # For type hinting
|
||||||
|
from .event_dispatcher import EventDispatcher
|
||||||
|
from .http import HTTPClient
|
||||||
|
from .interactions import Interaction # Added for INTERACTION_CREATE
|
||||||
|
|
||||||
|
# ZLIB Decompression constants
|
||||||
|
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
||||||
|
MAX_DECOMPRESSION_SIZE = 10 * 1024 * 1024 # 10 MiB, adjust as needed
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayClient:
|
||||||
|
"""
|
||||||
|
Handles the Discord Gateway WebSocket connection, heartbeating, and event dispatching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
http_client: "HTTPClient",
|
||||||
|
event_dispatcher: "EventDispatcher",
|
||||||
|
token: str,
|
||||||
|
intents: int,
|
||||||
|
client_instance: "Client", # Pass the main client instance
|
||||||
|
verbose: bool = False,
|
||||||
|
*,
|
||||||
|
shard_id: Optional[int] = None,
|
||||||
|
shard_count: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self._http: "HTTPClient" = http_client
|
||||||
|
self._dispatcher: "EventDispatcher" = event_dispatcher
|
||||||
|
self._token: str = token
|
||||||
|
self._intents: int = intents
|
||||||
|
self._client_instance: "Client" = client_instance # Store client instance
|
||||||
|
self.verbose: bool = verbose
|
||||||
|
self._shard_id: Optional[int] = shard_id
|
||||||
|
self._shard_count: Optional[int] = shard_count
|
||||||
|
|
||||||
|
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||||
|
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||||
|
self._heartbeat_interval: Optional[float] = None
|
||||||
|
self._last_sequence: Optional[int] = None
|
||||||
|
self._session_id: Optional[str] = None
|
||||||
|
self._resume_gateway_url: Optional[str] = None
|
||||||
|
|
||||||
|
self._keep_alive_task: Optional[asyncio.Task] = None
|
||||||
|
self._receive_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
# For zlib decompression
|
||||||
|
self._buffer = bytearray()
|
||||||
|
self._inflator = zlib.decompressobj()
|
||||||
|
|
||||||
|
async def _decompress_message(
|
||||||
|
self, message_bytes: bytes
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Decompresses a zlib-compressed message from the Gateway."""
|
||||||
|
self._buffer.extend(message_bytes)
|
||||||
|
|
||||||
|
if len(message_bytes) < 4 or message_bytes[-4:] != ZLIB_SUFFIX:
|
||||||
|
# Message is not complete or not zlib compressed in the expected way
|
||||||
|
return None
|
||||||
|
# Or handle partial messages if Discord ever sends them fragmented like this,
|
||||||
|
# but typically each binary message is a complete zlib stream.
|
||||||
|
|
||||||
|
try:
|
||||||
|
decompressed = self._inflator.decompress(self._buffer)
|
||||||
|
self._buffer.clear() # Reset buffer after successful decompression
|
||||||
|
return json.loads(decompressed.decode("utf-8"))
|
||||||
|
except zlib.error as e:
|
||||||
|
print(f"Zlib decompression error: {e}")
|
||||||
|
self._buffer.clear() # Clear buffer on error
|
||||||
|
self._inflator = zlib.decompressobj() # Reset inflator
|
||||||
|
return None
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"JSON decode error after decompression: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _send_json(self, payload: Dict[str, Any]):
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
if self.verbose:
|
||||||
|
print(f"GATEWAY SEND: {payload}")
|
||||||
|
await self._ws.send_json(payload)
|
||||||
|
else:
|
||||||
|
print("Gateway send attempted but WebSocket is closed or not available.")
|
||||||
|
# raise GatewayException("WebSocket is not connected.")
|
||||||
|
|
||||||
|
async def _heartbeat(self):
|
||||||
|
"""Sends a heartbeat to the Gateway."""
|
||||||
|
payload = {"op": GatewayOpcode.HEARTBEAT, "d": self._last_sequence}
|
||||||
|
await self._send_json(payload)
|
||||||
|
# print("Sent heartbeat.")
|
||||||
|
|
||||||
|
async def _keep_alive(self):
|
||||||
|
"""Manages the heartbeating loop."""
|
||||||
|
if self._heartbeat_interval is None:
|
||||||
|
# This should not happen if HELLO was processed correctly
|
||||||
|
print("Error: Heartbeat interval not set. Cannot start keep_alive.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await self._heartbeat()
|
||||||
|
await asyncio.sleep(
|
||||||
|
self._heartbeat_interval / 1000
|
||||||
|
) # Interval is in ms
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print("Keep_alive task cancelled.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in keep_alive loop: {e}")
|
||||||
|
# Potentially trigger a reconnect here or notify client
|
||||||
|
await self._client_instance.close_gateway(code=1000) # Generic close
|
||||||
|
|
||||||
|
async def _identify(self):
|
||||||
|
"""Sends the IDENTIFY payload to the Gateway."""
|
||||||
|
payload = {
|
||||||
|
"op": GatewayOpcode.IDENTIFY,
|
||||||
|
"d": {
|
||||||
|
"token": self._token,
|
||||||
|
"intents": self._intents,
|
||||||
|
"properties": {
|
||||||
|
"$os": "python", # Or platform.system()
|
||||||
|
"$browser": "disagreement", # Library name
|
||||||
|
"$device": "disagreement", # Library name
|
||||||
|
},
|
||||||
|
"compress": True, # Request zlib compression
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if self._shard_id is not None and self._shard_count is not None:
|
||||||
|
payload["d"]["shard"] = [self._shard_id, self._shard_count]
|
||||||
|
await self._send_json(payload)
|
||||||
|
print("Sent IDENTIFY.")
|
||||||
|
|
||||||
|
async def _resume(self):
|
||||||
|
"""Sends the RESUME payload to the Gateway."""
|
||||||
|
if not self._session_id or self._last_sequence is None:
|
||||||
|
print("Cannot RESUME: session_id or last_sequence is missing.")
|
||||||
|
await self._identify() # Fallback to identify
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"op": GatewayOpcode.RESUME,
|
||||||
|
"d": {
|
||||||
|
"token": self._token,
|
||||||
|
"session_id": self._session_id,
|
||||||
|
"seq": self._last_sequence,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await self._send_json(payload)
|
||||||
|
print(
|
||||||
|
f"Sent RESUME for session {self._session_id} at sequence {self._last_sequence}."
|
||||||
|
)
|
||||||
|
async def update_presence(
|
||||||
|
self,
|
||||||
|
status: str,
|
||||||
|
activity_name: Optional[str] = None,
|
||||||
|
activity_type: int = 0,
|
||||||
|
since: int = 0,
|
||||||
|
afk: bool = False,
|
||||||
|
):
|
||||||
|
"""Sends the presence update payload to the Gateway."""
|
||||||
|
payload = {
|
||||||
|
"op": GatewayOpcode.PRESENCE_UPDATE,
|
||||||
|
"d": {
|
||||||
|
"since": since,
|
||||||
|
"activities": [
|
||||||
|
{
|
||||||
|
"name": activity_name,
|
||||||
|
"type": activity_type,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if activity_name
|
||||||
|
else [],
|
||||||
|
"status": status,
|
||||||
|
"afk": afk,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await self._send_json(payload)
|
||||||
|
|
||||||
|
async def _handle_dispatch(self, data: Dict[str, Any]):
|
||||||
|
"""Handles DISPATCH events (actual Discord events)."""
|
||||||
|
event_name = data.get("t")
|
||||||
|
sequence_num = data.get("s")
|
||||||
|
raw_event_d_payload = data.get(
|
||||||
|
"d"
|
||||||
|
) # This is the 'd' field from the gateway event
|
||||||
|
|
||||||
|
if sequence_num is not None:
|
||||||
|
self._last_sequence = sequence_num
|
||||||
|
|
||||||
|
if event_name == "READY": # Special handling for READY
|
||||||
|
if not isinstance(raw_event_d_payload, dict):
|
||||||
|
print(
|
||||||
|
f"Error: READY event 'd' payload is not a dict or is missing: {raw_event_d_payload}"
|
||||||
|
)
|
||||||
|
# Consider raising an error or attempting a reconnect
|
||||||
|
return
|
||||||
|
self._session_id = raw_event_d_payload.get("session_id")
|
||||||
|
self._resume_gateway_url = raw_event_d_payload.get("resume_gateway_url")
|
||||||
|
|
||||||
|
app_id_str = "N/A"
|
||||||
|
# Store application_id on the client instance
|
||||||
|
if (
|
||||||
|
"application" in raw_event_d_payload
|
||||||
|
and isinstance(raw_event_d_payload["application"], dict)
|
||||||
|
and "id" in raw_event_d_payload["application"]
|
||||||
|
):
|
||||||
|
app_id_value = raw_event_d_payload["application"]["id"]
|
||||||
|
self._client_instance.application_id = (
|
||||||
|
app_id_value # Snowflake can be str or int
|
||||||
|
)
|
||||||
|
app_id_str = str(app_id_value)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Warning: Could not find application ID in READY payload. App commands may not work."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse and store the bot's own user object
|
||||||
|
if "user" in raw_event_d_payload and isinstance(
|
||||||
|
raw_event_d_payload["user"], dict
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Assuming Client has a parse_user method that takes user data dict
|
||||||
|
# and returns a User object, also caching it.
|
||||||
|
bot_user_obj = self._client_instance.parse_user(
|
||||||
|
raw_event_d_payload["user"]
|
||||||
|
)
|
||||||
|
self._client_instance.user = bot_user_obj
|
||||||
|
print(
|
||||||
|
f"Gateway READY. Bot User: {bot_user_obj.username}#{bot_user_obj.discriminator}. Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing bot user from READY payload: {e}")
|
||||||
|
print(
|
||||||
|
f"Gateway READY (user parse failed). Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Warning: Bot user object not found or invalid in READY payload."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Gateway READY (no user). Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
|
||||||
|
elif event_name == "INTERACTION_CREATE":
|
||||||
|
# print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}")
|
||||||
|
if isinstance(raw_event_d_payload, dict):
|
||||||
|
interaction = Interaction(
|
||||||
|
data=raw_event_d_payload, client_instance=self._client_instance
|
||||||
|
)
|
||||||
|
await self._dispatcher.dispatch(
|
||||||
|
"INTERACTION_CREATE", raw_event_d_payload
|
||||||
|
)
|
||||||
|
# Dispatch to a new client method that will then call AppCommandHandler
|
||||||
|
if hasattr(self._client_instance, "process_interaction"):
|
||||||
|
asyncio.create_task(
|
||||||
|
self._client_instance.process_interaction(interaction)
|
||||||
|
) # type: ignore
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Warning: Client instance does not have process_interaction method for INTERACTION_CREATE."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Error: INTERACTION_CREATE event 'd' payload is not a dict: {raw_event_d_payload}"
|
||||||
|
)
|
||||||
|
elif event_name == "RESUMED":
|
||||||
|
print("Gateway RESUMED successfully.")
|
||||||
|
# RESUMED 'd' payload is often an empty object or debug info.
|
||||||
|
# Ensure it's a dict for the dispatcher.
|
||||||
|
event_data_to_dispatch = (
|
||||||
|
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||||
|
)
|
||||||
|
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||||
|
elif event_name:
|
||||||
|
# For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing.
|
||||||
|
# Models/parsers in EventDispatcher will need to handle potentially empty dicts.
|
||||||
|
event_data_to_dispatch = (
|
||||||
|
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||||
|
)
|
||||||
|
# print(f"GATEWAY RECV EVENT: {event_name} | DATA: {event_data_to_dispatch}")
|
||||||
|
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||||
|
else:
|
||||||
|
print(f"Received dispatch with no event name: {data}")
|
||||||
|
|
||||||
|
async def _process_message(self, msg: aiohttp.WSMessage):
|
||||||
|
"""Processes a single message from the WebSocket."""
|
||||||
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
try:
|
||||||
|
data = json.loads(msg.data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(
|
||||||
|
f"Failed to decode JSON from Gateway: {msg.data[:200]}"
|
||||||
|
) # Log snippet
|
||||||
|
return
|
||||||
|
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
decompressed_data = await self._decompress_message(msg.data)
|
||||||
|
if decompressed_data is None:
|
||||||
|
print("Failed to decompress or decode binary message from Gateway.")
|
||||||
|
return
|
||||||
|
data = decompressed_data
|
||||||
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
print(
|
||||||
|
f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}"
|
||||||
|
)
|
||||||
|
raise GatewayException(
|
||||||
|
f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}"
|
||||||
|
)
|
||||||
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||||
|
close_code = (
|
||||||
|
self._ws.close_code
|
||||||
|
if self._ws and hasattr(self._ws, "close_code")
|
||||||
|
else "N/A"
|
||||||
|
)
|
||||||
|
print(f"WebSocket connection closed by server. Code: {close_code}")
|
||||||
|
# Raise an exception to signal the closure to the client's main run loop
|
||||||
|
raise GatewayException(f"WebSocket closed by server. Code: {close_code}")
|
||||||
|
else:
|
||||||
|
print(f"Received unhandled WebSocket message type: {msg.type}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"GATEWAY RECV: {data}")
|
||||||
|
op = data.get("op")
|
||||||
|
# 'd' payload (event_data) is handled specifically by each opcode handler below
|
||||||
|
|
||||||
|
if op == GatewayOpcode.DISPATCH:
|
||||||
|
await self._handle_dispatch(data) # _handle_dispatch will extract 'd'
|
||||||
|
elif op == GatewayOpcode.HEARTBEAT: # Server requests a heartbeat
|
||||||
|
await self._heartbeat()
|
||||||
|
elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect
|
||||||
|
print("Gateway requested RECONNECT. Closing and will attempt to reconnect.")
|
||||||
|
await self.close(code=4000) # Use a non-1000 code to indicate reconnect
|
||||||
|
elif op == GatewayOpcode.INVALID_SESSION:
|
||||||
|
# The 'd' payload for INVALID_SESSION is a boolean indicating resumability
|
||||||
|
can_resume = data.get("d") is True
|
||||||
|
print(f"Gateway indicated INVALID_SESSION. Resumable: {can_resume}")
|
||||||
|
if not can_resume:
|
||||||
|
self._session_id = None # Clear session_id to force re-identify
|
||||||
|
self._last_sequence = None
|
||||||
|
# Close and reconnect. The connect logic will decide to resume or identify.
|
||||||
|
await self.close(
|
||||||
|
code=4000 if can_resume else 4009
|
||||||
|
) # 4009 for non-resumable
|
||||||
|
elif op == GatewayOpcode.HELLO:
|
||||||
|
hello_d_payload = data.get("d")
|
||||||
|
if (
|
||||||
|
not isinstance(hello_d_payload, dict)
|
||||||
|
or "heartbeat_interval" not in hello_d_payload
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"Error: HELLO event 'd' payload is invalid or missing heartbeat_interval: {hello_d_payload}"
|
||||||
|
)
|
||||||
|
await self.close(code=1011) # Internal error, malformed HELLO
|
||||||
|
return
|
||||||
|
self._heartbeat_interval = hello_d_payload["heartbeat_interval"]
|
||||||
|
print(f"Gateway HELLO. Heartbeat interval: {self._heartbeat_interval}ms.")
|
||||||
|
# Start heartbeating
|
||||||
|
if self._keep_alive_task:
|
||||||
|
self._keep_alive_task.cancel()
|
||||||
|
self._keep_alive_task = self._loop.create_task(self._keep_alive())
|
||||||
|
|
||||||
|
# Identify or Resume
|
||||||
|
if self._session_id and self._resume_gateway_url: # Check if we can resume
|
||||||
|
print("Attempting to RESUME session.")
|
||||||
|
await self._resume()
|
||||||
|
else:
|
||||||
|
print("Performing initial IDENTIFY.")
|
||||||
|
await self._identify()
|
||||||
|
elif op == GatewayOpcode.HEARTBEAT_ACK:
|
||||||
|
# print("Received heartbeat ACK.")
|
||||||
|
pass # Good, connection is alive
|
||||||
|
else:
|
||||||
|
print(f"Received unhandled Gateway Opcode: {op} with data: {data}")
|
||||||
|
|
||||||
|
async def _receive_loop(self):
|
||||||
|
"""Continuously receives and processes messages from the WebSocket."""
|
||||||
|
if not self._ws or self._ws.closed:
|
||||||
|
print("Receive loop cannot start: WebSocket is not connected or closed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for msg in self._ws:
|
||||||
|
await self._process_message(msg)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print("Receive_loop task cancelled.")
|
||||||
|
except aiohttp.ClientConnectionError as e:
|
||||||
|
print(f"ClientConnectionError in receive_loop: {e}. Attempting reconnect.")
|
||||||
|
# This might be handled by an outer reconnect loop in the Client class
|
||||||
|
await self.close(code=1006) # Abnormal closure
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unexpected error in receive_loop: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
# Consider specific error types for more granular handling
|
||||||
|
await self.close(code=1011) # Internal error
|
||||||
|
finally:
|
||||||
|
print("Receive_loop ended.")
|
||||||
|
# If the loop ends unexpectedly (not due to explicit close),
|
||||||
|
# the main client might want to try reconnecting.
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connects to the Discord Gateway."""
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
print("Gateway already connected or connecting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
gateway_url = (
|
||||||
|
self._resume_gateway_url or (await self._http.get_gateway_bot())["url"]
|
||||||
|
)
|
||||||
|
if not gateway_url.endswith("?v=10&encoding=json&compress=zlib-stream"):
|
||||||
|
gateway_url += "?v=10&encoding=json&compress=zlib-stream"
|
||||||
|
|
||||||
|
print(f"Connecting to Gateway: {gateway_url}")
|
||||||
|
try:
|
||||||
|
await self._http._ensure_session() # Ensure the HTTP client's session is active
|
||||||
|
assert (
|
||||||
|
self._http._session is not None
|
||||||
|
), "HTTPClient session not initialized after ensure_session"
|
||||||
|
self._ws = await self._http._session.ws_connect(gateway_url, max_msg_size=0)
|
||||||
|
print("Gateway WebSocket connection established.")
|
||||||
|
|
||||||
|
if self._receive_task:
|
||||||
|
self._receive_task.cancel()
|
||||||
|
self._receive_task = self._loop.create_task(self._receive_loop())
|
||||||
|
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
raise GatewayException(
|
||||||
|
f"Failed to connect to Gateway (Connector Error): {e}"
|
||||||
|
) from e
|
||||||
|
except aiohttp.WSServerHandshakeError as e:
|
||||||
|
if e.status == 401: # Unauthorized during handshake
|
||||||
|
raise AuthenticationError(
|
||||||
|
f"Gateway handshake failed (401 Unauthorized): {e.message}. Check your bot token."
|
||||||
|
) from e
|
||||||
|
raise GatewayException(
|
||||||
|
f"Gateway handshake failed (Status: {e.status}): {e.message}"
|
||||||
|
) from e
|
||||||
|
except Exception as e: # Catch other potential errors during connection
|
||||||
|
raise GatewayException(
|
||||||
|
f"An unexpected error occurred during Gateway connection: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def close(self, code: int = 1000):
|
||||||
|
"""Closes the Gateway connection."""
|
||||||
|
print(f"Closing Gateway connection with code {code}...")
|
||||||
|
if self._keep_alive_task and not self._keep_alive_task.done():
|
||||||
|
self._keep_alive_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._keep_alive_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass # Expected
|
||||||
|
|
||||||
|
if self._receive_task and not self._receive_task.done():
|
||||||
|
self._receive_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._receive_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass # Expected
|
||||||
|
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.close(code=code)
|
||||||
|
print("Gateway WebSocket closed.")
|
||||||
|
|
||||||
|
self._ws = None
|
||||||
|
# Do not reset session_id, last_sequence, or resume_gateway_url here
|
||||||
|
# if the close code indicates a resumable disconnect (e.g. 4000-4009, or server-initiated RECONNECT)
|
||||||
|
# The connect logic will decide whether to resume or re-identify.
|
||||||
|
# However, if it's a non-resumable close (e.g. Invalid Session non-resumable), clear them.
|
||||||
|
if code == 4009: # Invalid session, not resumable
|
||||||
|
print("Clearing session state due to non-resumable invalid session.")
|
||||||
|
self._session_id = None
|
||||||
|
self._last_sequence = None
|
||||||
|
self._resume_gateway_url = None # This might be re-fetched anyway
|
657
disagreement/http.py
Normal file
657
disagreement/http.py
Normal file
@ -0,0 +1,657 @@
|
|||||||
|
# disagreement/http.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
HTTP client for interacting with the Discord REST API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import aiohttp # pylint: disable=import-error
|
||||||
|
import json
|
||||||
|
from urllib.parse import quote
|
||||||
|
from typing import Optional, Dict, Any, Union, TYPE_CHECKING, List
|
||||||
|
|
||||||
|
from .errors import (
|
||||||
|
HTTPException,
|
||||||
|
RateLimitError,
|
||||||
|
AuthenticationError,
|
||||||
|
DisagreementException,
|
||||||
|
)
|
||||||
|
from . import __version__ # For User-Agent
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import Client
|
||||||
|
from .models import Message
|
||||||
|
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
|
||||||
|
|
||||||
|
# Discord API constants
|
||||||
|
API_BASE_URL = "https://discord.com/api/v10" # Using API v10
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPClient:
|
||||||
|
"""Handles HTTP requests to the Discord API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
token: str,
|
||||||
|
client_session: Optional[aiohttp.ClientSession] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
self.token = token
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = (
|
||||||
|
client_session # Can be externally managed
|
||||||
|
)
|
||||||
|
self.user_agent = f"DiscordBot (https://github.com/yourusername/disagreement, {__version__})" # Customize URL
|
||||||
|
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
self._global_rate_limit_lock = asyncio.Event()
|
||||||
|
self._global_rate_limit_lock.set() # Initially unlocked
|
||||||
|
|
||||||
|
async def _ensure_session(self):
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
self._session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Closes the underlying aiohttp.ClientSession."""
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
async def request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
endpoint: str,
|
||||||
|
payload: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
is_json: bool = True,
|
||||||
|
use_auth_header: bool = True,
|
||||||
|
custom_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Makes an HTTP request to the Discord API."""
|
||||||
|
await self._ensure_session()
|
||||||
|
|
||||||
|
url = f"{API_BASE_URL}{endpoint}"
|
||||||
|
final_headers: Dict[str, str] = { # Renamed to final_headers
|
||||||
|
"User-Agent": self.user_agent,
|
||||||
|
}
|
||||||
|
if use_auth_header:
|
||||||
|
final_headers["Authorization"] = f"Bot {self.token}"
|
||||||
|
|
||||||
|
if is_json and payload:
|
||||||
|
final_headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
|
if custom_headers: # Merge custom headers
|
||||||
|
final_headers.update(custom_headers)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
|
||||||
|
|
||||||
|
# Global rate limit handling
|
||||||
|
await self._global_rate_limit_lock.wait()
|
||||||
|
|
||||||
|
for attempt in range(5): # Max 5 retries for rate limits
|
||||||
|
assert self._session is not None, "ClientSession not initialized"
|
||||||
|
async with self._session.request(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
json=payload if is_json else None,
|
||||||
|
data=payload if not is_json else None,
|
||||||
|
headers=final_headers,
|
||||||
|
params=params,
|
||||||
|
) as response:
|
||||||
|
|
||||||
|
data = None
|
||||||
|
try:
|
||||||
|
if response.headers.get("Content-Type", "").startswith(
|
||||||
|
"application/json"
|
||||||
|
):
|
||||||
|
data = await response.json()
|
||||||
|
else:
|
||||||
|
# For non-JSON responses, like fetching images or other files
|
||||||
|
# We might return the raw response or handle it differently
|
||||||
|
# For now, let's assume most API calls expect JSON
|
||||||
|
data = await response.text()
|
||||||
|
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||||
|
data = (
|
||||||
|
await response.text()
|
||||||
|
) # Fallback to text if JSON parsing fails
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"HTTP RESPONSE: {response.status} {url} | {data}")
|
||||||
|
|
||||||
|
if 200 <= response.status < 300:
|
||||||
|
if response.status == 204:
|
||||||
|
return None
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Rate limit handling
|
||||||
|
if response.status == 429: # Rate limited
|
||||||
|
retry_after_str = response.headers.get("Retry-After", "1")
|
||||||
|
try:
|
||||||
|
retry_after = float(retry_after_str)
|
||||||
|
except ValueError:
|
||||||
|
retry_after = 1.0 # Default retry if header is malformed
|
||||||
|
|
||||||
|
is_global = (
|
||||||
|
response.headers.get("X-RateLimit-Global", "false").lower()
|
||||||
|
== "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
error_message = f"Rate limited on {method} {endpoint}."
|
||||||
|
if data and isinstance(data, dict) and "message" in data:
|
||||||
|
error_message += f" Discord says: {data['message']}"
|
||||||
|
|
||||||
|
if is_global:
|
||||||
|
self._global_rate_limit_lock.clear()
|
||||||
|
await asyncio.sleep(retry_after)
|
||||||
|
self._global_rate_limit_lock.set()
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(retry_after)
|
||||||
|
|
||||||
|
if attempt < 4: # Don't log on the last attempt before raising
|
||||||
|
print(
|
||||||
|
f"{error_message} Retrying after {retry_after}s (Attempt {attempt + 1}/5). Global: {is_global}"
|
||||||
|
)
|
||||||
|
continue # Retry the request
|
||||||
|
else: # Last attempt failed
|
||||||
|
raise RateLimitError(
|
||||||
|
response,
|
||||||
|
message=error_message,
|
||||||
|
retry_after=retry_after,
|
||||||
|
is_global=is_global,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other error handling
|
||||||
|
if response.status == 401: # Unauthorized
|
||||||
|
raise AuthenticationError(response, "Invalid token provided.")
|
||||||
|
if response.status == 403: # Forbidden
|
||||||
|
raise HTTPException(
|
||||||
|
response,
|
||||||
|
"Missing permissions or access denied.",
|
||||||
|
status=response.status,
|
||||||
|
text=str(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
# General HTTP error
|
||||||
|
error_text = str(data) if data else "Unknown error"
|
||||||
|
discord_error_code = (
|
||||||
|
data.get("code") if isinstance(data, dict) else None
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
response,
|
||||||
|
f"API Error on {method} {endpoint}: {error_text}",
|
||||||
|
status=response.status,
|
||||||
|
text=error_text,
|
||||||
|
error_code=discord_error_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not be reached if retries are exhausted by RateLimitError
|
||||||
|
raise DisagreementException(
|
||||||
|
f"Failed request to {method} {endpoint} after multiple retries."
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Specific API call methods ---
|
||||||
|
|
||||||
|
async def get_gateway_bot(self) -> Dict[str, Any]:
|
||||||
|
"""Gets the WSS URL and sharding information for the Gateway."""
|
||||||
|
return await self.request("GET", "/gateway/bot")
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
channel_id: str,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
tts: bool = False,
|
||||||
|
embeds: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
components: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
allowed_mentions: Optional[dict] = None,
|
||||||
|
message_reference: Optional[Dict[str, Any]] = None,
|
||||||
|
flags: Optional[int] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Sends a message to a channel.
|
||||||
|
|
||||||
|
Returns the created message data as a dict.
|
||||||
|
"""
|
||||||
|
payload: Dict[str, Any] = {}
|
||||||
|
if content is not None: # Content is optional if embeds/components are present
|
||||||
|
payload["content"] = content
|
||||||
|
if tts:
|
||||||
|
payload["tts"] = True
|
||||||
|
if embeds:
|
||||||
|
payload["embeds"] = embeds
|
||||||
|
if components:
|
||||||
|
payload["components"] = components
|
||||||
|
if allowed_mentions:
|
||||||
|
payload["allowed_mentions"] = allowed_mentions
|
||||||
|
if flags:
|
||||||
|
payload["flags"] = flags
|
||||||
|
if message_reference:
|
||||||
|
payload["message_reference"] = message_reference
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
raise ValueError("Message must have content, embeds, or components.")
|
||||||
|
|
||||||
|
return await self.request(
|
||||||
|
"POST", f"/channels/{channel_id}/messages", payload=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
async def edit_message(
|
||||||
|
self,
|
||||||
|
channel_id: str,
|
||||||
|
message_id: str,
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Edits a message in a channel."""
|
||||||
|
|
||||||
|
return await self.request(
|
||||||
|
"PATCH",
|
||||||
|
f"/channels/{channel_id}/messages/{message_id}",
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_message(
|
||||||
|
self, channel_id: "Snowflake", message_id: "Snowflake"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Fetches a message from a channel."""
|
||||||
|
|
||||||
|
return await self.request(
|
||||||
|
"GET", f"/channels/{channel_id}/messages/{message_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_reaction(
|
||||||
|
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
|
||||||
|
) -> None:
|
||||||
|
"""Adds a reaction to a message as the current user."""
|
||||||
|
encoded = quote(emoji)
|
||||||
|
await self.request(
|
||||||
|
"PUT",
|
||||||
|
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_reaction(
|
||||||
|
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
|
||||||
|
) -> None:
|
||||||
|
"""Removes the current user's reaction from a message."""
|
||||||
|
encoded = quote(emoji)
|
||||||
|
await self.request(
|
||||||
|
"DELETE",
|
||||||
|
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_reactions(
|
||||||
|
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Fetches the users that reacted with a specific emoji."""
|
||||||
|
encoded = quote(emoji)
|
||||||
|
return await self.request(
|
||||||
|
"GET",
|
||||||
|
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_channel(
|
||||||
|
self, channel_id: str, reason: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a channel.
|
||||||
|
|
||||||
|
If the channel is a guild channel, requires the MANAGE_CHANNELS permission.
|
||||||
|
If the channel is a thread, requires the MANAGE_THREADS permission (if locked) or
|
||||||
|
be the thread creator (if not locked).
|
||||||
|
Deleting a category does not delete its child channels.
|
||||||
|
"""
|
||||||
|
custom_headers = {}
|
||||||
|
if reason:
|
||||||
|
custom_headers["X-Audit-Log-Reason"] = reason
|
||||||
|
|
||||||
|
await self.request(
|
||||||
|
"DELETE",
|
||||||
|
f"/channels/{channel_id}",
|
||||||
|
custom_headers=custom_headers if custom_headers else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_channel(self, channel_id: str) -> Dict[str, Any]:
|
||||||
|
"""Fetches a channel by ID."""
|
||||||
|
return await self.request("GET", f"/channels/{channel_id}")
|
||||||
|
|
||||||
|
async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]:
|
||||||
|
"""Fetches a user object for a given user ID."""
|
||||||
|
return await self.request("GET", f"/users/{user_id}")
|
||||||
|
|
||||||
|
async def get_guild_member(
|
||||||
|
self, guild_id: "Snowflake", user_id: "Snowflake"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Returns a guild member object for the specified user."""
|
||||||
|
return await self.request("GET", f"/guilds/{guild_id}/members/{user_id}")
|
||||||
|
|
||||||
|
async def kick_member(
|
||||||
|
self, guild_id: "Snowflake", user_id: "Snowflake", reason: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Kicks a member from the guild."""
|
||||||
|
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||||
|
await self.request(
|
||||||
|
"DELETE",
|
||||||
|
f"/guilds/{guild_id}/members/{user_id}",
|
||||||
|
custom_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ban_member(
|
||||||
|
self,
|
||||||
|
guild_id: "Snowflake",
|
||||||
|
user_id: "Snowflake",
|
||||||
|
*,
|
||||||
|
delete_message_seconds: int = 0,
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Bans a member from the guild."""
|
||||||
|
payload = {}
|
||||||
|
if delete_message_seconds:
|
||||||
|
payload["delete_message_seconds"] = delete_message_seconds
|
||||||
|
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||||
|
await self.request(
|
||||||
|
"PUT",
|
||||||
|
f"/guilds/{guild_id}/bans/{user_id}",
|
||||||
|
payload=payload if payload else None,
|
||||||
|
custom_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def timeout_member(
|
||||||
|
self,
|
||||||
|
guild_id: "Snowflake",
|
||||||
|
user_id: "Snowflake",
|
||||||
|
*,
|
||||||
|
until: Optional[str],
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Times out a member until the given ISO8601 timestamp."""
|
||||||
|
payload = {"communication_disabled_until": until}
|
||||||
|
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||||
|
return await self.request(
|
||||||
|
"PATCH",
|
||||||
|
f"/guilds/{guild_id}/members/{user_id}",
|
||||||
|
payload=payload,
|
||||||
|
custom_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
|
||||||
|
"""Returns a list of role objects for the guild."""
|
||||||
|
return await self.request("GET", f"/guilds/{guild_id}/roles")
|
||||||
|
|
||||||
|
async def get_guild(self, guild_id: "Snowflake") -> Dict[str, Any]:
|
||||||
|
"""Fetches a guild object for a given guild ID."""
|
||||||
|
return await self.request("GET", f"/guilds/{guild_id}")
|
||||||
|
|
||||||
|
# Add other methods like:
|
||||||
|
# async def get_guild(self, guild_id: str) -> Dict[str, Any]: ...
|
||||||
|
# async def create_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: ...
|
||||||
|
# etc.
|
||||||
|
# --- Application Command Endpoints ---
|
||||||
|
|
||||||
|
# Global Application Commands
|
||||||
|
async def get_global_application_commands(
|
||||||
|
self, application_id: "Snowflake", with_localizations: bool = False
|
||||||
|
) -> List["ApplicationCommand"]:
|
||||||
|
"""Fetches all global commands for your application."""
|
||||||
|
params = {"with_localizations": str(with_localizations).lower()}
|
||||||
|
data = await self.request(
|
||||||
|
"GET", f"/applications/{application_id}/commands", params=params
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand # Ensure constructor is available
|
||||||
|
|
||||||
|
return [ApplicationCommand(cmd_data) for cmd_data in data]
|
||||||
|
|
||||||
|
async def create_global_application_command(
|
||||||
|
self, application_id: "Snowflake", payload: Dict[str, Any]
|
||||||
|
) -> "ApplicationCommand":
|
||||||
|
"""Creates a new global command."""
|
||||||
|
data = await self.request(
|
||||||
|
"POST", f"/applications/{application_id}/commands", payload=payload
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand
|
||||||
|
|
||||||
|
return ApplicationCommand(data)
|
||||||
|
|
||||||
|
async def get_global_application_command(
|
||||||
|
self, application_id: "Snowflake", command_id: "Snowflake"
|
||||||
|
) -> "ApplicationCommand":
|
||||||
|
"""Fetches a specific global command."""
|
||||||
|
data = await self.request(
|
||||||
|
"GET", f"/applications/{application_id}/commands/{command_id}"
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand
|
||||||
|
|
||||||
|
return ApplicationCommand(data)
|
||||||
|
|
||||||
|
async def edit_global_application_command(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
command_id: "Snowflake",
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
) -> "ApplicationCommand":
|
||||||
|
"""Edits a specific global command."""
|
||||||
|
data = await self.request(
|
||||||
|
"PATCH",
|
||||||
|
f"/applications/{application_id}/commands/{command_id}",
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand
|
||||||
|
|
||||||
|
return ApplicationCommand(data)
|
||||||
|
|
||||||
|
async def delete_global_application_command(
|
||||||
|
self, application_id: "Snowflake", command_id: "Snowflake"
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a specific global command."""
|
||||||
|
await self.request(
|
||||||
|
"DELETE", f"/applications/{application_id}/commands/{command_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def bulk_overwrite_global_application_commands(
|
||||||
|
self, application_id: "Snowflake", payload: List[Dict[str, Any]]
|
||||||
|
) -> List["ApplicationCommand"]:
|
||||||
|
"""Bulk overwrites all global commands for your application."""
|
||||||
|
data = await self.request(
|
||||||
|
"PUT", f"/applications/{application_id}/commands", payload=payload
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand
|
||||||
|
|
||||||
|
return [ApplicationCommand(cmd_data) for cmd_data in data]
|
||||||
|
|
||||||
|
# Guild Application Commands
|
||||||
|
async def get_guild_application_commands(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
guild_id: "Snowflake",
|
||||||
|
with_localizations: bool = False,
|
||||||
|
) -> List["ApplicationCommand"]:
|
||||||
|
"""Fetches all commands for your application for a specific guild."""
|
||||||
|
params = {"with_localizations": str(with_localizations).lower()}
|
||||||
|
data = await self.request(
|
||||||
|
"GET",
|
||||||
|
f"/applications/{application_id}/guilds/{guild_id}/commands",
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand
|
||||||
|
|
||||||
|
return [ApplicationCommand(cmd_data) for cmd_data in data]
|
||||||
|
|
||||||
|
async def create_guild_application_command(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
guild_id: "Snowflake",
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
) -> "ApplicationCommand":
|
||||||
|
"""Creates a new guild command."""
|
||||||
|
data = await self.request(
|
||||||
|
"POST",
|
||||||
|
f"/applications/{application_id}/guilds/{guild_id}/commands",
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand
|
||||||
|
|
||||||
|
return ApplicationCommand(data)
|
||||||
|
|
||||||
|
async def get_guild_application_command(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
guild_id: "Snowflake",
|
||||||
|
command_id: "Snowflake",
|
||||||
|
) -> "ApplicationCommand":
|
||||||
|
"""Fetches a specific guild command."""
|
||||||
|
data = await self.request(
|
||||||
|
"GET",
|
||||||
|
f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand
|
||||||
|
|
||||||
|
return ApplicationCommand(data)
|
||||||
|
|
||||||
|
async def edit_guild_application_command(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
guild_id: "Snowflake",
|
||||||
|
command_id: "Snowflake",
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
) -> "ApplicationCommand":
|
||||||
|
"""Edits a specific guild command."""
|
||||||
|
data = await self.request(
|
||||||
|
"PATCH",
|
||||||
|
f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand
|
||||||
|
|
||||||
|
return ApplicationCommand(data)
|
||||||
|
|
||||||
|
async def delete_guild_application_command(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
guild_id: "Snowflake",
|
||||||
|
command_id: "Snowflake",
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a specific guild command."""
|
||||||
|
await self.request(
|
||||||
|
"DELETE",
|
||||||
|
f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def bulk_overwrite_guild_application_commands(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
guild_id: "Snowflake",
|
||||||
|
payload: List[Dict[str, Any]],
|
||||||
|
) -> List["ApplicationCommand"]:
|
||||||
|
"""Bulk overwrites all commands for your application for a specific guild."""
|
||||||
|
data = await self.request(
|
||||||
|
"PUT",
|
||||||
|
f"/applications/{application_id}/guilds/{guild_id}/commands",
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
from .interactions import ApplicationCommand
|
||||||
|
|
||||||
|
return [ApplicationCommand(cmd_data) for cmd_data in data]
|
||||||
|
|
||||||
|
# --- Interaction Response Endpoints ---
|
||||||
|
# Note: These methods return Dict[str, Any] representing the Message data.
|
||||||
|
# The caller (e.g., AppCommandHandler) will be responsible for constructing Message models
|
||||||
|
# if needed, as Message model instantiation requires a `client_instance`.
|
||||||
|
|
||||||
|
async def create_interaction_response(
|
||||||
|
self,
|
||||||
|
interaction_id: "Snowflake",
|
||||||
|
interaction_token: str,
|
||||||
|
payload: "InteractionResponsePayload",
|
||||||
|
*,
|
||||||
|
ephemeral: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Creates a response to an Interaction.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ephemeral: bool
|
||||||
|
Ignored parameter for test compatibility.
|
||||||
|
"""
|
||||||
|
# Interaction responses do not use the bot token in the Authorization header.
|
||||||
|
# They are authenticated by the interaction_token in the URL.
|
||||||
|
await self.request(
|
||||||
|
"POST",
|
||||||
|
f"/interactions/{interaction_id}/{interaction_token}/callback",
|
||||||
|
payload=payload.to_dict(),
|
||||||
|
use_auth_header=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_original_interaction_response(
|
||||||
|
self, application_id: "Snowflake", interaction_token: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Gets the initial Interaction response."""
|
||||||
|
# This endpoint uses the bot token for auth.
|
||||||
|
return await self.request(
|
||||||
|
"GET", f"/webhooks/{application_id}/{interaction_token}/messages/@original"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def edit_original_interaction_response(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
interaction_token: str,
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Edits the initial Interaction response."""
|
||||||
|
return await self.request(
|
||||||
|
"PATCH",
|
||||||
|
f"/webhooks/{application_id}/{interaction_token}/messages/@original",
|
||||||
|
payload=payload,
|
||||||
|
use_auth_header=False,
|
||||||
|
) # Docs imply webhook-style auth
|
||||||
|
|
||||||
|
async def delete_original_interaction_response(
|
||||||
|
self, application_id: "Snowflake", interaction_token: str
|
||||||
|
) -> None:
|
||||||
|
"""Deletes the initial Interaction response."""
|
||||||
|
await self.request(
|
||||||
|
"DELETE",
|
||||||
|
f"/webhooks/{application_id}/{interaction_token}/messages/@original",
|
||||||
|
use_auth_header=False,
|
||||||
|
) # Docs imply webhook-style auth
|
||||||
|
|
||||||
|
async def create_followup_message(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
interaction_token: str,
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Creates a followup message for an Interaction."""
|
||||||
|
# Followup messages are sent to a webhook endpoint.
|
||||||
|
return await self.request(
|
||||||
|
"POST",
|
||||||
|
f"/webhooks/{application_id}/{interaction_token}",
|
||||||
|
payload=payload,
|
||||||
|
use_auth_header=False,
|
||||||
|
) # Docs imply webhook-style auth
|
||||||
|
|
||||||
|
async def edit_followup_message(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
interaction_token: str,
|
||||||
|
message_id: "Snowflake",
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Edits a followup message for an Interaction."""
|
||||||
|
return await self.request(
|
||||||
|
"PATCH",
|
||||||
|
f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}",
|
||||||
|
payload=payload,
|
||||||
|
use_auth_header=False,
|
||||||
|
) # Docs imply webhook-style auth
|
||||||
|
|
||||||
|
async def delete_followup_message(
|
||||||
|
self,
|
||||||
|
application_id: "Snowflake",
|
||||||
|
interaction_token: str,
|
||||||
|
message_id: "Snowflake",
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a followup message for an Interaction."""
|
||||||
|
await self.request(
|
||||||
|
"DELETE",
|
||||||
|
f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}",
|
||||||
|
use_auth_header=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def trigger_typing(self, channel_id: str) -> None:
|
||||||
|
"""Sends a typing indicator to the specified channel."""
|
||||||
|
await self.request("POST", f"/channels/{channel_id}/typing")
|
32
disagreement/hybrid_context.py
Normal file
32
disagreement/hybrid_context.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
"""Utility class for working with either command or app contexts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from .ext.commands.core import CommandContext
|
||||||
|
from .ext.app_commands.context import AppCommandContext
|
||||||
|
|
||||||
|
|
||||||
|
class HybridContext:
|
||||||
|
"""Wraps :class:`CommandContext` and :class:`AppCommandContext`.
|
||||||
|
|
||||||
|
Provides a single :meth:`send` method that proxies to ``reply`` for
|
||||||
|
prefix commands and to ``send`` for slash commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ctx: Union[CommandContext, AppCommandContext]):
|
||||||
|
self._ctx = ctx
|
||||||
|
|
||||||
|
async def send(self, *args: Any, **kwargs: Any):
|
||||||
|
if isinstance(self._ctx, AppCommandContext):
|
||||||
|
return await self._ctx.send(*args, **kwargs)
|
||||||
|
return await self._ctx.reply(*args, **kwargs)
|
||||||
|
|
||||||
|
async def edit(self, *args: Any, **kwargs: Any):
|
||||||
|
if hasattr(self._ctx, "edit"):
|
||||||
|
return await self._ctx.edit(*args, **kwargs)
|
||||||
|
raise AttributeError("Underlying context does not support editing.")
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
return getattr(self._ctx, name)
|
22
disagreement/i18n.py
Normal file
22
disagreement/i18n.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import json
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
_translations: Dict[str, Dict[str, str]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def set_translations(locale: str, mapping: Dict[str, str]) -> None:
|
||||||
|
"""Set translations for a locale."""
|
||||||
|
_translations[locale] = mapping
|
||||||
|
|
||||||
|
|
||||||
|
def load_translations(locale: str, file_path: str) -> None:
|
||||||
|
"""Load translations for *locale* from a JSON file."""
|
||||||
|
with open(file_path, "r", encoding="utf-8") as handle:
|
||||||
|
_translations[locale] = json.load(handle)
|
||||||
|
|
||||||
|
|
||||||
|
def translate(key: str, locale: str, *, default: Optional[str] = None) -> str:
|
||||||
|
"""Return the translated string for *key* in *locale*."""
|
||||||
|
return _translations.get(locale, {}).get(
|
||||||
|
key, default if default is not None else key
|
||||||
|
)
|
572
disagreement/interactions.py
Normal file
572
disagreement/interactions.py
Normal file
@ -0,0 +1,572 @@
|
|||||||
|
# disagreement/interactions.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Data models for Discord Interaction objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Union, Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .enums import (
|
||||||
|
ApplicationCommandType,
|
||||||
|
ApplicationCommandOptionType,
|
||||||
|
InteractionType,
|
||||||
|
InteractionCallbackType,
|
||||||
|
IntegrationType,
|
||||||
|
InteractionContextType,
|
||||||
|
ChannelType,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Runtime imports for models used in this module
|
||||||
|
from .models import (
|
||||||
|
User,
|
||||||
|
Message,
|
||||||
|
Member,
|
||||||
|
Role,
|
||||||
|
Embed,
|
||||||
|
PartialChannel,
|
||||||
|
Attachment,
|
||||||
|
ActionRow,
|
||||||
|
Component,
|
||||||
|
AllowedMentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Import Client type only for type checking to avoid circular imports
|
||||||
|
from .client import Client
|
||||||
|
from .ui.modal import Modal
|
||||||
|
|
||||||
|
# MessageFlags, PartialAttachment can be added if/when defined
|
||||||
|
|
||||||
|
Snowflake = str
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Application Command Option Choice Structure
|
||||||
|
class ApplicationCommandOptionChoice:
|
||||||
|
"""Represents a choice for an application command option."""
|
||||||
|
|
||||||
|
def __init__(self, data: dict):
|
||||||
|
self.name: str = data["name"]
|
||||||
|
self.value: Union[str, int, float] = data["value"]
|
||||||
|
self.name_localizations: Optional[Dict[str, str]] = data.get(
|
||||||
|
"name_localizations"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<ApplicationCommandOptionChoice name='{self.name}' value={self.value!r}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
payload: Dict[str, Any] = {"name": self.name, "value": self.value}
|
||||||
|
if self.name_localizations:
|
||||||
|
payload["name_localizations"] = self.name_localizations
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Application Command Option Structure
|
||||||
|
class ApplicationCommandOption:
|
||||||
|
"""Represents an option for an application command."""
|
||||||
|
|
||||||
|
def __init__(self, data: dict):
|
||||||
|
self.type: ApplicationCommandOptionType = ApplicationCommandOptionType(
|
||||||
|
data["type"]
|
||||||
|
)
|
||||||
|
self.name: str = data["name"]
|
||||||
|
self.description: str = data["description"]
|
||||||
|
self.required: bool = data.get("required", False)
|
||||||
|
|
||||||
|
self.choices: Optional[List[ApplicationCommandOptionChoice]] = (
|
||||||
|
[ApplicationCommandOptionChoice(c) for c in data["choices"]]
|
||||||
|
if data.get("choices")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.options: Optional[List["ApplicationCommandOption"]] = (
|
||||||
|
[ApplicationCommandOption(o) for o in data["options"]]
|
||||||
|
if data.get("options")
|
||||||
|
else None
|
||||||
|
) # For subcommands/groups
|
||||||
|
|
||||||
|
self.channel_types: Optional[List[ChannelType]] = (
|
||||||
|
[ChannelType(ct) for ct in data.get("channel_types", [])]
|
||||||
|
if data.get("channel_types")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.min_value: Optional[Union[int, float]] = data.get("min_value")
|
||||||
|
self.max_value: Optional[Union[int, float]] = data.get("max_value")
|
||||||
|
self.min_length: Optional[int] = data.get("min_length")
|
||||||
|
self.max_length: Optional[int] = data.get("max_length")
|
||||||
|
self.autocomplete: bool = data.get("autocomplete", False)
|
||||||
|
self.name_localizations: Optional[Dict[str, str]] = data.get(
|
||||||
|
"name_localizations"
|
||||||
|
)
|
||||||
|
self.description_localizations: Optional[Dict[str, str]] = data.get(
|
||||||
|
"description_localizations"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<ApplicationCommandOption name='{self.name}' type={self.type!r}>"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
}
|
||||||
|
if self.required: # Defaults to False, only include if True
|
||||||
|
payload["required"] = self.required
|
||||||
|
if self.choices:
|
||||||
|
payload["choices"] = [c.to_dict() for c in self.choices]
|
||||||
|
if self.options: # For subcommands/groups
|
||||||
|
payload["options"] = [o.to_dict() for o in self.options]
|
||||||
|
if self.channel_types:
|
||||||
|
payload["channel_types"] = [ct.value for ct in self.channel_types]
|
||||||
|
if self.min_value is not None:
|
||||||
|
payload["min_value"] = self.min_value
|
||||||
|
if self.max_value is not None:
|
||||||
|
payload["max_value"] = self.max_value
|
||||||
|
if self.min_length is not None:
|
||||||
|
payload["min_length"] = self.min_length
|
||||||
|
if self.max_length is not None:
|
||||||
|
payload["max_length"] = self.max_length
|
||||||
|
if self.autocomplete: # Defaults to False, only include if True
|
||||||
|
payload["autocomplete"] = self.autocomplete
|
||||||
|
if self.name_localizations:
|
||||||
|
payload["name_localizations"] = self.name_localizations
|
||||||
|
if self.description_localizations:
|
||||||
|
payload["description_localizations"] = self.description_localizations
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Application Command Structure
|
||||||
|
class ApplicationCommand:
|
||||||
|
"""Represents an application command."""
|
||||||
|
|
||||||
|
def __init__(self, data: dict):
|
||||||
|
self.id: Optional[Snowflake] = data.get("id")
|
||||||
|
self.type: ApplicationCommandType = ApplicationCommandType(
|
||||||
|
data.get("type", 1)
|
||||||
|
) # Default to CHAT_INPUT
|
||||||
|
self.application_id: Optional[Snowflake] = data.get("application_id")
|
||||||
|
self.guild_id: Optional[Snowflake] = data.get("guild_id")
|
||||||
|
self.name: str = data["name"]
|
||||||
|
self.description: str = data.get(
|
||||||
|
"description", ""
|
||||||
|
) # Empty for USER/MESSAGE commands
|
||||||
|
|
||||||
|
self.options: Optional[List[ApplicationCommandOption]] = (
|
||||||
|
[ApplicationCommandOption(o) for o in data["options"]]
|
||||||
|
if data.get("options")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.default_member_permissions: Optional[str] = data.get(
|
||||||
|
"default_member_permissions"
|
||||||
|
)
|
||||||
|
self.dm_permission: Optional[bool] = data.get("dm_permission") # Deprecated
|
||||||
|
self.nsfw: bool = data.get("nsfw", False)
|
||||||
|
self.version: Optional[Snowflake] = data.get("version")
|
||||||
|
self.name_localizations: Optional[Dict[str, str]] = data.get(
|
||||||
|
"name_localizations"
|
||||||
|
)
|
||||||
|
self.description_localizations: Optional[Dict[str, str]] = data.get(
|
||||||
|
"description_localizations"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.integration_types: Optional[List[IntegrationType]] = (
|
||||||
|
[IntegrationType(it) for it in data["integration_types"]]
|
||||||
|
if data.get("integration_types")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.contexts: Optional[List[InteractionContextType]] = (
|
||||||
|
[InteractionContextType(c) for c in data["contexts"]]
|
||||||
|
if data.get("contexts")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<ApplicationCommand id='{self.id}' name='{self.name}' type={self.type!r}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Interaction Object's Resolved Data Structure
|
||||||
|
class ResolvedData:
|
||||||
|
"""Represents resolved data for an interaction."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, data: dict, client_instance: Optional["Client"] = None
|
||||||
|
): # client_instance for model hydration
|
||||||
|
# Models are now imported in TYPE_CHECKING block
|
||||||
|
|
||||||
|
users_data = data.get("users", {})
|
||||||
|
self.users: Dict[Snowflake, "User"] = {
|
||||||
|
uid: User(udata) for uid, udata in users_data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.members: Dict[Snowflake, "Member"] = {}
|
||||||
|
for mid, mdata in data.get("members", {}).items():
|
||||||
|
member_payload = dict(mdata)
|
||||||
|
member_payload.setdefault("id", mid)
|
||||||
|
if "user" not in member_payload and mid in users_data:
|
||||||
|
member_payload["user"] = users_data[mid]
|
||||||
|
self.members[mid] = Member(member_payload, client_instance=client_instance)
|
||||||
|
|
||||||
|
self.roles: Dict[Snowflake, "Role"] = {
|
||||||
|
rid: Role(rdata) for rid, rdata in data.get("roles", {}).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.channels: Dict[Snowflake, "PartialChannel"] = {
|
||||||
|
cid: PartialChannel(cdata, client_instance=client_instance)
|
||||||
|
for cid, cdata in data.get("channels", {}).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.messages: Dict[Snowflake, "Message"] = (
|
||||||
|
{
|
||||||
|
mid: Message(mdata, client_instance=client_instance) for mid, mdata in data.get("messages", {}).items() # type: ignore[misc]
|
||||||
|
}
|
||||||
|
if client_instance
|
||||||
|
else {}
|
||||||
|
) # Only hydrate if client is available
|
||||||
|
|
||||||
|
self.attachments: Dict[Snowflake, "Attachment"] = {
|
||||||
|
aid: Attachment(adata) for aid, adata in data.get("attachments", {}).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<ResolvedData users={len(self.users)} members={len(self.members)} roles={len(self.roles)} channels={len(self.channels)} messages={len(self.messages)} attachments={len(self.attachments)}>"
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Interaction Object's Data Structure (for Application Commands)
|
||||||
|
class InteractionData:
|
||||||
|
"""Represents the data payload for an interaction."""
|
||||||
|
|
||||||
|
def __init__(self, data: dict, client_instance: Optional["Client"] = None):
|
||||||
|
self.id: Optional[Snowflake] = data.get("id") # Command ID
|
||||||
|
self.name: Optional[str] = data.get("name") # Command name
|
||||||
|
self.type: Optional[ApplicationCommandType] = (
|
||||||
|
ApplicationCommandType(data["type"]) if data.get("type") else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resolved: Optional[ResolvedData] = (
|
||||||
|
ResolvedData(data["resolved"], client_instance=client_instance)
|
||||||
|
if data.get("resolved")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# For CHAT_INPUT, this is List[ApplicationCommandInteractionDataOption]
|
||||||
|
# For USER/MESSAGE, this is not present or different.
|
||||||
|
# For now, storing as raw list of dicts. Parsing can happen in handler.
|
||||||
|
self.options: Optional[List[Dict[str, Any]]] = data.get("options")
|
||||||
|
|
||||||
|
# For message components
|
||||||
|
self.custom_id: Optional[str] = data.get("custom_id")
|
||||||
|
self.component_type: Optional[int] = data.get("component_type")
|
||||||
|
self.values: Optional[List[str]] = data.get("values")
|
||||||
|
|
||||||
|
self.guild_id: Optional[Snowflake] = data.get("guild_id")
|
||||||
|
self.target_id: Optional[Snowflake] = data.get(
|
||||||
|
"target_id"
|
||||||
|
) # For USER/MESSAGE commands
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<InteractionData id='{self.id}' name='{self.name}' type={self.type!r}>"
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Interaction Object Structure
|
||||||
|
class Interaction:
|
||||||
|
"""Represents an interaction from Discord."""
|
||||||
|
|
||||||
|
def __init__(self, data: dict, client_instance: "Client"):
|
||||||
|
self._client: "Client" = client_instance
|
||||||
|
|
||||||
|
self.id: Snowflake = data["id"]
|
||||||
|
self.application_id: Snowflake = data["application_id"]
|
||||||
|
self.type: InteractionType = InteractionType(data["type"])
|
||||||
|
|
||||||
|
self.data: Optional[InteractionData] = (
|
||||||
|
InteractionData(data["data"], client_instance=client_instance)
|
||||||
|
if data.get("data")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.guild_id: Optional[Snowflake] = data.get("guild_id")
|
||||||
|
self.channel_id: Optional[Snowflake] = data.get(
|
||||||
|
"channel_id"
|
||||||
|
) # Will be present on command invocations
|
||||||
|
|
||||||
|
member_data = data.get("member")
|
||||||
|
user_data_from_member = (
|
||||||
|
member_data.get("user") if isinstance(member_data, dict) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.member: Optional["Member"] = (
|
||||||
|
Member(member_data, client_instance=self._client) if member_data else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# User object is included within member if in guild, otherwise it's top-level
|
||||||
|
# If self.member was successfully hydrated, its .user attribute should be preferred if it exists.
|
||||||
|
# However, Member.__init__ handles setting User attributes.
|
||||||
|
# The primary source for User is data.get("user") or member_data.get("user").
|
||||||
|
|
||||||
|
if data.get("user"):
|
||||||
|
self.user: Optional["User"] = User(data["user"])
|
||||||
|
elif user_data_from_member:
|
||||||
|
self.user: Optional["User"] = User(user_data_from_member)
|
||||||
|
elif (
|
||||||
|
self.member
|
||||||
|
): # If member was hydrated and has user attributes (e.g. from Member(User) inheritance)
|
||||||
|
# This assumes Member correctly populates its User parts.
|
||||||
|
self.user: Optional["User"] = self.member # Member is a User subclass
|
||||||
|
else:
|
||||||
|
self.user: Optional["User"] = None
|
||||||
|
|
||||||
|
self.token: str = data["token"] # For responding to the interaction
|
||||||
|
self.version: int = data["version"]
|
||||||
|
|
||||||
|
self.message: Optional["Message"] = (
|
||||||
|
Message(data["message"], client_instance=client_instance)
|
||||||
|
if data.get("message")
|
||||||
|
else None
|
||||||
|
) # For component interactions
|
||||||
|
|
||||||
|
self.app_permissions: Optional[str] = data.get(
|
||||||
|
"app_permissions"
|
||||||
|
) # Bitwise set of permissions the app has in the source channel
|
||||||
|
self.locale: Optional[str] = data.get(
|
||||||
|
"locale"
|
||||||
|
) # Selected language of the invoking user
|
||||||
|
self.guild_locale: Optional[str] = data.get(
|
||||||
|
"guild_locale"
|
||||||
|
) # Guild's preferred language
|
||||||
|
|
||||||
|
self.response = InteractionResponse(self)
|
||||||
|
|
||||||
|
async def respond(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
embed: Optional[Embed] = None,
|
||||||
|
embeds: Optional[List[Embed]] = None,
|
||||||
|
components: Optional[List[ActionRow]] = None,
|
||||||
|
ephemeral: bool = False,
|
||||||
|
tts: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
|
Responds to this interaction.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
content (Optional[str]): The content of the message.
|
||||||
|
embed (Optional[Embed]): A single embed to send.
|
||||||
|
embeds (Optional[List[Embed]]): A list of embeds to send.
|
||||||
|
components (Optional[List[ActionRow]]): A list of ActionRow components.
|
||||||
|
ephemeral (bool): Whether the response should be ephemeral (only visible to the user).
|
||||||
|
tts (bool): Whether the message should be sent with text-to-speech.
|
||||||
|
"""
|
||||||
|
if embed and embeds:
|
||||||
|
raise ValueError("Cannot provide both embed and embeds.")
|
||||||
|
|
||||||
|
data: Dict[str, Any] = {}
|
||||||
|
if tts:
|
||||||
|
data["tts"] = True
|
||||||
|
if content:
|
||||||
|
data["content"] = content
|
||||||
|
if embed:
|
||||||
|
data["embeds"] = [embed.to_dict()]
|
||||||
|
elif embeds:
|
||||||
|
data["embeds"] = [e.to_dict() for e in embeds]
|
||||||
|
if components:
|
||||||
|
data["components"] = [c.to_dict() for c in components]
|
||||||
|
if ephemeral:
|
||||||
|
data["flags"] = 1 << 6 # EPHEMERAL flag
|
||||||
|
|
||||||
|
payload = InteractionResponsePayload(
|
||||||
|
type=InteractionCallbackType.CHANNEL_MESSAGE_WITH_SOURCE,
|
||||||
|
data=InteractionCallbackData(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._client._http.create_interaction_response(
|
||||||
|
interaction_id=self.id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def respond_modal(self, modal: "Modal") -> None:
|
||||||
|
"""|coro| Send a modal in response to this interaction."""
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"type": InteractionCallbackType.MODAL.value,
|
||||||
|
"data": modal.to_dict(),
|
||||||
|
}
|
||||||
|
await self._client._http.create_interaction_response(
|
||||||
|
interaction_id=self.id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
payload=cast(Any, payload),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def edit(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
embed: Optional[Embed] = None,
|
||||||
|
embeds: Optional[List[Embed]] = None,
|
||||||
|
components: Optional[List[ActionRow]] = None,
|
||||||
|
attachments: Optional[List[Any]] = None,
|
||||||
|
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
|
Edits the original response to this interaction.
|
||||||
|
|
||||||
|
If the interaction is from a component, this will acknowledge the
|
||||||
|
interaction and update the message in one operation.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
content (Optional[str]): The new message content.
|
||||||
|
embed (Optional[Embed]): A single embed to send. Ignored if
|
||||||
|
``embeds`` is provided.
|
||||||
|
embeds (Optional[List[Embed]]): A list of embeds to send.
|
||||||
|
components (Optional[List[ActionRow]]): Updated components for the
|
||||||
|
message.
|
||||||
|
attachments (Optional[List[Any]]): Attachments to include with the
|
||||||
|
message.
|
||||||
|
allowed_mentions (Optional[Dict[str, Any]]): Controls mentions in the
|
||||||
|
message.
|
||||||
|
"""
|
||||||
|
if embed and embeds:
|
||||||
|
raise ValueError("Cannot provide both embed and embeds.")
|
||||||
|
|
||||||
|
payload_data: Dict[str, Any] = {}
|
||||||
|
if content is not None:
|
||||||
|
payload_data["content"] = content
|
||||||
|
if embed:
|
||||||
|
payload_data["embeds"] = [embed.to_dict()]
|
||||||
|
elif embeds is not None:
|
||||||
|
payload_data["embeds"] = [e.to_dict() for e in embeds]
|
||||||
|
if components is not None:
|
||||||
|
payload_data["components"] = [c.to_dict() for c in components]
|
||||||
|
if attachments is not None:
|
||||||
|
payload_data["attachments"] = [
|
||||||
|
a.to_dict() if hasattr(a, "to_dict") else a for a in attachments
|
||||||
|
]
|
||||||
|
if allowed_mentions is not None:
|
||||||
|
payload_data["allowed_mentions"] = allowed_mentions
|
||||||
|
|
||||||
|
if self.type == InteractionType.MESSAGE_COMPONENT:
|
||||||
|
# For component interactions, we send an UPDATE_MESSAGE response
|
||||||
|
# to acknowledge the interaction and edit the message simultaneously.
|
||||||
|
payload = InteractionResponsePayload(
|
||||||
|
type=InteractionCallbackType.UPDATE_MESSAGE,
|
||||||
|
data=InteractionCallbackData(payload_data),
|
||||||
|
)
|
||||||
|
await self._client._http.create_interaction_response(
|
||||||
|
self.id, self.token, payload
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For other interaction types (like an initial slash command response),
|
||||||
|
# we edit the original response via the webhook endpoint.
|
||||||
|
await self._client._http.edit_original_interaction_response(
|
||||||
|
application_id=self.application_id,
|
||||||
|
interaction_token=self.token,
|
||||||
|
payload=payload_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Interaction id='{self.id}' type={self.type!r}>"
|
||||||
|
|
||||||
|
|
||||||
|
class InteractionResponse:
|
||||||
|
"""Helper for sending responses for an :class:`Interaction`."""
|
||||||
|
|
||||||
|
def __init__(self, interaction: "Interaction") -> None:
|
||||||
|
self._interaction = interaction
|
||||||
|
|
||||||
|
async def send_modal(self, modal: "Modal") -> None:
|
||||||
|
"""Sends a modal response."""
|
||||||
|
payload = InteractionResponsePayload(
|
||||||
|
type=InteractionCallbackType.MODAL,
|
||||||
|
data=InteractionCallbackData(modal.to_dict()),
|
||||||
|
)
|
||||||
|
await self._interaction._client._http.create_interaction_response(
|
||||||
|
self._interaction.id,
|
||||||
|
self._interaction.token,
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Interaction Response Object's Data Structure
|
||||||
|
class InteractionCallbackData:
|
||||||
|
"""Data for an interaction response."""
|
||||||
|
|
||||||
|
def __init__(self, data: dict):
|
||||||
|
self.tts: Optional[bool] = data.get("tts")
|
||||||
|
self.content: Optional[str] = data.get("content")
|
||||||
|
self.embeds: Optional[List[Embed]] = (
|
||||||
|
[Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None
|
||||||
|
)
|
||||||
|
self.allowed_mentions: Optional[AllowedMentions] = (
|
||||||
|
AllowedMentions(data["allowed_mentions"])
|
||||||
|
if data.get("allowed_mentions")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.flags: Optional[int] = data.get("flags") # MessageFlags enum could be used
|
||||||
|
from .components import component_factory
|
||||||
|
|
||||||
|
self.components: Optional[List[Component]] = (
|
||||||
|
[component_factory(c) for c in data.get("components", [])]
|
||||||
|
if data.get("components")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.attachments: Optional[List[Attachment]] = (
|
||||||
|
[Attachment(a) for a in data.get("attachments", [])]
|
||||||
|
if data.get("attachments")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
# Helper to convert to dict for sending to Discord API
|
||||||
|
payload = {}
|
||||||
|
if self.tts is not None:
|
||||||
|
payload["tts"] = self.tts
|
||||||
|
if self.content is not None:
|
||||||
|
payload["content"] = self.content
|
||||||
|
if self.embeds is not None:
|
||||||
|
payload["embeds"] = [e.to_dict() for e in self.embeds]
|
||||||
|
if self.allowed_mentions is not None:
|
||||||
|
payload["allowed_mentions"] = self.allowed_mentions.to_dict()
|
||||||
|
if self.flags is not None:
|
||||||
|
payload["flags"] = self.flags
|
||||||
|
if self.components is not None:
|
||||||
|
payload["components"] = [c.to_dict() for c in self.components]
|
||||||
|
if self.attachments is not None:
|
||||||
|
payload["attachments"] = [a.to_dict() for a in self.attachments]
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<InteractionCallbackData content='{self.content[:20] if self.content else None}'>"
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Interaction Response Object Structure
|
||||||
|
class InteractionResponsePayload:
|
||||||
|
"""Payload for responding to an interaction."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
type: InteractionCallbackType,
|
||||||
|
data: Optional[InteractionCallbackData] = None,
|
||||||
|
):
|
||||||
|
self.type: InteractionCallbackType = type
|
||||||
|
self.data: Optional[InteractionCallbackData] = data
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
payload: Dict[str, Any] = {"type": self.type.value}
|
||||||
|
if self.data:
|
||||||
|
payload["data"] = self.data.to_dict()
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<InteractionResponsePayload type={self.type!r}>"
|
26
disagreement/logging_config.py
Normal file
26
disagreement/logging_config.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(level: int, file: Optional[str] = None) -> None:
|
||||||
|
"""Configure logging for the library.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
level:
|
||||||
|
Logging level from the :mod:`logging` module.
|
||||||
|
file:
|
||||||
|
Optional file path to write logs to. If ``None``, logs are sent to
|
||||||
|
standard output.
|
||||||
|
"""
|
||||||
|
handlers: list[logging.Handler] = []
|
||||||
|
if file is None:
|
||||||
|
handlers.append(logging.StreamHandler())
|
||||||
|
else:
|
||||||
|
handlers.append(logging.FileHandler(file))
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
handlers=handlers,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
1642
disagreement/models.py
Normal file
1642
disagreement/models.py
Normal file
File diff suppressed because it is too large
Load Diff
109
disagreement/oauth.py
Normal file
109
disagreement/oauth.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
"""OAuth2 utilities."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from typing import List, Optional, Dict, Any, Union
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
from .errors import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def build_authorization_url(
|
||||||
|
client_id: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
scope: Union[str, List[str]],
|
||||||
|
*,
|
||||||
|
state: Optional[str] = None,
|
||||||
|
response_type: str = "code",
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Return the Discord OAuth2 authorization URL."""
|
||||||
|
if isinstance(scope, list):
|
||||||
|
scope = " ".join(scope)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"response_type": response_type,
|
||||||
|
"scope": scope,
|
||||||
|
}
|
||||||
|
if state is not None:
|
||||||
|
params["state"] = state
|
||||||
|
if prompt is not None:
|
||||||
|
params["prompt"] = prompt
|
||||||
|
|
||||||
|
return "https://discord.com/oauth2/authorize?" + urlencode(params)
|
||||||
|
|
||||||
|
|
||||||
|
async def exchange_code_for_token(
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
code: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
*,
|
||||||
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Exchange an authorization code for an access token."""
|
||||||
|
close = False
|
||||||
|
if session is None:
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
close = True
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = await session.post(
|
||||||
|
"https://discord.com/api/v10/oauth2/token",
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
json_data = await resp.json()
|
||||||
|
if resp.status != 200:
|
||||||
|
raise HTTPException(resp, message="OAuth token exchange failed")
|
||||||
|
finally:
|
||||||
|
if close:
|
||||||
|
await session.close()
|
||||||
|
return json_data
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_access_token(
|
||||||
|
refresh_token: str,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
*,
|
||||||
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Refresh an access token using a refresh token."""
|
||||||
|
|
||||||
|
close = False
|
||||||
|
if session is None:
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
close = True
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = await session.post(
|
||||||
|
"https://discord.com/api/v10/oauth2/token",
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
json_data = await resp.json()
|
||||||
|
if resp.status != 200:
|
||||||
|
raise HTTPException(resp, message="OAuth token refresh failed")
|
||||||
|
finally:
|
||||||
|
if close:
|
||||||
|
await session.close()
|
||||||
|
return json_data
|
99
disagreement/permissions.py
Normal file
99
disagreement/permissions.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
"""Utility helpers for working with Discord permission bitmasks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import IntFlag
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
|
||||||
|
class Permissions(IntFlag):
|
||||||
|
"""Discord guild and channel permissions."""
|
||||||
|
|
||||||
|
CREATE_INSTANT_INVITE = 1 << 0
|
||||||
|
KICK_MEMBERS = 1 << 1
|
||||||
|
BAN_MEMBERS = 1 << 2
|
||||||
|
ADMINISTRATOR = 1 << 3
|
||||||
|
MANAGE_CHANNELS = 1 << 4
|
||||||
|
MANAGE_GUILD = 1 << 5
|
||||||
|
ADD_REACTIONS = 1 << 6
|
||||||
|
VIEW_AUDIT_LOG = 1 << 7
|
||||||
|
PRIORITY_SPEAKER = 1 << 8
|
||||||
|
STREAM = 1 << 9
|
||||||
|
VIEW_CHANNEL = 1 << 10
|
||||||
|
SEND_MESSAGES = 1 << 11
|
||||||
|
SEND_TTS_MESSAGES = 1 << 12
|
||||||
|
MANAGE_MESSAGES = 1 << 13
|
||||||
|
EMBED_LINKS = 1 << 14
|
||||||
|
ATTACH_FILES = 1 << 15
|
||||||
|
READ_MESSAGE_HISTORY = 1 << 16
|
||||||
|
MENTION_EVERYONE = 1 << 17
|
||||||
|
USE_EXTERNAL_EMOJIS = 1 << 18
|
||||||
|
VIEW_GUILD_INSIGHTS = 1 << 19
|
||||||
|
CONNECT = 1 << 20
|
||||||
|
SPEAK = 1 << 21
|
||||||
|
MUTE_MEMBERS = 1 << 22
|
||||||
|
DEAFEN_MEMBERS = 1 << 23
|
||||||
|
MOVE_MEMBERS = 1 << 24
|
||||||
|
USE_VAD = 1 << 25
|
||||||
|
CHANGE_NICKNAME = 1 << 26
|
||||||
|
MANAGE_NICKNAMES = 1 << 27
|
||||||
|
MANAGE_ROLES = 1 << 28
|
||||||
|
MANAGE_WEBHOOKS = 1 << 29
|
||||||
|
MANAGE_GUILD_EXPRESSIONS = 1 << 30
|
||||||
|
USE_APPLICATION_COMMANDS = 1 << 31
|
||||||
|
REQUEST_TO_SPEAK = 1 << 32
|
||||||
|
MANAGE_EVENTS = 1 << 33
|
||||||
|
MANAGE_THREADS = 1 << 34
|
||||||
|
CREATE_PUBLIC_THREADS = 1 << 35
|
||||||
|
CREATE_PRIVATE_THREADS = 1 << 36
|
||||||
|
USE_EXTERNAL_STICKERS = 1 << 37
|
||||||
|
SEND_MESSAGES_IN_THREADS = 1 << 38
|
||||||
|
USE_EMBEDDED_ACTIVITIES = 1 << 39
|
||||||
|
MODERATE_MEMBERS = 1 << 40
|
||||||
|
VIEW_CREATOR_MONETIZATION_ANALYTICS = 1 << 41
|
||||||
|
USE_SOUNDBOARD = 1 << 42
|
||||||
|
CREATE_GUILD_EXPRESSIONS = 1 << 43
|
||||||
|
CREATE_EVENTS = 1 << 44
|
||||||
|
USE_EXTERNAL_SOUNDS = 1 << 45
|
||||||
|
SEND_VOICE_MESSAGES = 1 << 46
|
||||||
|
|
||||||
|
|
||||||
|
def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int:
|
||||||
|
"""Return a combined integer value for multiple permissions."""
|
||||||
|
|
||||||
|
value = 0
|
||||||
|
for perm in perms:
|
||||||
|
if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)):
|
||||||
|
value |= permissions_value(*perm)
|
||||||
|
else:
|
||||||
|
value |= int(perm)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def has_permissions(
|
||||||
|
current: int | str | Permissions,
|
||||||
|
*perms: Permissions | int | Iterable[Permissions | int],
|
||||||
|
) -> bool:
|
||||||
|
"""Return ``True`` if ``current`` includes all ``perms``."""
|
||||||
|
|
||||||
|
current_val = int(current)
|
||||||
|
needed = permissions_value(*perms)
|
||||||
|
return (current_val & needed) == needed
|
||||||
|
|
||||||
|
|
||||||
|
def missing_permissions(
|
||||||
|
current: int | str | Permissions,
|
||||||
|
*perms: Permissions | int | Iterable[Permissions | int],
|
||||||
|
) -> List[Permissions]:
|
||||||
|
"""Return the subset of ``perms`` not present in ``current``."""
|
||||||
|
|
||||||
|
current_val = int(current)
|
||||||
|
missing: List[Permissions] = []
|
||||||
|
for perm in perms:
|
||||||
|
if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)):
|
||||||
|
missing.extend(missing_permissions(current_val, *perm))
|
||||||
|
else:
|
||||||
|
perm_val = int(perm)
|
||||||
|
if not current_val & perm_val:
|
||||||
|
missing.append(Permissions(perm_val))
|
||||||
|
return missing
|
75
disagreement/rate_limiter.py
Normal file
75
disagreement/rate_limiter.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
"""Asynchronous rate limiter for Discord HTTP requests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, Mapping
|
||||||
|
|
||||||
|
|
||||||
|
class _Bucket:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.remaining: int = 1
|
||||||
|
self.reset_at: float = 0.0
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Rate limiter implementing per-route buckets and a global queue."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._buckets: Dict[str, _Bucket] = {}
|
||||||
|
self._global_event = asyncio.Event()
|
||||||
|
self._global_event.set()
|
||||||
|
|
||||||
|
def _get_bucket(self, route: str) -> _Bucket:
|
||||||
|
bucket = self._buckets.get(route)
|
||||||
|
if bucket is None:
|
||||||
|
bucket = _Bucket()
|
||||||
|
self._buckets[route] = bucket
|
||||||
|
return bucket
|
||||||
|
|
||||||
|
async def acquire(self, route: str) -> _Bucket:
|
||||||
|
bucket = self._get_bucket(route)
|
||||||
|
while True:
|
||||||
|
await self._global_event.wait()
|
||||||
|
async with bucket.lock:
|
||||||
|
now = time.monotonic()
|
||||||
|
if bucket.remaining <= 0 and now < bucket.reset_at:
|
||||||
|
await asyncio.sleep(bucket.reset_at - now)
|
||||||
|
continue
|
||||||
|
if bucket.remaining > 0:
|
||||||
|
bucket.remaining -= 1
|
||||||
|
return bucket
|
||||||
|
|
||||||
|
def release(self, route: str, headers: Mapping[str, str]) -> None:
|
||||||
|
bucket = self._get_bucket(route)
|
||||||
|
try:
|
||||||
|
remaining = int(headers.get("X-RateLimit-Remaining", bucket.remaining))
|
||||||
|
reset_after = float(headers.get("X-RateLimit-Reset-After", "0"))
|
||||||
|
bucket.remaining = remaining
|
||||||
|
bucket.reset_at = time.monotonic() + reset_after
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if headers.get("X-RateLimit-Global", "false").lower() == "true":
|
||||||
|
retry_after = float(headers.get("Retry-After", "0"))
|
||||||
|
self._global_event.clear()
|
||||||
|
asyncio.create_task(self._lift_global(retry_after))
|
||||||
|
|
||||||
|
async def handle_rate_limit(
|
||||||
|
self, route: str, retry_after: float, is_global: bool
|
||||||
|
) -> None:
|
||||||
|
bucket = self._get_bucket(route)
|
||||||
|
bucket.remaining = 0
|
||||||
|
bucket.reset_at = time.monotonic() + retry_after
|
||||||
|
if is_global:
|
||||||
|
self._global_event.clear()
|
||||||
|
await asyncio.sleep(retry_after)
|
||||||
|
self._global_event.set()
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(retry_after)
|
||||||
|
|
||||||
|
async def _lift_global(self, delay: float) -> None:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
self._global_event.set()
|
65
disagreement/shard_manager.py
Normal file
65
disagreement/shard_manager.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# disagreement/shard_manager.py
|
||||||
|
|
||||||
|
"""Sharding utilities for managing multiple gateway connections."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import List, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .gateway import GatewayClient
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover - for type checking only
|
||||||
|
from .client import Client
|
||||||
|
|
||||||
|
|
||||||
|
class Shard:
|
||||||
|
"""Represents a single gateway shard."""
|
||||||
|
|
||||||
|
def __init__(self, shard_id: int, shard_count: int, gateway: GatewayClient) -> None:
|
||||||
|
self.id: int = shard_id
|
||||||
|
self.count: int = shard_count
|
||||||
|
self.gateway: GatewayClient = gateway
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Connects this shard's gateway."""
|
||||||
|
await self.gateway.connect()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Closes this shard's gateway."""
|
||||||
|
await self.gateway.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ShardManager:
|
||||||
|
"""Manages multiple :class:`Shard` instances."""
|
||||||
|
|
||||||
|
def __init__(self, client: "Client", shard_count: int) -> None:
|
||||||
|
self.client: "Client" = client
|
||||||
|
self.shard_count: int = shard_count
|
||||||
|
self.shards: List[Shard] = []
|
||||||
|
|
||||||
|
def _create_shards(self) -> None:
|
||||||
|
if self.shards:
|
||||||
|
return
|
||||||
|
for shard_id in range(self.shard_count):
|
||||||
|
gateway = GatewayClient(
|
||||||
|
http_client=self.client._http,
|
||||||
|
event_dispatcher=self.client._event_dispatcher,
|
||||||
|
token=self.client.token,
|
||||||
|
intents=self.client.intents,
|
||||||
|
client_instance=self.client,
|
||||||
|
verbose=self.client.verbose,
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_count=self.shard_count,
|
||||||
|
)
|
||||||
|
self.shards.append(Shard(shard_id, self.shard_count, gateway))
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Starts all shards."""
|
||||||
|
self._create_shards()
|
||||||
|
await asyncio.gather(*(s.connect() for s in self.shards))
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Closes all shards."""
|
||||||
|
await asyncio.gather(*(s.close() for s in self.shards))
|
||||||
|
self.shards.clear()
|
42
disagreement/typing.py
Normal file
42
disagreement/typing.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import asyncio
|
||||||
|
from contextlib import suppress
|
||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .errors import DisagreementException
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .client import Client
|
||||||
|
|
||||||
|
if __name__ == "typing":
|
||||||
|
# For direct module execution testing
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Typing:
|
||||||
|
"""Async context manager for Discord typing indicator."""
|
||||||
|
|
||||||
|
def __init__(self, client: "Client", channel_id: str) -> None:
|
||||||
|
self._client = client
|
||||||
|
self._channel_id = channel_id
|
||||||
|
self._task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def _run(self) -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await self._client._http.trigger_typing(self._channel_id)
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "Typing":
|
||||||
|
if self._client._closed:
|
||||||
|
raise DisagreementException("Client is closed.")
|
||||||
|
await self._client._http.trigger_typing(self._channel_id)
|
||||||
|
self._task = asyncio.create_task(self._run())
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await self._task
|
17
disagreement/ui/__init__.py
Normal file
17
disagreement/ui/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from .view import View
|
||||||
|
from .item import Item
|
||||||
|
from .button import Button, button
|
||||||
|
from .select import Select, select
|
||||||
|
from .modal import Modal, TextInput, text_input
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"View",
|
||||||
|
"Item",
|
||||||
|
"Button",
|
||||||
|
"button",
|
||||||
|
"Select",
|
||||||
|
"select",
|
||||||
|
"Modal",
|
||||||
|
"TextInput",
|
||||||
|
"text_input",
|
||||||
|
]
|
99
disagreement/ui/button.py
Normal file
99
disagreement/ui/button.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .item import Item
|
||||||
|
from ..enums import ComponentType, ButtonStyle
|
||||||
|
from ..models import PartialEmoji, to_partial_emoji
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..interactions import Interaction
|
||||||
|
|
||||||
|
|
||||||
|
class Button(Item):
|
||||||
|
"""Represents a button component in a View.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
style (ButtonStyle): The style of the button.
|
||||||
|
label (Optional[str]): The text that appears on the button.
|
||||||
|
emoji (Optional[str | PartialEmoji]): The emoji that appears on the button.
|
||||||
|
custom_id (Optional[str]): The developer-defined identifier for the button.
|
||||||
|
url (Optional[str]): The URL for the button.
|
||||||
|
disabled (bool): Whether the button is disabled.
|
||||||
|
row (Optional[int]): The row the button should be placed in, from 0 to 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
style: ButtonStyle = ButtonStyle.SECONDARY,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
emoji: Optional[str | PartialEmoji] = None,
|
||||||
|
custom_id: Optional[str] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
row: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(type=ComponentType.BUTTON)
|
||||||
|
if not label and not emoji:
|
||||||
|
raise ValueError("A button must have a label and/or an emoji.")
|
||||||
|
|
||||||
|
if url and custom_id:
|
||||||
|
raise ValueError("A button cannot have both a URL and a custom_id.")
|
||||||
|
|
||||||
|
self.style = style
|
||||||
|
self.label = label
|
||||||
|
self.emoji = to_partial_emoji(emoji)
|
||||||
|
self.custom_id = custom_id
|
||||||
|
self.url = url
|
||||||
|
self.disabled = disabled
|
||||||
|
self._row = row
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Converts the button to a dictionary that can be sent to Discord."""
|
||||||
|
payload = {
|
||||||
|
"type": ComponentType.BUTTON.value,
|
||||||
|
"style": self.style.value,
|
||||||
|
"disabled": self.disabled,
|
||||||
|
}
|
||||||
|
if self.label:
|
||||||
|
payload["label"] = self.label
|
||||||
|
if self.emoji:
|
||||||
|
payload["emoji"] = self.emoji.to_dict()
|
||||||
|
if self.url:
|
||||||
|
payload["url"] = self.url
|
||||||
|
if self.custom_id:
|
||||||
|
payload["custom_id"] = self.custom_id
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def button(
|
||||||
|
*,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
custom_id: Optional[str] = None,
|
||||||
|
style: ButtonStyle = ButtonStyle.SECONDARY,
|
||||||
|
emoji: Optional[str | PartialEmoji] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
row: Optional[int] = None,
|
||||||
|
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Button]:
|
||||||
|
"""A decorator to create a button in a View."""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Button:
|
||||||
|
if not asyncio.iscoroutinefunction(func):
|
||||||
|
raise TypeError("Button callback must be a coroutine function.")
|
||||||
|
|
||||||
|
item = Button(
|
||||||
|
label=label,
|
||||||
|
custom_id=custom_id,
|
||||||
|
style=style,
|
||||||
|
emoji=emoji,
|
||||||
|
url=url,
|
||||||
|
disabled=disabled,
|
||||||
|
row=row,
|
||||||
|
)
|
||||||
|
item.callback = func
|
||||||
|
return item
|
||||||
|
|
||||||
|
return decorator
|
38
disagreement/ui/item.py
Normal file
38
disagreement/ui/item.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from ..models import Component
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .view import View
|
||||||
|
from ..interactions import Interaction
|
||||||
|
|
||||||
|
|
||||||
|
class Item(Component):
|
||||||
|
"""Represents a UI item that can be placed in a View.
|
||||||
|
|
||||||
|
This is a base class and is not meant to be used directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._view: Optional[View] = None
|
||||||
|
self._row: Optional[int] = None
|
||||||
|
# This is the callback associated with this item.
|
||||||
|
self.callback: Optional[
|
||||||
|
Callable[["View", Interaction], Coroutine[Any, Any, Any]]
|
||||||
|
] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def view(self) -> Optional[View]:
|
||||||
|
return self._view
|
||||||
|
|
||||||
|
@property
|
||||||
|
def row(self) -> Optional[int]:
|
||||||
|
return self._row
|
||||||
|
|
||||||
|
def _refresh_from_data(self, data: dict[str, Any]) -> None:
|
||||||
|
# This is used to update the item's state from incoming interaction data.
|
||||||
|
# For example, a button's disabled state could be updated here.
|
||||||
|
pass
|
132
disagreement/ui/modal.py
Normal file
132
disagreement/ui/modal.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Callable, Coroutine, Optional, List, TYPE_CHECKING
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from .item import Item
|
||||||
|
from .view import View
|
||||||
|
from ..enums import ComponentType, TextInputStyle
|
||||||
|
from ..models import ActionRow
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover - for type hints only
|
||||||
|
from ..interactions import Interaction
|
||||||
|
|
||||||
|
|
||||||
|
class TextInput(Item):
|
||||||
|
"""Represents a text input component inside a modal."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
label: str,
|
||||||
|
custom_id: Optional[str] = None,
|
||||||
|
style: TextInputStyle = TextInputStyle.SHORT,
|
||||||
|
placeholder: Optional[str] = None,
|
||||||
|
required: bool = True,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
row: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(type=ComponentType.TEXT_INPUT)
|
||||||
|
self.label = label
|
||||||
|
self.custom_id = custom_id
|
||||||
|
self.style = style
|
||||||
|
self.placeholder = placeholder
|
||||||
|
self.required = required
|
||||||
|
self.min_length = min_length
|
||||||
|
self.max_length = max_length
|
||||||
|
self._row = row
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
payload = {
|
||||||
|
"type": ComponentType.TEXT_INPUT.value,
|
||||||
|
"style": self.style.value,
|
||||||
|
"label": self.label,
|
||||||
|
"required": self.required,
|
||||||
|
}
|
||||||
|
if self.custom_id:
|
||||||
|
payload["custom_id"] = self.custom_id
|
||||||
|
if self.placeholder:
|
||||||
|
payload["placeholder"] = self.placeholder
|
||||||
|
if self.min_length is not None:
|
||||||
|
payload["min_length"] = self.min_length
|
||||||
|
if self.max_length is not None:
|
||||||
|
payload["max_length"] = self.max_length
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def text_input(
|
||||||
|
*,
|
||||||
|
label: str,
|
||||||
|
custom_id: Optional[str] = None,
|
||||||
|
style: TextInputStyle = TextInputStyle.SHORT,
|
||||||
|
placeholder: Optional[str] = None,
|
||||||
|
required: bool = True,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
row: Optional[int] = None,
|
||||||
|
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], TextInput]:
|
||||||
|
"""Decorator to define a text input callback inside a :class:`Modal`."""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> TextInput:
|
||||||
|
if not asyncio.iscoroutinefunction(func):
|
||||||
|
raise TypeError("TextInput callback must be a coroutine function.")
|
||||||
|
|
||||||
|
item = TextInput(
|
||||||
|
label=label,
|
||||||
|
custom_id=custom_id,
|
||||||
|
style=style,
|
||||||
|
placeholder=placeholder,
|
||||||
|
required=required,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
|
row=row,
|
||||||
|
)
|
||||||
|
item.callback = func
|
||||||
|
return item
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class Modal:
|
||||||
|
"""Represents a modal dialog."""
|
||||||
|
|
||||||
|
def __init__(self, *, title: str, custom_id: str) -> None:
|
||||||
|
self.title = title
|
||||||
|
self.custom_id = custom_id
|
||||||
|
self._children: List[TextInput] = []
|
||||||
|
|
||||||
|
for item in self.__class__.__dict__.values():
|
||||||
|
if isinstance(item, TextInput):
|
||||||
|
self.add_item(item)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def children(self) -> List[TextInput]:
|
||||||
|
return self._children
|
||||||
|
|
||||||
|
def add_item(self, item: TextInput) -> None:
|
||||||
|
if not isinstance(item, TextInput):
|
||||||
|
raise TypeError("Only TextInput items can be added to a Modal.")
|
||||||
|
if len(self._children) >= 5:
|
||||||
|
raise ValueError("A modal can only have up to 5 text inputs.")
|
||||||
|
item._view = None # Not part of a view but reuse item base
|
||||||
|
self._children.append(item)
|
||||||
|
|
||||||
|
def to_components(self) -> List[ActionRow]:
|
||||||
|
rows: List[ActionRow] = []
|
||||||
|
for child in self.children:
|
||||||
|
row = ActionRow(components=[child])
|
||||||
|
rows.append(row)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"title": self.title,
|
||||||
|
"custom_id": self.custom_id,
|
||||||
|
"components": [r.to_dict() for r in self.to_components()],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def callback(
|
||||||
|
self, interaction: Interaction
|
||||||
|
) -> None: # pragma: no cover - default
|
||||||
|
pass
|
92
disagreement/ui/select.py
Normal file
92
disagreement/ui/select.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Callable, Coroutine, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .item import Item
|
||||||
|
from ..enums import ComponentType
|
||||||
|
from ..models import SelectOption
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..interactions import Interaction
|
||||||
|
|
||||||
|
|
||||||
|
class Select(Item):
|
||||||
|
"""Represents a select menu component in a View.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
custom_id (str): The developer-defined identifier for the select menu.
|
||||||
|
options (List[SelectOption]): The choices in the select menu.
|
||||||
|
placeholder (Optional[str]): The placeholder text that is shown if nothing is selected.
|
||||||
|
min_values (int): The minimum number of items that must be chosen.
|
||||||
|
max_values (int): The maximum number of items that can be chosen.
|
||||||
|
disabled (bool): Whether the select menu is disabled.
|
||||||
|
row (Optional[int]): The row the select menu should be placed in, from 0 to 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
custom_id: str,
|
||||||
|
options: List[SelectOption],
|
||||||
|
placeholder: Optional[str] = None,
|
||||||
|
min_values: int = 1,
|
||||||
|
max_values: int = 1,
|
||||||
|
disabled: bool = False,
|
||||||
|
row: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(type=ComponentType.STRING_SELECT)
|
||||||
|
self.custom_id = custom_id
|
||||||
|
self.options = options
|
||||||
|
self.placeholder = placeholder
|
||||||
|
self.min_values = min_values
|
||||||
|
self.max_values = max_values
|
||||||
|
self.disabled = disabled
|
||||||
|
self._row = row
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Converts the select menu to a dictionary that can be sent to Discord."""
|
||||||
|
payload = {
|
||||||
|
"type": ComponentType.STRING_SELECT.value,
|
||||||
|
"custom_id": self.custom_id,
|
||||||
|
"options": [option.to_dict() for option in self.options],
|
||||||
|
"disabled": self.disabled,
|
||||||
|
}
|
||||||
|
if self.placeholder:
|
||||||
|
payload["placeholder"] = self.placeholder
|
||||||
|
if self.min_values is not None:
|
||||||
|
payload["min_values"] = self.min_values
|
||||||
|
if self.max_values is not None:
|
||||||
|
payload["max_values"] = self.max_values
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def select(
|
||||||
|
*,
|
||||||
|
custom_id: str,
|
||||||
|
options: List[SelectOption],
|
||||||
|
placeholder: Optional[str] = None,
|
||||||
|
min_values: int = 1,
|
||||||
|
max_values: int = 1,
|
||||||
|
disabled: bool = False,
|
||||||
|
row: Optional[int] = None,
|
||||||
|
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Select]:
|
||||||
|
"""A decorator to create a select menu in a View."""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Select:
|
||||||
|
if not asyncio.iscoroutinefunction(func):
|
||||||
|
raise TypeError("Select callback must be a coroutine function.")
|
||||||
|
|
||||||
|
item = Select(
|
||||||
|
custom_id=custom_id,
|
||||||
|
options=options,
|
||||||
|
placeholder=placeholder,
|
||||||
|
min_values=min_values,
|
||||||
|
max_values=max_values,
|
||||||
|
disabled=disabled,
|
||||||
|
row=row,
|
||||||
|
)
|
||||||
|
item.callback = func
|
||||||
|
return item
|
||||||
|
|
||||||
|
return decorator
|
165
disagreement/ui/view.py
Normal file
165
disagreement/ui/view.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from ..models import ActionRow
|
||||||
|
from .item import Item
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..client import Client
|
||||||
|
from ..interactions import Interaction
|
||||||
|
|
||||||
|
|
||||||
|
class View:
|
||||||
|
"""Represents a container for UI components that can be sent with a message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
|
||||||
|
Defaults to 180.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, timeout: Optional[float] = 180.0):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.id = str(uuid.uuid4())
|
||||||
|
self.__children: List[Item] = []
|
||||||
|
self.__stopped = asyncio.Event()
|
||||||
|
self._client: Optional[Client] = None
|
||||||
|
self._message_id: Optional[str] = None
|
||||||
|
|
||||||
|
for item in self.__class__.__dict__.values():
|
||||||
|
if isinstance(item, Item):
|
||||||
|
self.add_item(item)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def children(self) -> List[Item]:
|
||||||
|
return self.__children
|
||||||
|
|
||||||
|
def add_item(self, item: Item):
|
||||||
|
"""Adds an item to the view."""
|
||||||
|
if not isinstance(item, Item):
|
||||||
|
raise TypeError("Only instances of 'Item' can be added to a View.")
|
||||||
|
|
||||||
|
if len(self.__children) >= 25:
|
||||||
|
raise ValueError("A view can only have a maximum of 25 components.")
|
||||||
|
|
||||||
|
item._view = self
|
||||||
|
self.__children.append(item)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message_id(self) -> Optional[str]:
|
||||||
|
return self._message_id
|
||||||
|
|
||||||
|
@message_id.setter
|
||||||
|
def message_id(self, value: str):
|
||||||
|
self._message_id = value
|
||||||
|
|
||||||
|
def to_components(self) -> List[ActionRow]:
|
||||||
|
"""Converts the view's children into a list of ActionRow components.
|
||||||
|
|
||||||
|
This retains the original, simple layout behaviour where each item is
|
||||||
|
placed in its own :class:`ActionRow` to ensure backward compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows: List[ActionRow] = []
|
||||||
|
|
||||||
|
for item in self.children:
|
||||||
|
if item.custom_id is None:
|
||||||
|
item.custom_id = (
|
||||||
|
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
rows.append(ActionRow(components=[item]))
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def layout_components_advanced(self) -> List[ActionRow]:
|
||||||
|
"""Group compatible components into rows following Discord rules."""
|
||||||
|
|
||||||
|
rows: List[ActionRow] = []
|
||||||
|
|
||||||
|
for item in self.children:
|
||||||
|
if item.custom_id is None:
|
||||||
|
item.custom_id = (
|
||||||
|
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
target_row = item.row
|
||||||
|
if target_row is not None:
|
||||||
|
if not 0 <= target_row <= 4:
|
||||||
|
raise ValueError("Row index must be between 0 and 4.")
|
||||||
|
|
||||||
|
while len(rows) <= target_row:
|
||||||
|
if len(rows) >= 5:
|
||||||
|
raise ValueError("A view can have at most 5 action rows.")
|
||||||
|
rows.append(ActionRow())
|
||||||
|
|
||||||
|
rows[target_row].add_component(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
placed = False
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
row.add_component(item)
|
||||||
|
placed = True
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not placed:
|
||||||
|
if len(rows) >= 5:
|
||||||
|
raise ValueError("A view can have at most 5 action rows.")
|
||||||
|
new_row = ActionRow([item])
|
||||||
|
rows.append(new_row)
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def to_components_payload(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Converts the view's children into a list of component dictionaries
|
||||||
|
that can be sent to the Discord API."""
|
||||||
|
return [row.to_dict() for row in self.to_components()]
|
||||||
|
|
||||||
|
async def _dispatch(self, interaction: Interaction):
|
||||||
|
"""Called by the client to dispatch an interaction to the correct item."""
|
||||||
|
if self.timeout is not None:
|
||||||
|
self.__stopped.set() # Reset the timeout on each interaction
|
||||||
|
self.__stopped.clear()
|
||||||
|
|
||||||
|
if interaction.data:
|
||||||
|
custom_id = interaction.data.custom_id
|
||||||
|
for child in self.children:
|
||||||
|
if child.custom_id == custom_id:
|
||||||
|
if child.callback:
|
||||||
|
await child.callback(self, interaction)
|
||||||
|
break
|
||||||
|
|
||||||
|
async def wait(self) -> bool:
|
||||||
|
"""Waits until the view has stopped interacting."""
|
||||||
|
return await self.__stopped.wait()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stops the view from listening to interactions."""
|
||||||
|
if not self.__stopped.is_set():
|
||||||
|
self.__stopped.set()
|
||||||
|
|
||||||
|
async def on_timeout(self):
|
||||||
|
"""Called when the view times out."""
|
||||||
|
pass # User can override this
|
||||||
|
|
||||||
|
async def _start(self, client: Client):
|
||||||
|
"""Starts the view's internal listener."""
|
||||||
|
self._client = client
|
||||||
|
if self.timeout is not None:
|
||||||
|
asyncio.create_task(self._timeout_task())
|
||||||
|
|
||||||
|
async def _timeout_task(self):
|
||||||
|
"""The task that waits for the timeout and then stops the view."""
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.wait(), timeout=self.timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.stop()
|
||||||
|
await self.on_timeout()
|
||||||
|
if self._client and self._message_id:
|
||||||
|
# Remove the view from the client's listeners
|
||||||
|
self._client._views.pop(self._message_id, None)
|
120
disagreement/voice_client.py
Normal file
120
disagreement/voice_client.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
# disagreement/voice_client.py
|
||||||
|
"""Voice gateway and UDP audio client."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import socket
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceClient:
|
||||||
|
"""Handles the Discord voice WebSocket connection and UDP streaming."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
session_id: str,
|
||||||
|
token: str,
|
||||||
|
guild_id: int,
|
||||||
|
user_id: int,
|
||||||
|
*,
|
||||||
|
ws=None,
|
||||||
|
udp: Optional[socket.socket] = None,
|
||||||
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.session_id = session_id
|
||||||
|
self.token = token
|
||||||
|
self.guild_id = str(guild_id)
|
||||||
|
self.user_id = str(user_id)
|
||||||
|
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
|
||||||
|
self._udp = udp
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||||
|
self._heartbeat_interval: Optional[float] = None
|
||||||
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
|
self.verbose = verbose
|
||||||
|
self.ssrc: Optional[int] = None
|
||||||
|
self.secret_key: Optional[Sequence[int]] = None
|
||||||
|
self._server_ip: Optional[str] = None
|
||||||
|
self._server_port: Optional[int] = None
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
if self._ws is None:
|
||||||
|
self._session = aiohttp.ClientSession()
|
||||||
|
self._ws = await self._session.ws_connect(self.endpoint)
|
||||||
|
|
||||||
|
hello = await self._ws.receive_json()
|
||||||
|
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
|
||||||
|
self._heartbeat_task = self._loop.create_task(self._heartbeat())
|
||||||
|
|
||||||
|
await self._ws.send_json(
|
||||||
|
{
|
||||||
|
"op": 0,
|
||||||
|
"d": {
|
||||||
|
"server_id": self.guild_id,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"token": self.token,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ready = await self._ws.receive_json()
|
||||||
|
data = ready["d"]
|
||||||
|
self.ssrc = data["ssrc"]
|
||||||
|
self._server_ip = data["ip"]
|
||||||
|
self._server_port = data["port"]
|
||||||
|
|
||||||
|
if self._udp is None:
|
||||||
|
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
self._udp.connect((self._server_ip, self._server_port))
|
||||||
|
|
||||||
|
await self._ws.send_json(
|
||||||
|
{
|
||||||
|
"op": 1,
|
||||||
|
"d": {
|
||||||
|
"protocol": "udp",
|
||||||
|
"data": {
|
||||||
|
"address": self._udp.getsockname()[0],
|
||||||
|
"port": self._udp.getsockname()[1],
|
||||||
|
"mode": "xsalsa20_poly1305",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
session_desc = await self._ws.receive_json()
|
||||||
|
self.secret_key = session_desc["d"].get("secret_key")
|
||||||
|
|
||||||
|
async def _heartbeat(self) -> None:
|
||||||
|
assert self._ws is not None
|
||||||
|
assert self._heartbeat_interval is not None
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
|
||||||
|
await asyncio.sleep(self._heartbeat_interval)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_audio_frame(self, frame: bytes) -> None:
|
||||||
|
if not self._udp:
|
||||||
|
raise RuntimeError("UDP socket not initialised")
|
||||||
|
self._udp.send(frame)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._heartbeat_task:
|
||||||
|
self._heartbeat_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._heartbeat_task
|
||||||
|
if self._ws:
|
||||||
|
await self._ws.close()
|
||||||
|
if self._session:
|
||||||
|
await self._session.close()
|
||||||
|
if self._udp:
|
||||||
|
self._udp.close()
|
18
docs/caching.md
Normal file
18
docs/caching.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Caching
|
||||||
|
|
||||||
|
Disagreement ships with a simple in-memory cache used by the HTTP and Gateway clients. Cached objects reduce API requests and improve performance.
|
||||||
|
|
||||||
|
The client automatically caches guilds, channels and users as they are received from events or HTTP calls. You can access cached data through lookup helpers such as `Client.get_guild`.
|
||||||
|
|
||||||
|
The cache can be cleared manually if needed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
client.cache.clear()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
- [Components](using_components.md)
|
||||||
|
- [Slash Commands](slash_commands.md)
|
||||||
|
- [Voice Features](voice_features.md)
|
||||||
|
|
51
docs/commands.md
Normal file
51
docs/commands.md
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Commands Extension
|
||||||
|
|
||||||
|
This guide covers the built-in prefix command system.
|
||||||
|
|
||||||
|
## Help Command
|
||||||
|
|
||||||
|
The command handler registers a `help` command automatically. Use it to list all available commands or get information about a single command.
|
||||||
|
|
||||||
|
```
|
||||||
|
!help # lists commands
|
||||||
|
!help ping # shows help for the "ping" command
|
||||||
|
```
|
||||||
|
|
||||||
|
The help command will show each command's brief description if provided.
|
||||||
|
|
||||||
|
## Checks
|
||||||
|
|
||||||
|
Use `commands.check` to prevent a command from running unless a predicate
|
||||||
|
returns ``True``. Checks may be regular or async callables that accept a
|
||||||
|
`CommandContext`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.ext.commands import command, check, CheckFailure
|
||||||
|
|
||||||
|
def is_owner(ctx):
|
||||||
|
return ctx.author.id == "1"
|
||||||
|
|
||||||
|
@command()
|
||||||
|
@check(is_owner)
|
||||||
|
async def secret(ctx):
|
||||||
|
await ctx.send("Only for the owner!")
|
||||||
|
```
|
||||||
|
|
||||||
|
When a check fails a :class:`CheckFailure` is raised and dispatched through the
|
||||||
|
command error handler.
|
||||||
|
|
||||||
|
## Cooldowns
|
||||||
|
|
||||||
|
Commands can be rate limited using the ``cooldown`` decorator. The example
|
||||||
|
below restricts usage to once every three seconds per user:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.ext.commands import command, cooldown
|
||||||
|
|
||||||
|
@command()
|
||||||
|
@cooldown(1, 3.0)
|
||||||
|
async def ping(ctx):
|
||||||
|
await ctx.send("Pong!")
|
||||||
|
```
|
||||||
|
|
||||||
|
Invoking a command while it is on cooldown raises :class:`CommandOnCooldown`.
|
21
docs/context_menus.md
Normal file
21
docs/context_menus.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Context Menu Commands
|
||||||
|
|
||||||
|
`disagreement` supports Discord's user and message context menu commands. Use the
|
||||||
|
`user_command` and `message_command` decorators from `ext.app_commands` to
|
||||||
|
define them.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.ext.app_commands import user_command, message_command, AppCommandContext
|
||||||
|
from disagreement.models import User, Message
|
||||||
|
|
||||||
|
@user_command(name="User Info")
|
||||||
|
async def user_info(ctx: AppCommandContext, user: User) -> None:
|
||||||
|
await ctx.send(f"User: {user.username}#{user.discriminator}")
|
||||||
|
|
||||||
|
@message_command(name="Quote")
|
||||||
|
async def quote(ctx: AppCommandContext, message: Message) -> None:
|
||||||
|
await ctx.send(message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
Add the commands to your client's handler and run `sync_commands()` to register
|
||||||
|
them with Discord.
|
25
docs/converters.md
Normal file
25
docs/converters.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Command Argument Converters
|
||||||
|
|
||||||
|
`disagreement.ext.commands` provides a number of built in converters that will parse string arguments into richer objects. These converters are automatically used when a command callback annotates its parameters with one of the supported types.
|
||||||
|
|
||||||
|
## Supported Types
|
||||||
|
|
||||||
|
- `int`, `float`, `bool`, and `str`
|
||||||
|
- `Member` – resolves a user mention or ID to a `Member` object for the current guild
|
||||||
|
- `Role` – resolves a role mention or ID to a `Role` object
|
||||||
|
- `Guild` – resolves a guild ID to a `Guild` object
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.ext.commands import command
|
||||||
|
from disagreement.ext.commands.core import CommandContext
|
||||||
|
from disagreement.models import Member
|
||||||
|
|
||||||
|
@command()
|
||||||
|
async def kick(ctx: CommandContext, target: Member):
|
||||||
|
await target.kick()
|
||||||
|
await ctx.send(f"Kicked {target.display_name}")
|
||||||
|
```
|
||||||
|
|
||||||
|
The framework will automatically convert the first argument to a `Member` using the mention or ID provided by the user.
|
25
docs/events.md
Normal file
25
docs/events.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Events
|
||||||
|
|
||||||
|
Disagreement dispatches Gateway events to asynchronous callbacks. Handlers can be registered with `@client.event` or `client.on_event`.
|
||||||
|
Listeners may be removed later using `EventDispatcher.unregister(event_name, coro)`.
|
||||||
|
|
||||||
|
|
||||||
|
## PRESENCE_UPDATE
|
||||||
|
|
||||||
|
Triggered when a user's presence changes. The callback receives a `PresenceUpdate` model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@client.event
|
||||||
|
async def on_presence_update(presence: disagreement.PresenceUpdate):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## TYPING_START
|
||||||
|
|
||||||
|
Dispatched when a user begins typing in a channel. The callback receives a `TypingStart` model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@client.event
|
||||||
|
async def on_typing_start(typing: disagreement.TypingStart):
|
||||||
|
...
|
||||||
|
```
|
17
docs/gateway.md
Normal file
17
docs/gateway.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Gateway Connection and Reconnection
|
||||||
|
|
||||||
|
`GatewayClient` manages the library's WebSocket connection. When the connection drops unexpectedly, it will now automatically attempt to reconnect using an exponential backoff strategy with jitter.
|
||||||
|
|
||||||
|
The default behaviour tries up to five reconnect attempts, doubling the delay each time up to a configurable maximum. A small random jitter is added to spread out reconnect attempts when multiple clients restart at once.
|
||||||
|
|
||||||
|
You can control the maximum number of retries and the backoff cap when constructing `Client`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
bot = Client(
|
||||||
|
token="your-token",
|
||||||
|
gateway_max_retries=10,
|
||||||
|
gateway_max_backoff=120.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
These values are passed to `GatewayClient` and applied whenever the connection needs to be re-established.
|
36
docs/i18n.md
Normal file
36
docs/i18n.md
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# Internationalization
|
||||||
|
|
||||||
|
Disagreement can translate command names, descriptions and other text using JSON translation files.
|
||||||
|
|
||||||
|
## Providing Translations
|
||||||
|
|
||||||
|
Use `disagreement.i18n.load_translations` to load a JSON file for a locale.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement import i18n
|
||||||
|
|
||||||
|
i18n.load_translations("es", "path/to/es.json")
|
||||||
|
```
|
||||||
|
|
||||||
|
The JSON file should map translation keys to translated strings:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"greet": "Hola",
|
||||||
|
"description": "Comando de saludo"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also set translations programmatically with `i18n.set_translations`.
|
||||||
|
|
||||||
|
## Using with Commands
|
||||||
|
|
||||||
|
Pass a `locale` argument when defining an `AppCommand` or using the decorators. The command name and description will be looked up using the loaded translations.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@slash_command(name="greet", description="description", locale="es")
|
||||||
|
async def greet(ctx):
|
||||||
|
await ctx.send(i18n.translate("greet", ctx.locale or "es"))
|
||||||
|
```
|
||||||
|
|
||||||
|
If a translation is missing the key itself is returned.
|
48
docs/oauth2.md
Normal file
48
docs/oauth2.md
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# OAuth2 Setup
|
||||||
|
|
||||||
|
This guide explains how to perform a basic OAuth2 flow with `disagreement`.
|
||||||
|
|
||||||
|
1. Generate the authorization URL:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.oauth import build_authorization_url
|
||||||
|
|
||||||
|
url = build_authorization_url(
|
||||||
|
client_id="YOUR_CLIENT_ID",
|
||||||
|
redirect_uri="https://your.app/callback",
|
||||||
|
scope=["identify"],
|
||||||
|
)
|
||||||
|
print(url)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. After the user authorizes your application and you receive a code, exchange it for a token:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import aiohttp
|
||||||
|
from disagreement.oauth import exchange_code_for_token
|
||||||
|
|
||||||
|
async def get_token(code: str):
|
||||||
|
return await exchange_code_for_token(
|
||||||
|
client_id="YOUR_CLIENT_ID",
|
||||||
|
client_secret="YOUR_CLIENT_SECRET",
|
||||||
|
code=code,
|
||||||
|
redirect_uri="https://your.app/callback",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`exchange_code_for_token` returns the JSON payload from Discord which includes
|
||||||
|
`access_token`, `refresh_token` and expiry information.
|
||||||
|
|
||||||
|
3. When the access token expires, you can refresh it using the provided refresh
|
||||||
|
token:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.oauth import refresh_access_token
|
||||||
|
|
||||||
|
async def refresh(token: str):
|
||||||
|
return await refresh_access_token(
|
||||||
|
refresh_token=token,
|
||||||
|
client_id="YOUR_CLIENT_ID",
|
||||||
|
client_secret="YOUR_CLIENT_SECRET",
|
||||||
|
)
|
||||||
|
```
|
62
docs/permissions.md
Normal file
62
docs/permissions.md
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# Permission Helpers
|
||||||
|
|
||||||
|
The `disagreement.permissions` module defines an :class:`~enum.IntFlag`
|
||||||
|
`Permissions` enumeration along with helper functions for working with
|
||||||
|
the Discord permission bitmask.
|
||||||
|
|
||||||
|
## Permissions Enum
|
||||||
|
|
||||||
|
Each attribute of ``Permissions`` represents a single permission bit. The value
|
||||||
|
is a power of two so multiple permissions can be combined using bitwise OR.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.permissions import Permissions
|
||||||
|
|
||||||
|
value = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
|
||||||
|
```
|
||||||
|
|
||||||
|
## Helper Functions
|
||||||
|
|
||||||
|
### ``permissions_value``
|
||||||
|
|
||||||
|
```python
|
||||||
|
permissions_value(*perms) -> int
|
||||||
|
```
|
||||||
|
|
||||||
|
Return an integer bitmask from one or more ``Permissions`` values. Nested
|
||||||
|
iterables are flattened automatically.
|
||||||
|
|
||||||
|
### ``has_permissions``
|
||||||
|
|
||||||
|
```python
|
||||||
|
has_permissions(current, *perms) -> bool
|
||||||
|
```
|
||||||
|
|
||||||
|
Return ``True`` if ``current`` (an ``int`` or ``Permissions``) contains all of
|
||||||
|
the provided permissions.
|
||||||
|
|
||||||
|
### ``missing_permissions``
|
||||||
|
|
||||||
|
```python
|
||||||
|
missing_permissions(current, *perms) -> List[Permissions]
|
||||||
|
```
|
||||||
|
|
||||||
|
Return a list of permissions that ``current`` does not contain.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.permissions import (
|
||||||
|
Permissions,
|
||||||
|
has_permissions,
|
||||||
|
missing_permissions,
|
||||||
|
)
|
||||||
|
|
||||||
|
current = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
|
||||||
|
|
||||||
|
if has_permissions(current, Permissions.SEND_MESSAGES):
|
||||||
|
print("Can send messages")
|
||||||
|
|
||||||
|
print(missing_permissions(current, Permissions.ADMINISTRATOR))
|
||||||
|
```
|
||||||
|
|
29
docs/presence.md
Normal file
29
docs/presence.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Updating Presence
|
||||||
|
|
||||||
|
The `Client.change_presence` method allows you to update the bot's status and displayed activity.
|
||||||
|
|
||||||
|
## Status Strings
|
||||||
|
|
||||||
|
- `online` – show the bot as online
|
||||||
|
- `idle` – mark the bot as away
|
||||||
|
- `dnd` – do not disturb
|
||||||
|
- `invisible` – appear offline
|
||||||
|
|
||||||
|
## Activity Types
|
||||||
|
|
||||||
|
An activity dictionary must include a `name` and a `type` field. The type value corresponds to Discord's activity types:
|
||||||
|
|
||||||
|
| Type | Meaning |
|
||||||
|
|-----:|--------------|
|
||||||
|
| `0` | Playing |
|
||||||
|
| `1` | Streaming |
|
||||||
|
| `2` | Listening |
|
||||||
|
| `3` | Watching |
|
||||||
|
| `4` | Custom |
|
||||||
|
| `5` | Competing |
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await client.change_presence(status="idle", activity={"name": "with Discord", "type": 0})
|
||||||
|
```
|
32
docs/reactions.md
Normal file
32
docs/reactions.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Handling Reactions
|
||||||
|
|
||||||
|
`disagreement` provides simple helpers for adding and removing message reactions.
|
||||||
|
|
||||||
|
## HTTP Methods
|
||||||
|
|
||||||
|
Use the `HTTPClient` methods directly if you need lower level control:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await client._http.create_reaction(channel_id, message_id, "👍")
|
||||||
|
await client._http.delete_reaction(channel_id, message_id, "👍")
|
||||||
|
users = await client._http.get_reactions(channel_id, message_id, "👍")
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also use the higher level helpers on :class:`Client`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await client.create_reaction(channel_id, message_id, "👍")
|
||||||
|
await client.delete_reaction(channel_id, message_id, "👍")
|
||||||
|
users = await client.get_reactions(channel_id, message_id, "👍")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Reaction Events
|
||||||
|
|
||||||
|
Register listeners for `MESSAGE_REACTION_ADD` and `MESSAGE_REACTION_REMOVE`.
|
||||||
|
Each listener receives a `Reaction` model instance.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@client.on_event("MESSAGE_REACTION_ADD")
|
||||||
|
async def on_reaction(reaction: disagreement.Reaction):
|
||||||
|
print(f"{reaction.user_id} reacted with {reaction.emoji}")
|
||||||
|
```
|
22
docs/slash_commands.md
Normal file
22
docs/slash_commands.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Using Slash Commands
|
||||||
|
|
||||||
|
The library provides a slash command framework via the `ext.app_commands` package. Define commands with decorators and register them with Discord.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.ext.app_commands import AppCommandGroup
|
||||||
|
|
||||||
|
bot_commands = AppCommandGroup("bot", "Bot commands")
|
||||||
|
|
||||||
|
@bot_commands.command(name="ping")
|
||||||
|
async def ping(ctx):
|
||||||
|
await ctx.respond("Pong!")
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `AppCommandGroup` to group related commands. See the [components guide](using_components.md) for building interactive responses.
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
- [Components](using_components.md)
|
||||||
|
- [Caching](caching.md)
|
||||||
|
- [Voice Features](voice_features.md)
|
||||||
|
|
15
docs/task_loop.md
Normal file
15
docs/task_loop.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Task Loops
|
||||||
|
|
||||||
|
The tasks extension allows you to run functions periodically. Decorate an async function with `@tasks.loop(seconds=...)` and start it using `.start()`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.ext import tasks
|
||||||
|
|
||||||
|
@tasks.loop(seconds=5.0)
|
||||||
|
async def announce():
|
||||||
|
print("Hello from a loop")
|
||||||
|
|
||||||
|
announce.start()
|
||||||
|
```
|
||||||
|
|
||||||
|
Stop the loop with `.stop()` when you no longer need it.
|
16
docs/typing_indicator.md
Normal file
16
docs/typing_indicator.md
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# Typing Indicator
|
||||||
|
|
||||||
|
The library exposes an async context manager to send the typing indicator for a channel.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
import disagreement
|
||||||
|
|
||||||
|
client = disagreement.Client(token="YOUR_TOKEN")
|
||||||
|
|
||||||
|
async def indicate(channel_id: str):
|
||||||
|
async with client.typing(channel_id):
|
||||||
|
await long_running_task()
|
||||||
|
```
|
||||||
|
|
||||||
|
This uses the underlying HTTP endpoint `/channels/{channel_id}/typing`.
|
168
docs/using_components.md
Normal file
168
docs/using_components.md
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
# Using Message Components
|
||||||
|
|
||||||
|
This guide explains how to work with the `disagreement` message component models. These examples are up to date with the current code base.
|
||||||
|
|
||||||
|
## Enabling the New Component System
|
||||||
|
|
||||||
|
Messages that use the component system must include the flag `IS_COMPONENTS_V2` (value `1 << 15`). Once this flag is set on a message it cannot be removed.
|
||||||
|
|
||||||
|
## Component Categories
|
||||||
|
|
||||||
|
The library exposes three broad categories of components:
|
||||||
|
|
||||||
|
- **Layout Components** – organize the placement of other components.
|
||||||
|
- **Content Components** – display static text or media.
|
||||||
|
- **Interactive Components** – allow the user to interact with your message.
|
||||||
|
|
||||||
|
## Action Row
|
||||||
|
|
||||||
|
`ActionRow` is a layout container. It may hold up to five buttons or a single select menu.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import ActionRow, Button
|
||||||
|
from disagreement.enums import ButtonStyle
|
||||||
|
|
||||||
|
row = ActionRow(components=[
|
||||||
|
Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="btn")
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Button
|
||||||
|
|
||||||
|
Buttons provide a clickable UI element.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import Button
|
||||||
|
from disagreement.enums import ButtonStyle
|
||||||
|
|
||||||
|
button = Button(
|
||||||
|
style=ButtonStyle.SUCCESS,
|
||||||
|
label="Confirm",
|
||||||
|
custom_id="confirm_button",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Select Menus
|
||||||
|
|
||||||
|
`SelectMenu` lets the user choose one or more options. The `type` parameter controls the menu variety (`STRING_SELECT`, `USER_SELECT`, `ROLE_SELECT`, `MENTIONABLE_SELECT`, `CHANNEL_SELECT`).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import SelectMenu, SelectOption
|
||||||
|
from disagreement.enums import ComponentType, ChannelType
|
||||||
|
|
||||||
|
menu = SelectMenu(
|
||||||
|
custom_id="example",
|
||||||
|
options=[
|
||||||
|
SelectOption(label="Option 1", value="1"),
|
||||||
|
SelectOption(label="Option 2", value="2"),
|
||||||
|
],
|
||||||
|
placeholder="Choose an option",
|
||||||
|
min_values=1,
|
||||||
|
max_values=1,
|
||||||
|
type=ComponentType.STRING_SELECT,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
For channel selects you may pass `channel_types` with a list of allowed `ChannelType` values.
|
||||||
|
|
||||||
|
## Section
|
||||||
|
|
||||||
|
`Section` groups one or more `TextDisplay` components and can include an accessory `Button` or `Thumbnail`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import Section, TextDisplay, Thumbnail, UnfurledMediaItem
|
||||||
|
|
||||||
|
section = Section(
|
||||||
|
components=[
|
||||||
|
TextDisplay(content="## Section Title"),
|
||||||
|
TextDisplay(content="Sections can hold multiple text displays."),
|
||||||
|
],
|
||||||
|
accessory=Thumbnail(media=UnfurledMediaItem(url="https://example.com/img.png")),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Text Display
|
||||||
|
|
||||||
|
`TextDisplay` simply renders markdown text.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import TextDisplay
|
||||||
|
|
||||||
|
text_display = TextDisplay(content="**Bold text**")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Thumbnail
|
||||||
|
|
||||||
|
`Thumbnail` shows a small image. Set `spoiler=True` to hide the image until clicked.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import Thumbnail, UnfurledMediaItem
|
||||||
|
|
||||||
|
thumb = Thumbnail(
|
||||||
|
media=UnfurledMediaItem(url="https://example.com/image.png"),
|
||||||
|
description="A picture",
|
||||||
|
spoiler=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Media Gallery
|
||||||
|
|
||||||
|
`MediaGallery` holds multiple `MediaGalleryItem` objects.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import MediaGallery, MediaGalleryItem, UnfurledMediaItem
|
||||||
|
|
||||||
|
gallery = MediaGallery(
|
||||||
|
items=[
|
||||||
|
MediaGalleryItem(media=UnfurledMediaItem(url="https://example.com/1.png")),
|
||||||
|
MediaGalleryItem(media=UnfurledMediaItem(url="https://example.com/2.png")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## File
|
||||||
|
|
||||||
|
`File` displays an uploaded file. Use `spoiler=True` to mark it as a spoiler.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import File, UnfurledMediaItem
|
||||||
|
|
||||||
|
file_component = File(
|
||||||
|
file=UnfurledMediaItem(url="attachment://file.zip"),
|
||||||
|
spoiler=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Separator
|
||||||
|
|
||||||
|
`Separator` adds vertical spacing or an optional divider line between components.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import Separator
|
||||||
|
|
||||||
|
separator = Separator(divider=True, spacing=2)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Container
|
||||||
|
|
||||||
|
`Container` visually groups a set of components and can apply an accent colour or spoiler.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import Container, TextDisplay
|
||||||
|
|
||||||
|
container = Container(
|
||||||
|
components=[TextDisplay(content="Inside a container")],
|
||||||
|
accent_color=0xFF0000,
|
||||||
|
spoiler=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
A container can itself contain layout and content components, letting you build complex messages.
|
||||||
|
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
- [Slash Commands](slash_commands.md)
|
||||||
|
- [Caching](caching.md)
|
||||||
|
- [Voice Features](voice_features.md)
|
||||||
|
|
29
docs/voice_client.md
Normal file
29
docs/voice_client.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# VoiceClient
|
||||||
|
|
||||||
|
`VoiceClient` provides a minimal interface to Discord's voice gateway. It handles the WebSocket handshake and lets you stream audio over UDP.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import disagreement
|
||||||
|
|
||||||
|
vc = disagreement.VoiceClient(
|
||||||
|
os.environ["DISCORD_VOICE_ENDPOINT"],
|
||||||
|
os.environ["DISCORD_SESSION_ID"],
|
||||||
|
os.environ["DISCORD_VOICE_TOKEN"],
|
||||||
|
int(os.environ["DISCORD_GUILD_ID"]),
|
||||||
|
int(os.environ["DISCORD_USER_ID"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(vc.connect())
|
||||||
|
```
|
||||||
|
|
||||||
|
After connecting you can send raw Opus frames:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await vc.send_audio_frame(opus_bytes)
|
||||||
|
```
|
||||||
|
|
||||||
|
Call `await vc.close()` when finished.
|
17
docs/voice_features.md
Normal file
17
docs/voice_features.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Voice Features
|
||||||
|
|
||||||
|
Disagreement includes experimental support for connecting to voice channels. You can join a voice channel and play audio using an FFmpeg subprocess.
|
||||||
|
|
||||||
|
```python
|
||||||
|
voice = await client.join_voice(guild_id, channel_id)
|
||||||
|
voice.play_file("welcome.mp3")
|
||||||
|
```
|
||||||
|
|
||||||
|
Voice support is optional and may require additional system dependencies such as FFmpeg.
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
- [Components](using_components.md)
|
||||||
|
- [Slash Commands](slash_commands.md)
|
||||||
|
- [Caching](caching.md)
|
||||||
|
|
34
docs/webhooks.md
Normal file
34
docs/webhooks.md
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Working with Webhooks
|
||||||
|
|
||||||
|
The `HTTPClient` includes helper methods for creating, editing and deleting Discord webhooks.
|
||||||
|
|
||||||
|
## Create a webhook
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.http import HTTPClient
|
||||||
|
|
||||||
|
http = HTTPClient(token="TOKEN")
|
||||||
|
payload = {"name": "My Webhook"}
|
||||||
|
webhook_data = await http.create_webhook("123", payload)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Edit a webhook
|
||||||
|
|
||||||
|
```python
|
||||||
|
await http.edit_webhook("456", {"name": "Renamed"})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Delete a webhook
|
||||||
|
|
||||||
|
```python
|
||||||
|
await http.delete_webhook("456")
|
||||||
|
```
|
||||||
|
|
||||||
|
The methods return the raw webhook JSON. You can construct a `Webhook` model if needed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from disagreement.models import Webhook
|
||||||
|
|
||||||
|
webhook = Webhook(webhook_data)
|
||||||
|
print(webhook.id, webhook.name)
|
||||||
|
```
|
218
examples/basic_bot.py
Normal file
218
examples/basic_bot.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
# examples/basic_bot.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
A basic example bot using the Disagreement library.
|
||||||
|
|
||||||
|
To run this bot:
|
||||||
|
1. Make sure you have the 'disagreement' library installed or accessible in your PYTHONPATH.
|
||||||
|
If running from the project root, it should be discoverable.
|
||||||
|
2. Set the DISCORD_BOT_TOKEN environment variable to your bot's token.
|
||||||
|
e.g., export DISCORD_BOT_TOKEN="your_actual_token_here" (Linux/macOS)
|
||||||
|
set DISCORD_BOT_TOKEN="your_actual_token_here" (Windows CMD)
|
||||||
|
$env:DISCORD_BOT_TOKEN="your_actual_token_here" (Windows PowerShell)
|
||||||
|
3. Run this script: python examples/basic_bot.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import logging # Optional: for more detailed logging
|
||||||
|
|
||||||
|
# Assuming the 'disagreement' package is in the parent directory or installed
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
# Add project root to path if running script directly from examples folder
|
||||||
|
# and disagreement is not installed
|
||||||
|
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
try:
|
||||||
|
import disagreement
|
||||||
|
from disagreement.ext import commands # Import the new commands extension
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"If running from the 'examples' directory, try running from the project root: python -m examples.basic_bot"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Optional: Configure logging for more insight, especially for gateway events
|
||||||
|
# logging.basicConfig(level=logging.DEBUG) # For very verbose output
|
||||||
|
# logging.getLogger('disagreement.gateway').setLevel(logging.INFO) # Or DEBUG
|
||||||
|
# logging.getLogger('disagreement.http').setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# --- Bot Configuration ---
|
||||||
|
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||||
|
|
||||||
|
# --- Intents Configuration ---
|
||||||
|
# Define the intents your bot needs. For basic message reading and responding:
|
||||||
|
intents = (
|
||||||
|
disagreement.GatewayIntent.GUILDS
|
||||||
|
| disagreement.GatewayIntent.GUILD_MESSAGES
|
||||||
|
| disagreement.GatewayIntent.MESSAGE_CONTENT
|
||||||
|
) # MESSAGE_CONTENT is privileged!
|
||||||
|
|
||||||
|
# If you don't need message content and only react to commands/mentions,
|
||||||
|
# you might not need MESSAGE_CONTENT intent.
|
||||||
|
# intents = disagreement.GatewayIntent.default() # A good starting point without privileged intents
|
||||||
|
# intents |= disagreement.GatewayIntent.MESSAGE_CONTENT # Add if needed
|
||||||
|
|
||||||
|
# --- Initialize the Client ---
|
||||||
|
if not BOT_TOKEN:
|
||||||
|
print("Error: The DISCORD_BOT_TOKEN environment variable is not set.")
|
||||||
|
print("Please set it before running the bot.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Initialize Client with a command prefix
|
||||||
|
client = disagreement.Client(token=BOT_TOKEN, intents=intents, command_prefix="!")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Define a Cog for example commands ---
|
||||||
|
class ExampleCog(commands.Cog): # Ensuring this uses commands.Cog
|
||||||
|
def __init__(
|
||||||
|
self, bot_client
|
||||||
|
): # Renamed client to bot_client to avoid conflict with self.client
|
||||||
|
super().__init__(bot_client) # Pass the client instance to the base Cog
|
||||||
|
|
||||||
|
@commands.command(name="hello", aliases=["hi"])
|
||||||
|
async def hello_command(self, ctx: commands.CommandContext, *, who: str = "world"):
|
||||||
|
"""Greets someone."""
|
||||||
|
await ctx.reply(f"Hello {ctx.author.mention} and {who}!")
|
||||||
|
print(f"Executed 'hello' command for {ctx.author.username}, greeting {who}")
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def ping(self, ctx: commands.CommandContext):
|
||||||
|
"""Responds with Pong!"""
|
||||||
|
await ctx.reply("Pong!")
|
||||||
|
print(f"Executed 'ping' command for {ctx.author.username}")
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def me(self, ctx: commands.CommandContext):
|
||||||
|
"""Shows information about the invoking user."""
|
||||||
|
reply_content = (
|
||||||
|
f"Hello {ctx.author.mention}!\n"
|
||||||
|
f"Your User ID is: {ctx.author.id}\n"
|
||||||
|
f"Your Username: {ctx.author.username}#{ctx.author.discriminator}\n"
|
||||||
|
f"Are you a bot? {'Yes' if ctx.author.bot else 'No'}"
|
||||||
|
)
|
||||||
|
await ctx.reply(reply_content)
|
||||||
|
print(f"Executed 'me' command for {ctx.author.username}")
|
||||||
|
|
||||||
|
@commands.command(name="add")
|
||||||
|
async def add_numbers(self, ctx: commands.CommandContext, num1: int, num2: int):
|
||||||
|
"""Adds two numbers."""
|
||||||
|
result = num1 + num2
|
||||||
|
await ctx.reply(f"The sum of {num1} and {num2} is {result}.")
|
||||||
|
print(
|
||||||
|
f"Executed 'add' command for {ctx.author.username}: {num1} + {num2} = {result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@commands.command(name="say")
|
||||||
|
async def say_something(self, ctx: commands.CommandContext, *, text_to_say: str):
|
||||||
|
"""Repeats the text you provide."""
|
||||||
|
await ctx.reply(f"You said: {text_to_say}")
|
||||||
|
print(
|
||||||
|
f"Executed 'say' command for {ctx.author.username}, saying: {text_to_say}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@commands.command(name="quit")
|
||||||
|
async def quit_command(self, ctx: commands.CommandContext):
|
||||||
|
"""Shuts down the bot (requires YOUR_USER_ID to be set)."""
|
||||||
|
# Replace YOUR_USER_ID with your actual Discord User ID for a safe shutdown command
|
||||||
|
your_user_id = "YOUR_USER_ID_REPLACE_ME" # IMPORTANT: Replace this
|
||||||
|
if str(ctx.author.id) == your_user_id:
|
||||||
|
print("Quit command received. Shutting down...")
|
||||||
|
await ctx.reply("Shutting down...")
|
||||||
|
await self.client.close() # Access client via self.client from Cog
|
||||||
|
else:
|
||||||
|
await ctx.reply("You are not authorized to use this command.")
|
||||||
|
print(
|
||||||
|
f"Unauthorized quit attempt by {ctx.author.username} ({ctx.author.id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Event Handlers ---
|
||||||
|
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
"""Called when the bot is ready and connected to Discord."""
|
||||||
|
if client.user:
|
||||||
|
print(
|
||||||
|
f"Bot is ready! Logged in as {client.user.username}#{client.user.discriminator}"
|
||||||
|
)
|
||||||
|
print(f"User ID: {client.user.id}")
|
||||||
|
else:
|
||||||
|
print("Bot is ready, but client.user is missing!")
|
||||||
|
print("------")
|
||||||
|
print("Disagreement Bot is operational.")
|
||||||
|
print("Listening for commands...")
|
||||||
|
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_message(message: disagreement.Message):
|
||||||
|
"""Called when a message is created and received."""
|
||||||
|
# Command processing is now handled by the CommandHandler via client._process_message_for_commands
|
||||||
|
# This on_message can be used for other message-related logic if needed,
|
||||||
|
# or removed if all message handling is command-based.
|
||||||
|
|
||||||
|
# Example: Log all messages (excluding bot's own, if client.user was available)
|
||||||
|
# if client.user and message.author.id == client.user.id:
|
||||||
|
# return
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"General on_message: #{message.channel_id} from {message.author.username}: {message.content}"
|
||||||
|
)
|
||||||
|
# The old if/elif command structure is no longer needed here.
|
||||||
|
|
||||||
|
|
||||||
|
@client.on_event(
|
||||||
|
"GUILD_CREATE"
|
||||||
|
) # Example of listening to a specific event by its Discord name
|
||||||
|
async def on_guild_available(guild_data: dict): # Receives raw data for now
|
||||||
|
# In a real scenario, guild_data would be parsed into a Guild model
|
||||||
|
print(f"Guild available: {guild_data.get('name')} (ID: {guild_data.get('id')})")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Execution ---
|
||||||
|
async def main():
|
||||||
|
print("Starting Disagreement Bot...")
|
||||||
|
try:
|
||||||
|
# Add the Cog to the client
|
||||||
|
client.add_cog(ExampleCog(client)) # Pass client instance to Cog constructor
|
||||||
|
# client.add_cog is synchronous, but it schedules cog.cog_load() if it's async.
|
||||||
|
|
||||||
|
await client.run()
|
||||||
|
except disagreement.AuthenticationError:
|
||||||
|
print(
|
||||||
|
"Authentication failed. Please check your bot token and ensure it's correct."
|
||||||
|
)
|
||||||
|
except disagreement.DisagreementException as e:
|
||||||
|
print(f"A Disagreement library error occurred: {e}")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Bot shutting down due to KeyboardInterrupt...")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An unexpected error occurred: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
if not client.is_closed():
|
||||||
|
print("Ensuring client is closed...")
|
||||||
|
await client.close()
|
||||||
|
print("Bot has been shut down.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Note: On Windows, the default asyncio event loop policy might not support add_signal_handler.
|
||||||
|
# If you encounter issues with Ctrl+C not working as expected,
|
||||||
|
# you might need to adjust the event loop policy or handle shutdown differently.
|
||||||
|
# For example, for Windows:
|
||||||
|
# if os.name == 'nt':
|
||||||
|
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
asyncio.run(main())
|
292
examples/component_bot.py
Normal file
292
examples/component_bot.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from typing import Union, Optional
|
||||||
|
from disagreement import Client, ui, HybridContext
|
||||||
|
from disagreement.models import (
|
||||||
|
Message,
|
||||||
|
SelectOption,
|
||||||
|
User,
|
||||||
|
Member,
|
||||||
|
Role,
|
||||||
|
Attachment,
|
||||||
|
Channel,
|
||||||
|
ActionRow,
|
||||||
|
Button,
|
||||||
|
Section,
|
||||||
|
TextDisplay,
|
||||||
|
Thumbnail,
|
||||||
|
UnfurledMediaItem,
|
||||||
|
MediaGallery,
|
||||||
|
MediaGalleryItem,
|
||||||
|
Container,
|
||||||
|
)
|
||||||
|
from disagreement.enums import (
|
||||||
|
ButtonStyle,
|
||||||
|
GatewayIntent,
|
||||||
|
ChannelType,
|
||||||
|
MessageFlags,
|
||||||
|
InteractionCallbackType,
|
||||||
|
MessageFlags,
|
||||||
|
)
|
||||||
|
from disagreement.ext.commands.cog import Cog
|
||||||
|
from disagreement.ext.commands.core import CommandContext
|
||||||
|
from disagreement.ext.app_commands.decorators import hybrid_command, slash_command
|
||||||
|
from disagreement.ext.app_commands.context import AppCommandContext
|
||||||
|
from disagreement.interactions import (
|
||||||
|
Interaction,
|
||||||
|
InteractionResponsePayload,
|
||||||
|
InteractionCallbackData,
|
||||||
|
)
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Get the bot token and application ID from the environment variables
|
||||||
|
token = os.getenv("DISCORD_BOT_TOKEN")
|
||||||
|
application_id = os.getenv("DISCORD_APPLICATION_ID")
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
raise ValueError("Bot token not found in environment variables")
|
||||||
|
if not application_id:
|
||||||
|
raise ValueError("Application ID not found in environment variables")
|
||||||
|
|
||||||
|
# Define the intents
|
||||||
|
intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
|
||||||
|
|
||||||
|
# Create a new client
|
||||||
|
client = Client(
|
||||||
|
|
||||||
|
token=token,
|
||||||
|
|
||||||
|
intents=intents,
|
||||||
|
|
||||||
|
command_prefix="!",
|
||||||
|
|
||||||
|
mention_replies=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simple stock data used for the stock cycling example
|
||||||
|
STOCKS = [
|
||||||
|
{"symbol": "AAPL", "name": "Apple Inc.", "price": "$175"},
|
||||||
|
{"symbol": "MSFT", "name": "Microsoft Corp.", "price": "$315"},
|
||||||
|
{"symbol": "GOOGL", "name": "Alphabet Inc.", "price": "$128"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Define a View class that contains our components
|
||||||
|
class MyView(ui.View):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(timeout=180) # 180-second timeout
|
||||||
|
self.click_count = 0
|
||||||
|
|
||||||
|
@ui.button(label="Click Me!", style=ButtonStyle.SUCCESS, emoji="🖱️")
|
||||||
|
async def click_me(self, interaction: Interaction):
|
||||||
|
self.click_count += 1
|
||||||
|
await interaction.respond(
|
||||||
|
content=f"You've clicked the button {self.click_count} times!",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ui.select(
|
||||||
|
custom_id="string_select",
|
||||||
|
placeholder="Choose an option",
|
||||||
|
options=[
|
||||||
|
SelectOption(
|
||||||
|
label="Option 1", value="opt1", description="This is the first option"
|
||||||
|
),
|
||||||
|
SelectOption(
|
||||||
|
label="Option 2", value="opt2", description="This is the second option"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def select_menu(self, interaction: Interaction):
|
||||||
|
if interaction.data and interaction.data.values:
|
||||||
|
await interaction.respond(
|
||||||
|
content=f"You selected: {interaction.data.values[0]}",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_timeout(self):
|
||||||
|
# This method is called when the view times out.
|
||||||
|
# You can use this to edit the original message, for example.
|
||||||
|
print("View has timed out!")
|
||||||
|
|
||||||
|
|
||||||
|
# View for cycling through available stocks
|
||||||
|
class StockView(ui.View):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(timeout=180)
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
@ui.button(label="Next Stock", style=ButtonStyle.PRIMARY)
|
||||||
|
async def next_stock(self, interaction: Interaction):
|
||||||
|
self.index = (self.index + 1) % len(STOCKS)
|
||||||
|
stock = STOCKS[self.index]
|
||||||
|
# Edit the message by responding to the interaction with an update
|
||||||
|
await interaction.edit(
|
||||||
|
content=f"**{stock['symbol']}** - {stock['name']}\nPrice: {stock['price']}",
|
||||||
|
components=self.to_components(),
|
||||||
|
# Preserve the original reply mention
|
||||||
|
allowed_mentions={"replied_user": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentCommandsCog(Cog):
|
||||||
|
def __init__(self, client: Client):
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
@hybrid_command(name="components", description="Sends interactive components.")
|
||||||
|
async def components_command(self, ctx: Union[CommandContext, AppCommandContext]):
|
||||||
|
# Send a message with the view using a hybrid context helper
|
||||||
|
hybrid = HybridContext(ctx)
|
||||||
|
await hybrid.send("Here are your components:", view=MyView())
|
||||||
|
|
||||||
|
@hybrid_command(
|
||||||
|
name="stocks", description="Shows stock information with navigation."
|
||||||
|
)
|
||||||
|
async def stocks_command(self, ctx: Union[CommandContext, AppCommandContext]):
|
||||||
|
# Show the first stock and attach a button to cycle through them
|
||||||
|
first = STOCKS[0]
|
||||||
|
hybrid = HybridContext(ctx)
|
||||||
|
await hybrid.send(
|
||||||
|
f"**{first['symbol']}** - {first['name']}\nPrice: {first['price']}",
|
||||||
|
view=StockView(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@hybrid_command(name="sectiondemo", description="Shows a section layout.")
|
||||||
|
async def section_demo(self, ctx: Union[CommandContext, AppCommandContext]) -> None:
|
||||||
|
section = Section(
|
||||||
|
components=[
|
||||||
|
TextDisplay(content="## Advanced Components"),
|
||||||
|
TextDisplay(content="Sections group text with accessories."),
|
||||||
|
],
|
||||||
|
accessory=Thumbnail(
|
||||||
|
media=UnfurledMediaItem(url="https://placehold.co/100x100.png")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
container = Container(components=[section], accent_color=0x5865F2)
|
||||||
|
hybrid = HybridContext(ctx)
|
||||||
|
await hybrid.send(
|
||||||
|
components=[container],
|
||||||
|
flags=MessageFlags.IS_COMPONENTS_V2.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@hybrid_command(name="gallerydemo", description="Shows a media gallery.")
|
||||||
|
async def gallery_demo(self, ctx: Union[CommandContext, AppCommandContext]) -> None:
|
||||||
|
gallery = MediaGallery(
|
||||||
|
items=[
|
||||||
|
MediaGalleryItem(
|
||||||
|
media=UnfurledMediaItem(url="https://placehold.co/600x400.png")
|
||||||
|
),
|
||||||
|
MediaGalleryItem(
|
||||||
|
media=UnfurledMediaItem(url="https://placehold.co/600x400.jpg")
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
hybrid = HybridContext(ctx)
|
||||||
|
await hybrid.send(
|
||||||
|
components=[gallery],
|
||||||
|
flags=MessageFlags.IS_COMPONENTS_V2.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@hybrid_command(
|
||||||
|
name="complex_components",
|
||||||
|
description="Shows a complex layout with multiple containers.",
|
||||||
|
)
|
||||||
|
async def complex_components(
|
||||||
|
self, ctx: Union[CommandContext, AppCommandContext]
|
||||||
|
) -> None:
|
||||||
|
container1 = Container(
|
||||||
|
components=[
|
||||||
|
Section(
|
||||||
|
components=[
|
||||||
|
TextDisplay(content="## Complex Layout Example"),
|
||||||
|
TextDisplay(
|
||||||
|
content="This container has an accent color and includes a section with an action row of buttons. There is a thumbnail accessory to the right."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
accessory=Thumbnail(
|
||||||
|
media=UnfurledMediaItem(url="https://placehold.co/100x100.png")
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ActionRow(
|
||||||
|
components=[
|
||||||
|
Button(
|
||||||
|
style=ButtonStyle.PRIMARY,
|
||||||
|
label="Primary",
|
||||||
|
custom_id="complex_primary",
|
||||||
|
),
|
||||||
|
Button(
|
||||||
|
style=ButtonStyle.SUCCESS,
|
||||||
|
label="Success",
|
||||||
|
custom_id="complex_success",
|
||||||
|
),
|
||||||
|
Button(
|
||||||
|
style=ButtonStyle.DANGER,
|
||||||
|
label="Destructive",
|
||||||
|
custom_id="complex_destructive",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
accent_color=0x5865F2,
|
||||||
|
)
|
||||||
|
container2 = Container(
|
||||||
|
components=[
|
||||||
|
TextDisplay(
|
||||||
|
content="## Another Container\nThis container has no accent color and includes a media gallery."
|
||||||
|
),
|
||||||
|
MediaGallery(
|
||||||
|
items=[
|
||||||
|
MediaGalleryItem(
|
||||||
|
media=UnfurledMediaItem(
|
||||||
|
url="https://placehold.co/300x200.png"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
MediaGalleryItem(
|
||||||
|
media=UnfurledMediaItem(
|
||||||
|
url="https://placehold.co/300x200.jpg"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
MediaGalleryItem(
|
||||||
|
media=UnfurledMediaItem(
|
||||||
|
url="https://placehold.co/300x200.gif"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
hybrid = HybridContext(ctx)
|
||||||
|
await hybrid.send(
|
||||||
|
components=[container1, container2],
|
||||||
|
flags=MessageFlags.IS_COMPONENTS_V2.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
if client.user:
|
||||||
|
print(f"Logged in as {client.user.username}")
|
||||||
|
if client.application_id:
|
||||||
|
try:
|
||||||
|
print("Attempting to sync application commands...")
|
||||||
|
await client.app_command_handler.sync_commands(
|
||||||
|
application_id=client.application_id
|
||||||
|
)
|
||||||
|
print("Application commands synced successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error syncing application commands: {e}")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Client's application ID is not set. Skipping application command sync."
|
||||||
|
)
|
||||||
|
|
||||||
|
client.add_cog(ComponentCommandsCog(client))
|
||||||
|
await client.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
71
examples/context_menus.py
Normal file
71
examples/context_menus.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
"""Examples showing how to use context menu commands."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Allow running example from repository root
|
||||||
|
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.ext.app_commands import (
|
||||||
|
user_command,
|
||||||
|
message_command,
|
||||||
|
AppCommandContext,
|
||||||
|
)
|
||||||
|
from disagreement.models import User, Message
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
|
||||||
|
APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "")
|
||||||
|
client = Client(token=BOT_TOKEN, application_id=APP_ID)
|
||||||
|
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
"""Called when the bot is ready and connected to Discord."""
|
||||||
|
if client.user:
|
||||||
|
print(f"Bot is ready! Logged in as {client.user.username}")
|
||||||
|
print("Attempting to sync application commands...")
|
||||||
|
try:
|
||||||
|
if client.application_id:
|
||||||
|
await client.app_command_handler.sync_commands(
|
||||||
|
application_id=client.application_id
|
||||||
|
)
|
||||||
|
print("Application commands synced successfully.")
|
||||||
|
else:
|
||||||
|
print("Skipping command sync: application ID is not set.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error syncing application commands: {e}")
|
||||||
|
else:
|
||||||
|
print("Bot is ready, but client.user is missing!")
|
||||||
|
print("------")
|
||||||
|
|
||||||
|
|
||||||
|
@user_command(name="User Info")
|
||||||
|
async def user_info(ctx: AppCommandContext, user: User) -> None:
|
||||||
|
await ctx.send(
|
||||||
|
f"Selected user: {user.username}#{user.discriminator}", ephemeral=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@message_command(name="Quote")
|
||||||
|
async def quote(ctx: AppCommandContext, message: Message) -> None:
|
||||||
|
await ctx.send(f"Quoted message: {message.content}", ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
|
client.app_command_handler.add_command(user_info)
|
||||||
|
client.app_command_handler.add_command(quote)
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
await client.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
315
examples/hybrid_bot.py
Normal file
315
examples/hybrid_bot.py
Normal file
@ -0,0 +1,315 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional, Literal, Union
|
||||||
|
|
||||||
|
from disagreement import HybridContext
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.ext.commands.cog import Cog
|
||||||
|
from disagreement.ext.commands.core import CommandContext
|
||||||
|
from disagreement.ext.app_commands.decorators import (
|
||||||
|
slash_command,
|
||||||
|
user_command,
|
||||||
|
message_command,
|
||||||
|
hybrid_command,
|
||||||
|
)
|
||||||
|
from disagreement.ext.app_commands.commands import (
|
||||||
|
AppCommandGroup,
|
||||||
|
) # For defining groups
|
||||||
|
from disagreement.ext.app_commands.context import AppCommandContext # Added
|
||||||
|
from disagreement.models import (
|
||||||
|
User,
|
||||||
|
Member,
|
||||||
|
Role,
|
||||||
|
Attachment,
|
||||||
|
Message,
|
||||||
|
Channel,
|
||||||
|
) # For type hints
|
||||||
|
from disagreement.enums import (
|
||||||
|
ChannelType,
|
||||||
|
) # For channel option type hints, assuming it exists
|
||||||
|
|
||||||
|
# from disagreement.interactions import Interaction # Replaced by AppCommandContext
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Define a Test Cog ---
|
||||||
|
class TestCog(Cog):
|
||||||
|
def __init__(self, client: Client):
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
@slash_command(name="greet", description="Sends a greeting.")
|
||||||
|
async def greet_slash(self, ctx: AppCommandContext, name: str):
|
||||||
|
await ctx.send(f"Hello, {name}! (Slash)")
|
||||||
|
|
||||||
|
@user_command(name="Show User Info")
|
||||||
|
async def show_user_info_user(
|
||||||
|
self, ctx: AppCommandContext, user: User
|
||||||
|
): # Target user is in ctx.interaction.data.target_id and resolved
|
||||||
|
target_user = (
|
||||||
|
ctx.interaction.data.resolved.users.get(ctx.interaction.data.target_id)
|
||||||
|
if ctx.interaction.data
|
||||||
|
and ctx.interaction.data.resolved
|
||||||
|
and ctx.interaction.data.target_id
|
||||||
|
else user
|
||||||
|
)
|
||||||
|
if target_user:
|
||||||
|
await ctx.send(
|
||||||
|
f"User: {target_user.username}#{target_user.discriminator} (ID: {target_user.id}) (User Cmd)",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await ctx.send("Could not find user information.", ephemeral=True)
|
||||||
|
|
||||||
|
@message_command(name="Quote Message")
|
||||||
|
async def quote_message_msg(
|
||||||
|
self, ctx: AppCommandContext, message: Message
|
||||||
|
): # Target message is in ctx.interaction.data.target_id and resolved
|
||||||
|
target_message = (
|
||||||
|
ctx.interaction.data.resolved.messages.get(ctx.interaction.data.target_id)
|
||||||
|
if ctx.interaction.data
|
||||||
|
and ctx.interaction.data.resolved
|
||||||
|
and ctx.interaction.data.target_id
|
||||||
|
else message
|
||||||
|
)
|
||||||
|
if target_message:
|
||||||
|
await ctx.send(
|
||||||
|
f'Quoting {target_message.author.username}: "{target_message.content}" (Message Cmd)',
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await ctx.send("Could not find message to quote.", ephemeral=True)
|
||||||
|
|
||||||
|
@hybrid_command(name="ping", description="Checks bot latency.", aliases=["pong"])
|
||||||
|
async def ping_hybrid(
|
||||||
|
self, ctx: Union[CommandContext, AppCommandContext], arg: Optional[str] = None
|
||||||
|
):
|
||||||
|
# latency = self.client.latency # Assuming client has latency attribute from gateway - Commented out for now
|
||||||
|
latency_ms = "N/A" # Placeholder
|
||||||
|
hybrid = HybridContext(ctx)
|
||||||
|
await hybrid.send(f"Pong! Arg: {arg} (Hybrid)")
|
||||||
|
|
||||||
|
@slash_command(name="options_test", description="Tests various option types.")
|
||||||
|
async def options_test_slash(
|
||||||
|
self,
|
||||||
|
ctx: AppCommandContext,
|
||||||
|
text: str,
|
||||||
|
integer: int,
|
||||||
|
boolean: bool,
|
||||||
|
number: float,
|
||||||
|
user_option: User,
|
||||||
|
role_option: Role,
|
||||||
|
attachment_option: Attachment,
|
||||||
|
choice_option_str: Literal["apple", "banana", "cherry"],
|
||||||
|
# Channel and member options as well as numeric Literal choices are
|
||||||
|
# not yet exercised in tests pending full library support.
|
||||||
|
):
|
||||||
|
response_parts = [
|
||||||
|
f"Text: {text}",
|
||||||
|
f"Integer: {integer}",
|
||||||
|
f"Boolean: {boolean}",
|
||||||
|
f"Number: {number}",
|
||||||
|
f"User: {user_option.username}#{user_option.discriminator}",
|
||||||
|
f"Role: {role_option.name}",
|
||||||
|
f"Attachment: {attachment_option.filename} (URL: {attachment_option.url})",
|
||||||
|
f"Choice Str: {choice_option_str}",
|
||||||
|
]
|
||||||
|
await ctx.send("\n".join(response_parts), ephemeral=True)
|
||||||
|
|
||||||
|
# --- Subcommand Group Test ---
|
||||||
|
# Define the group as a class attribute.
|
||||||
|
# The AppCommandHandler's discovery mechanism (via Cog) should pick up AppCommandGroup instances.
|
||||||
|
settings_group = AppCommandGroup(
|
||||||
|
name="settings",
|
||||||
|
description="Manage bot settings.",
|
||||||
|
# guild_ids can be added here if the group is guild-specific
|
||||||
|
)
|
||||||
|
|
||||||
|
@slash_command(
|
||||||
|
name="show", description="Shows current setting values.", parent=settings_group
|
||||||
|
)
|
||||||
|
async def settings_show(
|
||||||
|
self, ctx: AppCommandContext, setting_name: Optional[str] = None
|
||||||
|
):
|
||||||
|
if setting_name:
|
||||||
|
await ctx.send(
|
||||||
|
f"Showing value for setting: {setting_name} (Value: Placeholder)",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await ctx.send(
|
||||||
|
"Showing all settings: (Placeholder for all settings)", ephemeral=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@slash_command(
|
||||||
|
name="update", description="Updates a setting.", parent=settings_group
|
||||||
|
)
|
||||||
|
async def settings_update(
|
||||||
|
self, ctx: AppCommandContext, setting_name: str, value: str
|
||||||
|
):
|
||||||
|
await ctx.send(
|
||||||
|
f"Updated setting: {setting_name} to value: {value}", ephemeral=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# The Cog's metaclass or command registration logic should handle adding `settings_group`
|
||||||
|
# (and its subcommands) to the client's AppCommandHandler.
|
||||||
|
# The decorators now handle associating subcommands with their parent group.
|
||||||
|
|
||||||
|
@slash_command(
|
||||||
|
name="numeric_choices_test", description="Tests integer and float choices."
|
||||||
|
)
|
||||||
|
async def numeric_choices_test_slash(
|
||||||
|
self,
|
||||||
|
ctx: AppCommandContext,
|
||||||
|
int_choice: Literal[10, 20, 30, 42],
|
||||||
|
float_choice: float,
|
||||||
|
):
|
||||||
|
response = (
|
||||||
|
f"Integer Choice: {int_choice} (Type: {type(int_choice).__name__})\n"
|
||||||
|
f"Float Choice: {float_choice} (Type: {type(float_choice).__name__})"
|
||||||
|
)
|
||||||
|
await ctx.send(response, ephemeral=True)
|
||||||
|
|
||||||
|
@slash_command(
|
||||||
|
name="numeric_choices_extended",
|
||||||
|
description="Tests additional integer and float choice handling.",
|
||||||
|
)
|
||||||
|
async def numeric_choices_extended_slash(
|
||||||
|
self,
|
||||||
|
ctx: AppCommandContext,
|
||||||
|
int_choice: Literal[-5, 0, 5],
|
||||||
|
float_choice: float,
|
||||||
|
):
|
||||||
|
response = (
|
||||||
|
f"Int Choice: {int_choice} (Type: {type(int_choice).__name__})\n"
|
||||||
|
f"Float Choice: {float_choice} (Type: {type(float_choice).__name__})"
|
||||||
|
)
|
||||||
|
await ctx.send(response, ephemeral=True)
|
||||||
|
|
||||||
|
@slash_command(
|
||||||
|
name="channel_member_test",
|
||||||
|
description="Tests channel and member options.",
|
||||||
|
)
|
||||||
|
async def channel_member_test_slash(
|
||||||
|
self,
|
||||||
|
ctx: AppCommandContext,
|
||||||
|
channel: Channel,
|
||||||
|
member: Member,
|
||||||
|
):
|
||||||
|
response = (
|
||||||
|
f"Channel: {channel.name} (Type: {channel.type.name})\n"
|
||||||
|
f"Member: {member.username}#{member.discriminator}"
|
||||||
|
)
|
||||||
|
await ctx.send(response, ephemeral=True)
|
||||||
|
|
||||||
|
@slash_command(
|
||||||
|
name="channel_types_test",
|
||||||
|
description="Demonstrates multiple channel type options.",
|
||||||
|
)
|
||||||
|
async def channel_types_test_slash(
|
||||||
|
self,
|
||||||
|
ctx: AppCommandContext,
|
||||||
|
text_channel: Channel,
|
||||||
|
voice_channel: Channel,
|
||||||
|
category_channel: Channel,
|
||||||
|
):
|
||||||
|
response = (
|
||||||
|
f"Text: {text_channel.type.name}\n"
|
||||||
|
f"Voice: {voice_channel.type.name}\n"
|
||||||
|
f"Category: {category_channel.type.name}"
|
||||||
|
)
|
||||||
|
await ctx.send(response, ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Bot Script ---
|
||||||
|
async def main():
|
||||||
|
bot_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||||
|
application_id = os.getenv("DISCORD_APPLICATION_ID")
|
||||||
|
|
||||||
|
if not bot_token:
|
||||||
|
logger.error("Error: DISCORD_BOT_TOKEN environment variable not set.")
|
||||||
|
return
|
||||||
|
if not application_id:
|
||||||
|
logger.error("Error: DISCORD_APPLICATION_ID environment variable not set.")
|
||||||
|
return
|
||||||
|
|
||||||
|
client = Client(token=bot_token, command_prefix="!", application_id=application_id)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
if client.user:
|
||||||
|
logger.info(
|
||||||
|
f"Bot logged in as {client.user.username}#{client.user.discriminator}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Client ready, but client.user is not populated! This should not happen."
|
||||||
|
)
|
||||||
|
return # Avoid proceeding if basic client info isn't there
|
||||||
|
|
||||||
|
if client.application_id:
|
||||||
|
logger.info(f"Application ID is: {client.application_id}")
|
||||||
|
# Sync application commands (global in this case)
|
||||||
|
try:
|
||||||
|
logger.info("Attempting to sync application commands...")
|
||||||
|
# Ensure application_id is not None before passing
|
||||||
|
app_id_to_sync = client.application_id
|
||||||
|
if (
|
||||||
|
app_id_to_sync is not None
|
||||||
|
): # Redundant due to outer if, but good for clarity
|
||||||
|
await client.app_command_handler.sync_commands(
|
||||||
|
application_id=app_id_to_sync
|
||||||
|
)
|
||||||
|
logger.info("Application commands synced successfully.")
|
||||||
|
else: # Should not be reached if outer if client.application_id is true
|
||||||
|
logger.error(
|
||||||
|
"Application ID was None despite initial check. Skipping sync."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error syncing application commands: {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
# This case should be less likely now that Client gets it from READY.
|
||||||
|
# If DISCORD_APPLICATION_ID was critical as a fallback, that logic would be here.
|
||||||
|
# For now, we rely on the READY event.
|
||||||
|
logger.warning(
|
||||||
|
"Client's application ID is not set after READY. Skipping application command sync."
|
||||||
|
)
|
||||||
|
# Check if the environment variable was provided, as a diagnostic.
|
||||||
|
if not application_id:
|
||||||
|
logger.warning(
|
||||||
|
"DISCORD_APPLICATION_ID environment variable was also not provided."
|
||||||
|
)
|
||||||
|
|
||||||
|
client.add_cog(TestCog(client))
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.run()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Bot shutting down...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"An error occurred in the bot's main run loop: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if not client.is_closed():
|
||||||
|
await client.close()
|
||||||
|
logger.info("Bot has been closed.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# For Windows, to allow graceful shutdown with Ctrl+C
|
||||||
|
if os.name == "nt":
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Main loop interrupted. Exiting.")
|
65
examples/modal_command.py
Normal file
65
examples/modal_command.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
"""Example showing how to present a modal using a slash command."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from disagreement import Client, ui
|
||||||
|
from disagreement.enums import GatewayIntent, TextInputStyle
|
||||||
|
from disagreement.ext.app_commands.decorators import slash_command
|
||||||
|
from disagreement.ext.app_commands.context import AppCommandContext
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
token = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||||
|
application_id = os.getenv("DISCORD_APPLICATION_ID", "")
|
||||||
|
|
||||||
|
client = Client(
|
||||||
|
token=token, application_id=application_id, intents=GatewayIntent.default()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackModal(ui.Modal):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(title="Feedback", custom_id="feedback")
|
||||||
|
|
||||||
|
@ui.text_input(label="Your feedback", style=TextInputStyle.PARAGRAPH)
|
||||||
|
async def feedback(self, interaction):
|
||||||
|
await interaction.respond(content="Thanks for your feedback!", ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
|
@slash_command(name="feedback", description="Send feedback via a modal")
|
||||||
|
async def feedback_command(ctx: AppCommandContext):
|
||||||
|
await ctx.interaction.respond_modal(FeedbackModal())
|
||||||
|
|
||||||
|
|
||||||
|
client.app_command_handler.add_command(feedback_command)
|
||||||
|
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
"""Called when the bot is ready and connected to Discord."""
|
||||||
|
if client.user:
|
||||||
|
print(f"Bot is ready! Logged in as {client.user.username}")
|
||||||
|
print("Attempting to sync application commands...")
|
||||||
|
try:
|
||||||
|
if client.application_id:
|
||||||
|
await client.app_command_handler.sync_commands(
|
||||||
|
application_id=client.application_id
|
||||||
|
)
|
||||||
|
print("Application commands synced successfully.")
|
||||||
|
else:
|
||||||
|
print("Skipping command sync: application ID is not set.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error syncing application commands: {e}")
|
||||||
|
else:
|
||||||
|
print("Bot is ready, but client.user is missing!")
|
||||||
|
print("------")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await client.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
63
examples/modal_send.py
Normal file
63
examples/modal_send.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
"""Example showing how to send a modal."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
from disagreement import Client, GatewayIntent, ui # type: ignore
|
||||||
|
from disagreement.ext.app_commands.decorators import slash_command
|
||||||
|
from disagreement.ext.app_commands.context import AppCommandContext
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||||
|
APP_ID = os.getenv("DISCORD_APPLICATION_ID", "")
|
||||||
|
|
||||||
|
if not TOKEN:
|
||||||
|
print("DISCORD_BOT_TOKEN not set")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
client = Client(token=TOKEN, intents=GatewayIntent.default(), application_id=APP_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class NameModal(ui.Modal):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(title="Your Name", custom_id="name_modal")
|
||||||
|
self.name = ui.TextInput(label="Name", custom_id="name")
|
||||||
|
|
||||||
|
|
||||||
|
@slash_command(name="namemodal", description="Shows a modal")
|
||||||
|
async def _namemodal(ctx: AppCommandContext):
|
||||||
|
await ctx.interaction.response.send_modal(NameModal())
|
||||||
|
|
||||||
|
|
||||||
|
client.app_command_handler.add_command(_namemodal)
|
||||||
|
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
"""Called when the bot is ready and connected to Discord."""
|
||||||
|
if client.user:
|
||||||
|
print(f"Bot is ready! Logged in as {client.user.username}")
|
||||||
|
print("Attempting to sync application commands...")
|
||||||
|
try:
|
||||||
|
if client.application_id:
|
||||||
|
await client.app_command_handler.sync_commands(
|
||||||
|
application_id=client.application_id
|
||||||
|
)
|
||||||
|
print("Application commands synced successfully.")
|
||||||
|
else:
|
||||||
|
print("Skipping command sync: application ID is not set.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error syncing application commands: {e}")
|
||||||
|
else:
|
||||||
|
print("Bot is ready, but client.user is missing!")
|
||||||
|
print("------")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(client.run())
|
36
examples/sharded_bot.py
Normal file
36
examples/sharded_bot.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
"""Example bot demonstrating gateway sharding."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Ensure local package is importable when running from the examples directory
|
||||||
|
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
import disagreement
|
||||||
|
|
||||||
|
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||||
|
if not TOKEN:
|
||||||
|
raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set")
|
||||||
|
|
||||||
|
client = disagreement.Client(token=TOKEN, shard_count=2)
|
||||||
|
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
if client.user:
|
||||||
|
print(f"Shard bot ready as {client.user.username}#{client.user.discriminator}")
|
||||||
|
else:
|
||||||
|
print("Shard bot ready")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
if not TOKEN:
|
||||||
|
print("DISCORD_BOT_TOKEN environment variable not set")
|
||||||
|
return
|
||||||
|
await client.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
30
examples/task_loop.py
Normal file
30
examples/task_loop.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Example showing the tasks extension."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Allow running from the examples folder without installing
|
||||||
|
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
from disagreement.ext import tasks
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
@tasks.loop(seconds=1.0)
|
||||||
|
async def ticker() -> None:
|
||||||
|
global counter
|
||||||
|
counter += 1
|
||||||
|
print(f"Tick {counter}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
ticker.start()
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
ticker.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
61
examples/voice_bot.py
Normal file
61
examples/voice_bot.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""Example bot demonstrating VoiceClient usage."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# If running from the examples directory
|
||||||
|
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import disagreement
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
_VOICE_ENDPOINT = os.getenv("DISCORD_VOICE_ENDPOINT")
|
||||||
|
_VOICE_TOKEN = os.getenv("DISCORD_VOICE_TOKEN")
|
||||||
|
_VOICE_SESSION_ID = os.getenv("DISCORD_SESSION_ID")
|
||||||
|
_GUILD_ID = os.getenv("DISCORD_GUILD_ID")
|
||||||
|
_USER_ID = os.getenv("DISCORD_USER_ID")
|
||||||
|
|
||||||
|
if not all([_VOICE_ENDPOINT, _VOICE_TOKEN, _VOICE_SESSION_ID, _GUILD_ID, _USER_ID]):
|
||||||
|
print("Missing one or more required environment variables for voice connection")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
assert _VOICE_ENDPOINT
|
||||||
|
assert _VOICE_TOKEN
|
||||||
|
assert _VOICE_SESSION_ID
|
||||||
|
assert _GUILD_ID
|
||||||
|
assert _USER_ID
|
||||||
|
|
||||||
|
VOICE_ENDPOINT = cast(str, _VOICE_ENDPOINT)
|
||||||
|
VOICE_TOKEN = cast(str, _VOICE_TOKEN)
|
||||||
|
VOICE_SESSION_ID = cast(str, _VOICE_SESSION_ID)
|
||||||
|
GUILD_ID = int(cast(str, _GUILD_ID))
|
||||||
|
USER_ID = int(cast(str, _USER_ID))
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
vc = disagreement.VoiceClient(
|
||||||
|
VOICE_ENDPOINT,
|
||||||
|
VOICE_SESSION_ID,
|
||||||
|
VOICE_TOKEN,
|
||||||
|
GUILD_ID,
|
||||||
|
USER_ID,
|
||||||
|
)
|
||||||
|
await vc.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send silence frame as an example
|
||||||
|
await vc.send_audio_frame(b"\xf8\xff\xfe")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
finally:
|
||||||
|
await vc.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
55
pyproject.toml
Normal file
55
pyproject.toml
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
[project]
|
||||||
|
name = "disagreement"
|
||||||
|
version = "0.0.1"
|
||||||
|
description = "A Python library for the Discord API."
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
license = {text = "BSD 3-Clause"}
|
||||||
|
authors = [
|
||||||
|
{name = "Slipstream", email = "me@slipstreamm.dev"}
|
||||||
|
]
|
||||||
|
keywords = ["discord", "api", "bot", "async", "aiohttp"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: BSD License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Topic :: Software Development :: Libraries",
|
||||||
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
|
"Topic :: Internet",
|
||||||
|
]
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"aiohttp>=3.9.0,<4.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-asyncio>=1.0.0",
|
||||||
|
"hypothesis>=6.89.0",
|
||||||
|
]
|
||||||
|
dev = [
|
||||||
|
"dotenv>=0.0.5",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/Slipstreamm/disagreement"
|
||||||
|
Issues = "https://github.com/Slipstreamm/disagreement/issues"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
# Optional: for linting/formatting, e.g., Ruff
|
||||||
|
# [tool.ruff]
|
||||||
|
# line-length = 88
|
||||||
|
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
|
||||||
|
# ignore = []
|
||||||
|
|
||||||
|
# [tool.ruff.format]
|
||||||
|
# quote-style = "double"
|
9
pyrightconfig.json
Normal file
9
pyrightconfig.json
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"include": ["."],
|
||||||
|
"exclude": ["**/node_modules", "**/__pycache__", "**/.venv", "**/.git", "**/dist", "**/build", "**/tests/**", "tavilytool.py"],
|
||||||
|
"ignore": [],
|
||||||
|
"reportMissingImports": true,
|
||||||
|
"reportMissingTypeStubs": false,
|
||||||
|
"pythonVersion": "3.13",
|
||||||
|
"typeCheckingMode": "standard"
|
||||||
|
}
|
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
-e .[test,dev]
|
171
tavilytool.py
Normal file
171
tavilytool.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tavily API Script for AI Agents
|
||||||
|
Execute with: python tavily.py "your search query"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import requests # type: ignore
|
||||||
|
import argparse
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class TavilyAPI:
|
||||||
|
def __init__(self, api_key: str):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = "https://api.tavily.com"
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
search_depth: str = "basic",
|
||||||
|
include_answer: bool = True,
|
||||||
|
include_images: bool = False,
|
||||||
|
include_raw_content: bool = False,
|
||||||
|
max_results: int = 5,
|
||||||
|
include_domains: Optional[List[str]] = None,
|
||||||
|
exclude_domains: Optional[List[str]] = None,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Perform a search using Tavily API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query string
|
||||||
|
search_depth: "basic" or "advanced"
|
||||||
|
include_answer: Include AI-generated answer
|
||||||
|
include_images: Include images in results
|
||||||
|
include_raw_content: Include raw HTML content
|
||||||
|
max_results: Maximum number of results (1-20)
|
||||||
|
include_domains: List of domains to include
|
||||||
|
exclude_domains: List of domains to exclude
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing search results
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/search"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"query": query,
|
||||||
|
"search_depth": search_depth,
|
||||||
|
"include_answer": include_answer,
|
||||||
|
"include_images": include_images,
|
||||||
|
"include_raw_content": include_raw_content,
|
||||||
|
"max_results": max_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_domains:
|
||||||
|
payload["include_domains"] = include_domains
|
||||||
|
if exclude_domains:
|
||||||
|
payload["exclude_domains"] = exclude_domains
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, json=payload, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
return {"error": f"API request failed: {str(e)}"}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {"error": "Invalid JSON response from API"}
|
||||||
|
|
||||||
|
|
||||||
|
def format_results(results: Dict) -> str:
|
||||||
|
"""Format search results for display"""
|
||||||
|
if "error" in results:
|
||||||
|
return f"❌ Error: {results['error']}"
|
||||||
|
|
||||||
|
output = []
|
||||||
|
|
||||||
|
# Add answer if available
|
||||||
|
if results.get("answer"):
|
||||||
|
output.append("🤖 AI Answer:")
|
||||||
|
output.append(f" {results['answer']}")
|
||||||
|
output.append("")
|
||||||
|
|
||||||
|
# Add search results
|
||||||
|
if results.get("results"):
|
||||||
|
output.append("🔍 Search Results:")
|
||||||
|
for i, result in enumerate(results["results"], 1):
|
||||||
|
output.append(f" {i}. {result.get('title', 'No title')}")
|
||||||
|
output.append(f" URL: {result.get('url', 'No URL')}")
|
||||||
|
if result.get("content"):
|
||||||
|
# Truncate content to first 200 chars
|
||||||
|
content = (
|
||||||
|
result["content"][:200] + "..."
|
||||||
|
if len(result["content"]) > 200
|
||||||
|
else result["content"]
|
||||||
|
)
|
||||||
|
output.append(f" Content: {content}")
|
||||||
|
output.append("")
|
||||||
|
|
||||||
|
# Add images if available
|
||||||
|
if results.get("images"):
|
||||||
|
output.append("🖼️ Images:")
|
||||||
|
for img in results["images"][:3]: # Show first 3 images
|
||||||
|
output.append(f" {img}")
|
||||||
|
output.append("")
|
||||||
|
|
||||||
|
return "\n".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Search using Tavily API")
|
||||||
|
parser.add_argument("query", help="Search query")
|
||||||
|
parser.add_argument(
|
||||||
|
"--depth",
|
||||||
|
choices=["basic", "advanced"],
|
||||||
|
default="basic",
|
||||||
|
help="Search depth (default: basic)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-results",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Maximum number of results (default: 5)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-images", action="store_true", help="Include images in results"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-answer", action="store_true", help="Don't include AI-generated answer"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-domains", nargs="+", help="Include only these domains"
|
||||||
|
)
|
||||||
|
parser.add_argument("--exclude-domains", nargs="+", help="Exclude these domains")
|
||||||
|
parser.add_argument("--raw", action="store_true", help="Output raw JSON response")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Get API key from environment
|
||||||
|
api_key = os.getenv("TAVILY_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
print("❌ Error: TAVILY_API_KEY environment variable not set")
|
||||||
|
print("Set it with: export TAVILY_API_KEY='your-api-key-here'")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Initialize Tavily API
|
||||||
|
tavily = TavilyAPI(api_key)
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
results = tavily.search(
|
||||||
|
query=args.query,
|
||||||
|
search_depth=args.depth,
|
||||||
|
include_answer=not args.no_answer,
|
||||||
|
include_images=args.include_images,
|
||||||
|
max_results=args.max_results,
|
||||||
|
include_domains=args.include_domains,
|
||||||
|
exclude_domains=args.exclude_domains,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output results
|
||||||
|
if args.raw:
|
||||||
|
print(json.dumps(results, indent=2))
|
||||||
|
else:
|
||||||
|
print(format_results(results))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
111
tests/conftest.py
Normal file
111
tests/conftest.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.interactions import Interaction
|
||||||
|
from disagreement.enums import InteractionType
|
||||||
|
from disagreement.models import Message
|
||||||
|
|
||||||
|
|
||||||
|
class DummyHTTP:
|
||||||
|
def __init__(self):
|
||||||
|
self.create_interaction_response = AsyncMock()
|
||||||
|
self.create_followup_message = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "123",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.edit_original_interaction_response = AsyncMock(return_value={"id": "123"})
|
||||||
|
self.edit_followup_message = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": "123",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBot:
|
||||||
|
def __init__(self):
|
||||||
|
self._http = DummyHTTP()
|
||||||
|
self.application_id = "app123"
|
||||||
|
self._guilds = {}
|
||||||
|
self._channels = {}
|
||||||
|
|
||||||
|
def get_guild(self, gid):
|
||||||
|
return self._guilds.get(gid)
|
||||||
|
|
||||||
|
async def fetch_channel(self, cid):
|
||||||
|
return self._channels.get(cid)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyCommandHTTP:
|
||||||
|
def __init__(self):
|
||||||
|
self.edit_message = AsyncMock(return_value={"id": "321", "channel_id": "c"})
|
||||||
|
|
||||||
|
|
||||||
|
class DummyCommandBot:
|
||||||
|
def __init__(self):
|
||||||
|
self._http = DummyCommandHTTP()
|
||||||
|
|
||||||
|
async def edit_message(self, channel_id, message_id, *, content=None, **kwargs):
|
||||||
|
return await self._http.edit_message(
|
||||||
|
channel_id, message_id, {"content": content, **kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClient:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DummyInteraction:
|
||||||
|
data = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def dummy_bot():
|
||||||
|
return DummyBot()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def interaction(dummy_bot):
|
||||||
|
data = {
|
||||||
|
"id": "1",
|
||||||
|
"application_id": dummy_bot.application_id,
|
||||||
|
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||||
|
"token": "tok",
|
||||||
|
"version": 1,
|
||||||
|
}
|
||||||
|
return Interaction(data, client_instance=dummy_bot)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def command_bot():
|
||||||
|
return DummyCommandBot()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def message(command_bot):
|
||||||
|
message_data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
return Message(message_data, client_instance=command_bot)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def dummy_client():
|
||||||
|
return DummyClient()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def basic_interaction():
|
||||||
|
return DummyInteraction()
|
172
tests/test_additional_converters.py
Normal file
172
tests/test_additional_converters.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.ext.commands.converters import run_converters
|
||||||
|
from disagreement.ext.commands.core import CommandContext, Command
|
||||||
|
from disagreement.ext.commands.errors import BadArgument
|
||||||
|
from disagreement.models import Message, Member, Role, Guild
|
||||||
|
from disagreement.enums import (
|
||||||
|
VerificationLevel,
|
||||||
|
MessageNotificationLevel,
|
||||||
|
ExplicitContentFilterLevel,
|
||||||
|
MFALevel,
|
||||||
|
GuildNSFWLevel,
|
||||||
|
PremiumTier,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBot:
|
||||||
|
def __init__(self, guild: Guild):
|
||||||
|
self._guilds = {guild.id: guild}
|
||||||
|
|
||||||
|
def get_guild(self, gid):
|
||||||
|
return self._guilds.get(gid)
|
||||||
|
|
||||||
|
async def fetch_member(self, gid, mid):
|
||||||
|
guild = self._guilds.get(gid)
|
||||||
|
return guild.get_member(mid) if guild else None
|
||||||
|
|
||||||
|
async def fetch_role(self, gid, rid):
|
||||||
|
guild = self._guilds.get(gid)
|
||||||
|
return guild.get_role(rid) if guild else None
|
||||||
|
|
||||||
|
async def fetch_guild(self, gid):
|
||||||
|
return self._guilds.get(gid)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def guild_objects():
|
||||||
|
guild_data = {
|
||||||
|
"id": "1",
|
||||||
|
"name": "g",
|
||||||
|
"owner_id": "2",
|
||||||
|
"afk_timeout": 60,
|
||||||
|
"verification_level": VerificationLevel.NONE.value,
|
||||||
|
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||||
|
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||||
|
"roles": [],
|
||||||
|
"emojis": [],
|
||||||
|
"features": [],
|
||||||
|
"mfa_level": MFALevel.NONE.value,
|
||||||
|
"system_channel_flags": 0,
|
||||||
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
|
}
|
||||||
|
guild = Guild(guild_data, client_instance=None)
|
||||||
|
|
||||||
|
member = Member(
|
||||||
|
{
|
||||||
|
"user": {"id": "3", "username": "m", "discriminator": "0001"},
|
||||||
|
"joined_at": "t",
|
||||||
|
"roles": [],
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
member.guild_id = guild.id
|
||||||
|
|
||||||
|
role = Role(
|
||||||
|
{
|
||||||
|
"id": "5",
|
||||||
|
"name": "r",
|
||||||
|
"color": 0,
|
||||||
|
"hoist": False,
|
||||||
|
"position": 0,
|
||||||
|
"permissions": "0",
|
||||||
|
"managed": False,
|
||||||
|
"mentionable": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
guild._members[member.id] = member
|
||||||
|
guild.roles.append(role)
|
||||||
|
|
||||||
|
return guild, member, role
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def command_context(guild_objects):
|
||||||
|
guild, member, role = guild_objects
|
||||||
|
bot = DummyBot(guild)
|
||||||
|
message_data = {
|
||||||
|
"id": "10",
|
||||||
|
"channel_id": "20",
|
||||||
|
"guild_id": guild.id,
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(message_data, client_instance=bot)
|
||||||
|
|
||||||
|
async def dummy(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cmd = Command(dummy)
|
||||||
|
return CommandContext(
|
||||||
|
message=msg, bot=bot, prefix="!", command=cmd, invoked_with="dummy"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_member_converter(command_context, guild_objects):
|
||||||
|
_, member, _ = guild_objects
|
||||||
|
mention = f"<@!{member.id}>"
|
||||||
|
result = await run_converters(command_context, Member, mention)
|
||||||
|
assert result is member
|
||||||
|
result = await run_converters(command_context, Member, member.id)
|
||||||
|
assert result is member
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_role_converter(command_context, guild_objects):
|
||||||
|
_, _, role = guild_objects
|
||||||
|
mention = f"<@&{role.id}>"
|
||||||
|
result = await run_converters(command_context, Role, mention)
|
||||||
|
assert result is role
|
||||||
|
result = await run_converters(command_context, Role, role.id)
|
||||||
|
assert result is role
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guild_converter(command_context, guild_objects):
|
||||||
|
guild, _, _ = guild_objects
|
||||||
|
result = await run_converters(command_context, Guild, guild.id)
|
||||||
|
assert result is guild
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_member_converter_no_guild():
|
||||||
|
guild_data = {
|
||||||
|
"id": "99",
|
||||||
|
"name": "g",
|
||||||
|
"owner_id": "2",
|
||||||
|
"afk_timeout": 60,
|
||||||
|
"verification_level": VerificationLevel.NONE.value,
|
||||||
|
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||||
|
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||||
|
"roles": [],
|
||||||
|
"emojis": [],
|
||||||
|
"features": [],
|
||||||
|
"mfa_level": MFALevel.NONE.value,
|
||||||
|
"system_channel_flags": 0,
|
||||||
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
|
}
|
||||||
|
guild = Guild(guild_data, client_instance=None)
|
||||||
|
bot = DummyBot(guild)
|
||||||
|
message_data = {
|
||||||
|
"id": "11",
|
||||||
|
"channel_id": "20",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(message_data, client_instance=bot)
|
||||||
|
|
||||||
|
async def dummy(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=msg, bot=bot, prefix="!", command=Command(dummy), invoked_with="dummy"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(BadArgument):
|
||||||
|
await run_converters(ctx, Member, "<@!1>")
|
17
tests/test_cache.py
Normal file
17
tests/test_cache.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from disagreement.cache import Cache
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_store_and_get():
|
||||||
|
cache = Cache()
|
||||||
|
cache.set("a", 123)
|
||||||
|
assert cache.get("a") == 123
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_ttl_expiry():
|
||||||
|
cache = Cache(ttl=0.01)
|
||||||
|
cache.set("b", 1)
|
||||||
|
assert cache.get("b") == 1
|
||||||
|
time.sleep(0.02)
|
||||||
|
assert cache.get("b") is None
|
33
tests/test_client_context_manager.py
Normal file
33
tests/test_client_context_manager.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_async_context_closes(monkeypatch):
|
||||||
|
client = Client(token="t")
|
||||||
|
monkeypatch.setattr(client, "connect", AsyncMock())
|
||||||
|
monkeypatch.setattr(client._http, "close", AsyncMock())
|
||||||
|
|
||||||
|
async with client:
|
||||||
|
client.connect.assert_awaited_once()
|
||||||
|
|
||||||
|
client._http.close.assert_awaited_once()
|
||||||
|
assert client.is_closed()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_async_context_closes_on_exception(monkeypatch):
|
||||||
|
client = Client(token="t")
|
||||||
|
monkeypatch.setattr(client, "connect", AsyncMock())
|
||||||
|
monkeypatch.setattr(client._http, "close", AsyncMock())
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with client:
|
||||||
|
raise ValueError("boom")
|
||||||
|
|
||||||
|
client.connect.assert_awaited_once()
|
||||||
|
client._http.close.assert_awaited_once()
|
||||||
|
assert client.is_closed()
|
51
tests/test_command_checks.py
Normal file
51
tests/test_command_checks.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.ext.commands.core import Command, CommandContext
|
||||||
|
from disagreement.ext.commands.decorators import check, cooldown
|
||||||
|
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_decorator_blocks(message):
|
||||||
|
async def cb(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cmd = Command(check(lambda c: False)(cb))
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message,
|
||||||
|
bot=message._client,
|
||||||
|
prefix="!",
|
||||||
|
command=cmd,
|
||||||
|
invoked_with="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(CheckFailure):
|
||||||
|
await cmd.invoke(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cooldown_per_user(message):
|
||||||
|
uses = []
|
||||||
|
|
||||||
|
@cooldown(1, 0.05)
|
||||||
|
async def cb(ctx):
|
||||||
|
uses.append(1)
|
||||||
|
|
||||||
|
cmd = Command(cb)
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message,
|
||||||
|
bot=message._client,
|
||||||
|
prefix="!",
|
||||||
|
command=cmd,
|
||||||
|
invoked_with="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
await cmd.invoke(ctx)
|
||||||
|
|
||||||
|
with pytest.raises(CommandOnCooldown):
|
||||||
|
await cmd.invoke(ctx)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
await cmd.invoke(ctx)
|
||||||
|
assert len(uses) == 2
|
29
tests/test_components_factory.py
Normal file
29
tests/test_components_factory.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from disagreement.components import component_factory
|
||||||
|
from disagreement.enums import ComponentType, ButtonStyle
|
||||||
|
|
||||||
|
|
||||||
|
def test_component_factory_button():
|
||||||
|
data = {
|
||||||
|
"type": ComponentType.BUTTON.value,
|
||||||
|
"style": ButtonStyle.PRIMARY.value,
|
||||||
|
"label": "Click",
|
||||||
|
"custom_id": "x",
|
||||||
|
}
|
||||||
|
comp = component_factory(data)
|
||||||
|
assert comp.to_dict()["label"] == "Click"
|
||||||
|
|
||||||
|
|
||||||
|
def test_component_factory_action_row():
|
||||||
|
data = {
|
||||||
|
"type": ComponentType.ACTION_ROW.value,
|
||||||
|
"components": [
|
||||||
|
{
|
||||||
|
"type": ComponentType.BUTTON.value,
|
||||||
|
"style": ButtonStyle.PRIMARY.value,
|
||||||
|
"label": "A",
|
||||||
|
"custom_id": "b",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
row = component_factory(data)
|
||||||
|
assert len(row.components) == 1
|
199
tests/test_context.py
Normal file
199
tests/test_context.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.ext.app_commands.context import AppCommandContext
|
||||||
|
from disagreement.interactions import Interaction
|
||||||
|
from disagreement.enums import InteractionType, MessageFlags
|
||||||
|
from disagreement.models import (
|
||||||
|
Embed,
|
||||||
|
ActionRow,
|
||||||
|
Button,
|
||||||
|
Container,
|
||||||
|
TextDisplay,
|
||||||
|
)
|
||||||
|
from disagreement.enums import ButtonStyle, ComponentType
|
||||||
|
|
||||||
|
|
||||||
|
class DummyHTTP:
|
||||||
|
def __init__(self):
|
||||||
|
self.create_interaction_response = AsyncMock()
|
||||||
|
self.create_followup_message = AsyncMock(return_value={"id": "123"})
|
||||||
|
self.edit_original_interaction_response = AsyncMock(return_value={"id": "123"})
|
||||||
|
self.edit_followup_message = AsyncMock(return_value={"id": "123"})
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBot:
|
||||||
|
def __init__(self):
|
||||||
|
self._http = DummyHTTP()
|
||||||
|
self.application_id = "app123"
|
||||||
|
self._guilds = {}
|
||||||
|
self._channels = {}
|
||||||
|
|
||||||
|
def get_guild(self, gid):
|
||||||
|
return self._guilds.get(gid)
|
||||||
|
|
||||||
|
async def fetch_channel(self, cid):
|
||||||
|
return self._channels.get(cid)
|
||||||
|
from disagreement.ext.commands.core import CommandContext, Command
|
||||||
|
from disagreement.enums import MessageFlags, ButtonStyle, ComponentType
|
||||||
|
from disagreement.models import ActionRow, Button, Container, TextDisplay
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sends_extra_payload(dummy_bot, interaction):
|
||||||
|
ctx = AppCommandContext(dummy_bot, interaction)
|
||||||
|
button = Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="a")
|
||||||
|
row = ActionRow([button])
|
||||||
|
await ctx.send(
|
||||||
|
content="hi",
|
||||||
|
tts=True,
|
||||||
|
components=[row],
|
||||||
|
allowed_mentions={"parse": []},
|
||||||
|
files=[{"id": 1, "filename": "f.txt"}],
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
dummy_bot._http.create_interaction_response.assert_called_once()
|
||||||
|
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
|
||||||
|
"payload"
|
||||||
|
].data.to_dict()
|
||||||
|
assert payload["tts"] is True
|
||||||
|
assert payload["components"]
|
||||||
|
assert payload["allowed_mentions"] == {"parse": []}
|
||||||
|
assert payload["attachments"] == [{"id": 1, "filename": "f.txt"}]
|
||||||
|
assert payload["flags"] == MessageFlags.EPHEMERAL.value
|
||||||
|
await ctx.send_followup(content="again")
|
||||||
|
dummy_bot._http.create_followup_message.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_second_send_is_followup(dummy_bot, interaction):
|
||||||
|
ctx = AppCommandContext(dummy_bot, interaction)
|
||||||
|
await ctx.send(content="first")
|
||||||
|
await ctx.send_followup(content="second")
|
||||||
|
assert dummy_bot._http.create_interaction_response.call_count == 1
|
||||||
|
assert dummy_bot._http.create_followup_message.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_edit_with_components_and_attachments(dummy_bot, interaction):
|
||||||
|
ctx = AppCommandContext(dummy_bot, interaction)
|
||||||
|
await ctx.send(content="orig")
|
||||||
|
row = ActionRow([Button(style=ButtonStyle.PRIMARY, label="B", custom_id="b")])
|
||||||
|
await ctx.edit(content="new", components=[row], attachments=[{"id": 1}])
|
||||||
|
dummy_bot._http.edit_original_interaction_response.assert_called_once()
|
||||||
|
payload = dummy_bot._http.edit_original_interaction_response.call_args.kwargs[
|
||||||
|
"payload"
|
||||||
|
]
|
||||||
|
assert payload["components"]
|
||||||
|
assert payload["attachments"] == [{"id": 1}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_with_flags(dummy_bot, interaction):
|
||||||
|
ctx = AppCommandContext(dummy_bot, interaction)
|
||||||
|
await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value)
|
||||||
|
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
|
||||||
|
"payload"
|
||||||
|
].data.to_dict()
|
||||||
|
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_container_component(dummy_bot, interaction):
|
||||||
|
ctx = AppCommandContext(dummy_bot, interaction)
|
||||||
|
container = Container(components=[TextDisplay(content="hi")])
|
||||||
|
await ctx.send(components=[container], flags=MessageFlags.IS_COMPONENTS_V2.value)
|
||||||
|
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
|
||||||
|
"payload"
|
||||||
|
].data.to_dict()
|
||||||
|
assert payload["components"][0]["type"] == ComponentType.CONTAINER.value
|
||||||
|
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_command_context_edit(command_bot, message):
|
||||||
|
async def dummy(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cmd = Command(dummy)
|
||||||
|
ctx = CommandContext(
|
||||||
|
message=message, bot=command_bot, prefix="!", command=cmd, invoked_with="dummy"
|
||||||
|
)
|
||||||
|
await ctx.edit(message.id, content="new")
|
||||||
|
command_bot._http.edit_message.assert_called_once()
|
||||||
|
args = command_bot._http.edit_message.call_args[0]
|
||||||
|
assert args[0] == message.channel_id
|
||||||
|
assert args[1] == message.id
|
||||||
|
assert args[2]["content"] == "new"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_http_error_propagates(dummy_bot, interaction):
|
||||||
|
ctx = AppCommandContext(dummy_bot, interaction)
|
||||||
|
dummy_bot._http.create_interaction_response.side_effect = RuntimeError("boom")
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await ctx.send(content="hi")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_send_only_initial_once(dummy_bot, interaction):
|
||||||
|
ctx = AppCommandContext(dummy_bot, interaction)
|
||||||
|
|
||||||
|
async def send_msg(i: int):
|
||||||
|
if i == 0:
|
||||||
|
await ctx.send(content=str(i))
|
||||||
|
else:
|
||||||
|
await ctx.send_followup(content=str(i))
|
||||||
|
|
||||||
|
await asyncio.gather(*(send_msg(i) for i in range(50)))
|
||||||
|
assert dummy_bot._http.create_interaction_response.call_count == 1
|
||||||
|
assert dummy_bot._http.create_followup_message.call_count == 49
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_with_flags_2():
|
||||||
|
bot = DummyBot()
|
||||||
|
interaction = Interaction(
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"application_id": bot.application_id,
|
||||||
|
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||||
|
"token": "tok",
|
||||||
|
"version": 1,
|
||||||
|
},
|
||||||
|
client_instance=bot,
|
||||||
|
)
|
||||||
|
ctx = AppCommandContext(bot, interaction)
|
||||||
|
await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value)
|
||||||
|
payload = bot._http.create_interaction_response.call_args[1][
|
||||||
|
"payload"
|
||||||
|
].data.to_dict()
|
||||||
|
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_container_component_2():
|
||||||
|
bot = DummyBot()
|
||||||
|
interaction = Interaction(
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"application_id": bot.application_id,
|
||||||
|
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||||
|
"token": "tok",
|
||||||
|
"version": 1,
|
||||||
|
},
|
||||||
|
client_instance=bot,
|
||||||
|
)
|
||||||
|
ctx = AppCommandContext(bot, interaction)
|
||||||
|
container = Container(components=[TextDisplay(content="hi")])
|
||||||
|
await ctx.send(
|
||||||
|
components=[container],
|
||||||
|
flags=MessageFlags.IS_COMPONENTS_V2.value,
|
||||||
|
)
|
||||||
|
payload = bot._http.create_interaction_response.call_args[1][
|
||||||
|
"payload"
|
||||||
|
].data.to_dict()
|
||||||
|
assert payload["components"][0]["type"] == ComponentType.CONTAINER.value
|
||||||
|
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
|
80
tests/test_context_menus.py
Normal file
80
tests/test_context_menus.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.ext.app_commands.handler import AppCommandHandler
|
||||||
|
from disagreement.ext.app_commands.decorators import user_command, message_command
|
||||||
|
from disagreement.enums import ApplicationCommandType, InteractionType
|
||||||
|
from disagreement.interactions import Interaction
|
||||||
|
from disagreement.models import User, Message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_context_menu_invokes(dummy_bot):
|
||||||
|
handler = AppCommandHandler(dummy_bot)
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
@user_command(name="Info")
|
||||||
|
async def info(ctx, user: User):
|
||||||
|
captured["user"] = user
|
||||||
|
|
||||||
|
handler.add_command(info)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"id": "cmd",
|
||||||
|
"name": "Info",
|
||||||
|
"type": ApplicationCommandType.USER.value,
|
||||||
|
"target_id": "42",
|
||||||
|
"resolved": {
|
||||||
|
"users": {"42": {"id": "42", "username": "Bob", "discriminator": "0001"}}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"id": "1",
|
||||||
|
"application_id": dummy_bot.application_id,
|
||||||
|
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||||
|
"token": "tok",
|
||||||
|
"version": 1,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
interaction = Interaction(payload, client_instance=dummy_bot)
|
||||||
|
await handler.process_interaction(interaction)
|
||||||
|
assert isinstance(captured.get("user"), User)
|
||||||
|
assert captured["user"].id == "42"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_context_menu_invokes(dummy_bot):
|
||||||
|
handler = AppCommandHandler(dummy_bot)
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
@message_command(name="Quote")
|
||||||
|
async def quote(ctx, message: Message):
|
||||||
|
captured["msg"] = message
|
||||||
|
|
||||||
|
handler.add_command(quote)
|
||||||
|
|
||||||
|
msg_data = {
|
||||||
|
"id": "99",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "Ann", "discriminator": "0001"},
|
||||||
|
"content": "Hello",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"id": "cmd",
|
||||||
|
"name": "Quote",
|
||||||
|
"type": ApplicationCommandType.MESSAGE.value,
|
||||||
|
"target_id": "99",
|
||||||
|
"resolved": {"messages": {"99": msg_data}},
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"id": "1",
|
||||||
|
"application_id": dummy_bot.application_id,
|
||||||
|
"type": InteractionType.APPLICATION_COMMAND.value,
|
||||||
|
"token": "tok",
|
||||||
|
"version": 1,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
interaction = Interaction(payload, client_instance=dummy_bot)
|
||||||
|
await handler.process_interaction(interaction)
|
||||||
|
assert isinstance(captured.get("msg"), Message)
|
||||||
|
assert captured["msg"].id == "99"
|
30
tests/test_converter_registration.py
Normal file
30
tests/test_converter_registration.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.ext.app_commands.handler import AppCommandHandler
|
||||||
|
from disagreement.ext.app_commands.converters import Converter
|
||||||
|
|
||||||
|
|
||||||
|
class MyType:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class MyConverter(Converter[MyType]):
|
||||||
|
async def convert(self, interaction, value):
|
||||||
|
return MyType(f"converted-{value}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_converter_registration(dummy_client):
|
||||||
|
handler = AppCommandHandler(client=dummy_client)
|
||||||
|
handler.register_converter(MyType, MyConverter)
|
||||||
|
assert handler.get_converter(MyType) is MyConverter
|
||||||
|
|
||||||
|
result = await handler._resolve_value(
|
||||||
|
value="example",
|
||||||
|
expected_type=MyType,
|
||||||
|
resolved_data=None,
|
||||||
|
guild_id=None,
|
||||||
|
)
|
||||||
|
assert isinstance(result, MyType)
|
||||||
|
assert result.value == "converted-example"
|
113
tests/test_converters.py
Normal file
113
tests/test_converters.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
from hypothesis import given, strategies as st
|
||||||
|
|
||||||
|
from disagreement.ext.app_commands.converters import run_converters
|
||||||
|
from disagreement.enums import ApplicationCommandOptionType
|
||||||
|
from disagreement.errors import AppCommandOptionConversionError
|
||||||
|
from conftest import DummyInteraction, DummyClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"py_type, option_type, input_value, expected",
|
||||||
|
[
|
||||||
|
(str, ApplicationCommandOptionType.STRING, "hello", "hello"),
|
||||||
|
(int, ApplicationCommandOptionType.INTEGER, "42", 42),
|
||||||
|
(bool, ApplicationCommandOptionType.BOOLEAN, "true", True),
|
||||||
|
(float, ApplicationCommandOptionType.NUMBER, "3.14", pytest.approx(3.14)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_basic_type_converters(
|
||||||
|
basic_interaction, dummy_client, py_type, option_type, input_value, expected
|
||||||
|
):
|
||||||
|
result = await run_converters(
|
||||||
|
basic_interaction, py_type, option_type, input_value, dummy_client
|
||||||
|
)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_converters_error_cases(basic_interaction, dummy_client):
|
||||||
|
with pytest.raises(AppCommandOptionConversionError):
|
||||||
|
await run_converters(
|
||||||
|
basic_interaction,
|
||||||
|
bool,
|
||||||
|
ApplicationCommandOptionType.BOOLEAN,
|
||||||
|
"maybe",
|
||||||
|
dummy_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AppCommandOptionConversionError):
|
||||||
|
await run_converters(
|
||||||
|
basic_interaction,
|
||||||
|
list,
|
||||||
|
ApplicationCommandOptionType.MENTIONABLE,
|
||||||
|
"x",
|
||||||
|
dummy_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@given(st.text())
|
||||||
|
def test_string_roundtrip(value):
|
||||||
|
interaction = DummyInteraction()
|
||||||
|
client = DummyClient()
|
||||||
|
result = asyncio.run(
|
||||||
|
run_converters(
|
||||||
|
interaction,
|
||||||
|
str,
|
||||||
|
ApplicationCommandOptionType.STRING,
|
||||||
|
value,
|
||||||
|
client,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert result == value
|
||||||
|
|
||||||
|
|
||||||
|
@given(st.integers())
|
||||||
|
def test_integer_roundtrip(value):
|
||||||
|
interaction = DummyInteraction()
|
||||||
|
client = DummyClient()
|
||||||
|
result = asyncio.run(
|
||||||
|
run_converters(
|
||||||
|
interaction,
|
||||||
|
int,
|
||||||
|
ApplicationCommandOptionType.INTEGER,
|
||||||
|
str(value),
|
||||||
|
client,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert result == value
|
||||||
|
|
||||||
|
|
||||||
|
@given(st.booleans())
|
||||||
|
def test_boolean_roundtrip(value):
|
||||||
|
interaction = DummyInteraction()
|
||||||
|
client = DummyClient()
|
||||||
|
raw = "true" if value else "false"
|
||||||
|
result = asyncio.run(
|
||||||
|
run_converters(
|
||||||
|
interaction,
|
||||||
|
bool,
|
||||||
|
ApplicationCommandOptionType.BOOLEAN,
|
||||||
|
raw,
|
||||||
|
client,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert result is value
|
||||||
|
|
||||||
|
|
||||||
|
@given(st.floats(allow_nan=False, allow_infinity=False))
|
||||||
|
def test_number_roundtrip(value):
|
||||||
|
interaction = DummyInteraction()
|
||||||
|
client = DummyClient()
|
||||||
|
result = asyncio.run(
|
||||||
|
run_converters(
|
||||||
|
interaction,
|
||||||
|
float,
|
||||||
|
ApplicationCommandOptionType.NUMBER,
|
||||||
|
str(value),
|
||||||
|
client,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert result == pytest.approx(value)
|
38
tests/test_error_handler.py
Normal file
38
tests/test_error_handler.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.error_handler import setup_global_error_handler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_exception_logs_error(monkeypatch, capsys):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
records = []
|
||||||
|
|
||||||
|
def fake_error(msg, *args, **kwargs):
|
||||||
|
records.append(msg % args if args else msg)
|
||||||
|
|
||||||
|
monkeypatch.setattr(logging, "error", fake_error)
|
||||||
|
setup_global_error_handler(loop)
|
||||||
|
exc = RuntimeError("boom")
|
||||||
|
loop.call_exception_handler({"exception": exc})
|
||||||
|
assert any("Unhandled exception" in r for r in records)
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_logs_error(monkeypatch):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
logged = {}
|
||||||
|
|
||||||
|
def fake_error(msg, *args, **kwargs):
|
||||||
|
logged["msg"] = msg % args if args else msg
|
||||||
|
|
||||||
|
monkeypatch.setattr(logging, "error", fake_error)
|
||||||
|
setup_global_error_handler(loop)
|
||||||
|
loop.call_exception_handler({"message": "oops"})
|
||||||
|
assert "oops" in logged["msg"]
|
||||||
|
loop.close()
|
23
tests/test_errors.py
Normal file
23
tests/test_errors.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.errors import (
|
||||||
|
HTTPException,
|
||||||
|
RateLimitError,
|
||||||
|
AppCommandOptionConversionError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_exception_message():
|
||||||
|
exc = HTTPException(message="Bad", status=400)
|
||||||
|
assert str(exc) == "HTTP 400: Bad"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_error_inherits_httpexception():
|
||||||
|
exc = RateLimitError(response=None, retry_after=1.0, is_global=True)
|
||||||
|
assert isinstance(exc, HTTPException)
|
||||||
|
assert "Rate limited" in str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def test_app_command_option_conversion_error():
|
||||||
|
exc = AppCommandOptionConversionError("bad", option_name="opt", original_value="x")
|
||||||
|
assert "opt" in str(exc) and "x" in str(exc)
|
68
tests/test_event_dispatcher.py
Normal file
68
tests/test_event_dispatcher.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.event_dispatcher import EventDispatcher
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.parsed = {}
|
||||||
|
|
||||||
|
def parse_message(self, data):
|
||||||
|
self.parsed["message"] = True
|
||||||
|
return data
|
||||||
|
|
||||||
|
def parse_guild(self, data):
|
||||||
|
self.parsed["guild"] = True
|
||||||
|
return data
|
||||||
|
|
||||||
|
def parse_channel(self, data):
|
||||||
|
self.parsed["channel"] = True
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_calls_listener():
|
||||||
|
client = DummyClient()
|
||||||
|
dispatcher = EventDispatcher(client)
|
||||||
|
called = {}
|
||||||
|
|
||||||
|
async def listener(payload):
|
||||||
|
called["data"] = payload
|
||||||
|
|
||||||
|
dispatcher.register("MESSAGE_CREATE", listener)
|
||||||
|
await dispatcher.dispatch("MESSAGE_CREATE", {"id": 1})
|
||||||
|
assert called["data"] == {"id": 1}
|
||||||
|
assert client.parsed.get("message")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_listener_no_args():
|
||||||
|
client = DummyClient()
|
||||||
|
dispatcher = EventDispatcher(client)
|
||||||
|
called = False
|
||||||
|
|
||||||
|
async def listener():
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
|
||||||
|
dispatcher.register("GUILD_CREATE", listener)
|
||||||
|
await dispatcher.dispatch("GUILD_CREATE", {"id": 123})
|
||||||
|
assert called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unregister_listener():
|
||||||
|
client = DummyClient()
|
||||||
|
dispatcher = EventDispatcher(client)
|
||||||
|
called = False
|
||||||
|
|
||||||
|
async def listener(_):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
|
||||||
|
dispatcher.register("MESSAGE_CREATE", listener)
|
||||||
|
dispatcher.unregister("MESSAGE_CREATE", listener)
|
||||||
|
await dispatcher.dispatch("MESSAGE_CREATE", {"id": 1})
|
||||||
|
assert not called
|
42
tests/test_event_error_hook.py
Normal file
42
tests/test_event_error_hook.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.event_dispatcher import EventDispatcher
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClient:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_error_hook_called():
|
||||||
|
dispatcher = EventDispatcher(DummyClient())
|
||||||
|
hook = AsyncMock()
|
||||||
|
dispatcher.on_dispatch_error = hook
|
||||||
|
|
||||||
|
async def listener(_):
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
dispatcher.register("TEST_EVENT", listener)
|
||||||
|
await dispatcher.dispatch("TEST_EVENT", {})
|
||||||
|
|
||||||
|
hook.assert_awaited_once()
|
||||||
|
args = hook.call_args.args
|
||||||
|
assert args[0] == "TEST_EVENT"
|
||||||
|
assert isinstance(args[1], RuntimeError)
|
||||||
|
assert args[2] is listener
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_error_hook_not_called_when_ok():
|
||||||
|
dispatcher = EventDispatcher(DummyClient())
|
||||||
|
hook = AsyncMock()
|
||||||
|
dispatcher.on_dispatch_error = hook
|
||||||
|
|
||||||
|
async def listener(_):
|
||||||
|
return
|
||||||
|
|
||||||
|
dispatcher.register("TEST_EVENT", listener)
|
||||||
|
await dispatcher.dispatch("TEST_EVENT", {})
|
||||||
|
|
||||||
|
hook.assert_not_awaited()
|
44
tests/test_extension_loader.py
Normal file
44
tests/test_extension_loader.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.ext import loader
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_module(name):
|
||||||
|
mod = types.ModuleType(name)
|
||||||
|
called = {"setup": False, "teardown": False}
|
||||||
|
|
||||||
|
def setup():
|
||||||
|
called["setup"] = True
|
||||||
|
|
||||||
|
def teardown():
|
||||||
|
called["teardown"] = True
|
||||||
|
|
||||||
|
mod.setup = setup
|
||||||
|
mod.teardown = teardown
|
||||||
|
sys.modules[name] = mod
|
||||||
|
return called
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_and_unload_extension():
|
||||||
|
called = create_dummy_module("dummy_ext")
|
||||||
|
|
||||||
|
module = loader.load_extension("dummy_ext")
|
||||||
|
assert module is sys.modules["dummy_ext"]
|
||||||
|
assert called["setup"] is True
|
||||||
|
|
||||||
|
loader.unload_extension("dummy_ext")
|
||||||
|
assert called["teardown"] is True
|
||||||
|
assert "dummy_ext" not in loader._loaded_extensions
|
||||||
|
assert "dummy_ext" not in sys.modules
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_extension_twice_raises():
|
||||||
|
called = create_dummy_module("repeat_ext")
|
||||||
|
loader.load_extension("repeat_ext")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
loader.load_extension("repeat_ext")
|
||||||
|
loader.unload_extension("repeat_ext")
|
||||||
|
assert called["teardown"] is True
|
67
tests/test_gateway_backoff.py
Normal file
67
tests/test_gateway_backoff.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.gateway import GatewayClient, GatewayException
|
||||||
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
class DummyHTTP:
|
||||||
|
async def get_gateway_bot(self):
|
||||||
|
return {"url": "ws://example"}
|
||||||
|
|
||||||
|
async def _ensure_session(self):
|
||||||
|
self._session = AsyncMock()
|
||||||
|
self._session.ws_connect = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDispatcher:
|
||||||
|
async def dispatch(self, *_):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
self.application_id = None # Mock application_id for Client.connect
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_connect_backoff(monkeypatch):
|
||||||
|
http = DummyHTTP()
|
||||||
|
# Mock the GatewayClient's connect method to simulate failures and then success
|
||||||
|
mock_gateway_connect = AsyncMock(
|
||||||
|
side_effect=[GatewayException("boom"), GatewayException("boom"), None]
|
||||||
|
)
|
||||||
|
# Create a dummy client instance
|
||||||
|
client = Client(
|
||||||
|
token="test_token",
|
||||||
|
intents=0,
|
||||||
|
loop=asyncio.get_event_loop(),
|
||||||
|
command_prefix="!",
|
||||||
|
verbose=False,
|
||||||
|
mention_replies=False,
|
||||||
|
shard_count=None,
|
||||||
|
)
|
||||||
|
# Patch the internal _gateway attribute after client initialization
|
||||||
|
# This ensures _initialize_gateway is called and _gateway is set
|
||||||
|
await client._initialize_gateway()
|
||||||
|
monkeypatch.setattr(client._gateway, "connect", mock_gateway_connect)
|
||||||
|
|
||||||
|
# Mock wait_until_ready to prevent it from blocking the test
|
||||||
|
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
|
||||||
|
|
||||||
|
delays = []
|
||||||
|
|
||||||
|
async def fake_sleep(d):
|
||||||
|
delays.append(d)
|
||||||
|
|
||||||
|
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
||||||
|
|
||||||
|
# Call the client's connect method, which contains the backoff logic
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
# Assert that GatewayClient.connect was called the correct number of times
|
||||||
|
assert mock_gateway_connect.call_count == 3
|
||||||
|
# Assert the delays experienced due to exponential backoff
|
||||||
|
assert delays == [5, 10]
|
57
tests/test_help_command.py
Normal file
57
tests/test_help_command.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.ext.commands.core import CommandHandler, Command
|
||||||
|
from disagreement.models import Message
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBot:
|
||||||
|
def __init__(self):
|
||||||
|
self.sent = []
|
||||||
|
|
||||||
|
async def send_message(self, channel_id, content, **kwargs):
|
||||||
|
self.sent.append(content)
|
||||||
|
return {"id": "1", "channel_id": channel_id, "content": content}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_help_lists_commands():
|
||||||
|
bot = DummyBot()
|
||||||
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
|
||||||
|
async def foo(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
handler.add_command(Command(foo, name="foo", brief="Foo cmd"))
|
||||||
|
|
||||||
|
msg_data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "!help",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(msg_data, client_instance=bot)
|
||||||
|
await handler.process_commands(msg)
|
||||||
|
assert any("foo" in m for m in bot.sent)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_help_specific_command():
|
||||||
|
bot = DummyBot()
|
||||||
|
handler = CommandHandler(client=bot, prefix="!")
|
||||||
|
|
||||||
|
async def bar(ctx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
handler.add_command(Command(bar, name="bar", description="Bar desc"))
|
||||||
|
|
||||||
|
msg_data = {
|
||||||
|
"id": "1",
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "!help bar",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
msg = Message(msg_data, client_instance=bot)
|
||||||
|
await handler.process_commands(msg)
|
||||||
|
assert any("Bar desc" in m for m in bot.sent)
|
57
tests/test_http_reactions.py
Normal file
57
tests/test_http_reactions.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
from disagreement.errors import DisagreementException
|
||||||
|
from disagreement.models import User
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_reaction_calls_http():
|
||||||
|
http = SimpleNamespace(create_reaction=AsyncMock())
|
||||||
|
client = Client.__new__(Client)
|
||||||
|
client._http = http
|
||||||
|
client._closed = False
|
||||||
|
|
||||||
|
await client.create_reaction("1", "2", "😀")
|
||||||
|
|
||||||
|
http.create_reaction.assert_called_once_with("1", "2", "😀")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_reaction_closed():
|
||||||
|
http = SimpleNamespace(create_reaction=AsyncMock())
|
||||||
|
client = Client.__new__(Client)
|
||||||
|
client._http = http
|
||||||
|
client._closed = True
|
||||||
|
|
||||||
|
with pytest.raises(DisagreementException):
|
||||||
|
await client.create_reaction("1", "2", "😀")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_reaction_calls_http():
|
||||||
|
http = SimpleNamespace(delete_reaction=AsyncMock())
|
||||||
|
client = Client.__new__(Client)
|
||||||
|
client._http = http
|
||||||
|
client._closed = False
|
||||||
|
|
||||||
|
await client.delete_reaction("1", "2", "😀")
|
||||||
|
|
||||||
|
http.delete_reaction.assert_called_once_with("1", "2", "😀")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_reactions_parses_users():
|
||||||
|
users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}]
|
||||||
|
http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload))
|
||||||
|
client = Client.__new__(Client)
|
||||||
|
client._http = http
|
||||||
|
client._closed = False
|
||||||
|
client._users = {}
|
||||||
|
|
||||||
|
users = await client.get_reactions("1", "2", "😀")
|
||||||
|
|
||||||
|
http.get_reactions.assert_called_once_with("1", "2", "😀")
|
||||||
|
assert isinstance(users[0], User)
|
44
tests/test_hybrid_context.py
Normal file
44
tests/test_hybrid_context.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.hybrid_context import HybridContext
|
||||||
|
from disagreement.ext.app_commands.context import AppCommandContext
|
||||||
|
|
||||||
|
|
||||||
|
class DummyCommandCtx:
|
||||||
|
def __init__(self):
|
||||||
|
self.sent = []
|
||||||
|
|
||||||
|
async def reply(self, *a, **kw):
|
||||||
|
self.sent.append(("reply", a, kw))
|
||||||
|
|
||||||
|
async def edit(self, *a, **kw):
|
||||||
|
self.sent.append(("edit", a, kw))
|
||||||
|
|
||||||
|
|
||||||
|
class DummyAppCtx(AppCommandContext):
|
||||||
|
def __init__(self):
|
||||||
|
self.sent = []
|
||||||
|
|
||||||
|
async def send(self, *a, **kw):
|
||||||
|
self.sent.append(("send", a, kw))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_routes_based_on_context():
|
||||||
|
cctx = DummyCommandCtx()
|
||||||
|
actx = DummyAppCtx()
|
||||||
|
await HybridContext(cctx).send("hi")
|
||||||
|
await HybridContext(actx).send("hi")
|
||||||
|
assert cctx.sent[0][0] == "reply"
|
||||||
|
assert actx.sent[0][0] == "send"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_edit_delegation_and_error():
|
||||||
|
cctx = DummyCommandCtx()
|
||||||
|
hctx = HybridContext(cctx)
|
||||||
|
await hctx.edit("m")
|
||||||
|
assert cctx.sent[0][0] == "edit"
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
await HybridContext(DummyAppCtx()).edit("m")
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user