Initial commit

This commit is contained in:
Slipstream 2025-06-09 22:25:14 -06:00
commit 7c7cb4137c
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
120 changed files with 15345 additions and 0 deletions

34
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,34 @@
name: Python CI
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
npm install -g pyright
- name: Run Pyright
run: pyright
- name: Run tests
run: |
pytest tests/

129
.gitignore vendored Normal file
View File

@ -0,0 +1,129 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a PyInstaller script; this is deployed to PyInstaller's temporary directory.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# PEP 582; __pypackages__
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# IDE specific files
.idea/
.vscode/
.kilocode/
*.swp
*.swo
*~

64
AGENTS.md Normal file
View File

@ -0,0 +1,64 @@
# Agents
- There are no nested `AGENTS.md` files; this is the only one in the project.
- Tools to use for testing: `pyright`, `pylint`, `pytest`, `black`
- You have a Python script `tavilytool.py` in the project root that you can use to search the web.
# Tavily API Script Usage Instructions
## Basic Usage
Search for information using simple queries:
```bash
python tavilytool.py "your search query"
```
## Examples
```bash
python tavilytool.py "latest AI development 2024"
python tavilytool.py "how to make chocolate chip cookies"
python tavilytool.py "current weather in New York"
python tavilytool.py "best programming practices Python"
```
## Advanced Options
### Search Depth
- **Basic search**: `python tavilytool.py "query"` (default)
- **Advanced search**: `python tavilytool.py "query" --depth advanced`
### Control Results
- **Limit results**: `python tavilytool.py "query" --max-results 3`
- **Include images**: `python tavilytool.py "query" --include-images`
- **Skip AI answer**: `python tavilytool.py "query" --no-answer`
### Domain Filtering
- **Include specific domains**: `python tavilytool.py "query" --include-domains reddit.com stackoverflow.com`
- **Exclude domains**: `python tavilytool.py "query" --exclude-domains wikipedia.org`
### Output Format
- **Formatted output**: `python tavilytool.py "query"` (default - human readable)
- **Raw JSON**: `python tavilytool.py "query" --raw` (for programmatic processing)
## Output Structure
The default formatted output includes:
- 🤖 **AI Answer**: Direct answer to your query
- 🔍 **Search Results**: Titles, URLs, and content snippets
- 🖼️ **Images**: Relevant images (when `--include-images` is used)
## Command Combinations
```bash
# Advanced search with images, limited results
python tavilytool.py "machine learning tutorials" --depth advanced --include-images --max-results 3
# Search specific sites only, raw output
python tavilytool.py "Python best practices" --include-domains github.com stackoverflow.com --raw
# Quick search without AI answer
python tavilytool.py "today's news" --no-answer --max-results 5
```
## Tips
- Always quote your search queries to handle spaces and special characters
- Use `--max-results` to control response length and API usage
- Use `--raw` when you need to parse results programmatically
- Combine options as needed for specific use cases

26
LICENSE Normal file
View File

@ -0,0 +1,26 @@
Copyright (c) 2025, Slipstream
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

3
MANIFEST.in Normal file
View File

@ -0,0 +1,3 @@
graft docs
graft examples
include LICENSE

131
README.md Normal file
View File

@ -0,0 +1,131 @@
# Disagreement
A Python library for interacting with the Discord API, with a focus on bot development.
## Features
- Asynchronous design using `aiohttp`
- Gateway and HTTP API clients
- Slash command framework
- Message component helpers
- Built-in caching layer
- Experimental voice support
- Helpful error handling utilities
## Installation
```bash
python -m pip install -U pip
pip install disagreement
# or install from source for development
pip install -e .
```
Requires Python 3.11 or newer.
## Basic Usage
```python
import asyncio
import os
import disagreement
# Ensure DISCORD_BOT_TOKEN is set in your environment
client = disagreement.Client(token=os.environ.get("DISCORD_BOT_TOKEN"))
@client.on_event('MESSAGE_CREATE')
async def on_message(message: disagreement.Message):
print(f"Received: {message.content} from {message.author.username}")
if message.content.lower() == '!ping':
await message.reply('Pong!')
async def main():
if not client.token:
print("Error: DISCORD_BOT_TOKEN environment variable not set.")
return
try:
async with client:
await asyncio.Future() # run until cancelled
except KeyboardInterrupt:
print("Bot shutting down...")
# Add any other specific exception handling from your library, e.g., disagreement.AuthenticationError
if __name__ == '__main__':
asyncio.run(main())
```
### Global Error Handling
To ensure unexpected errors don't crash your bot, you can enable the library's
global error handler:
```python
import disagreement
disagreement.setup_global_error_handler()
```
Call this early in your program to log unhandled exceptions instead of letting
them terminate the process.
### Configuring Logging
Use :func:`disagreement.logging_config.setup_logging` to configure logging for
your bot. The helper accepts a logging level and an optional file path.
```python
import logging
from disagreement.logging_config import setup_logging
setup_logging(logging.INFO)
# Or log to a file
setup_logging(logging.DEBUG, file="bot.log")
```
### Defining Subcommands with `AppCommandGroup`
```python
from disagreement.ext.app_commands import AppCommandGroup
settings = AppCommandGroup("settings", "Manage settings")
@settings.command(name="show")
async def show(ctx):
"""Displays a setting."""
...
@settings.group("admin", description="Admin settings")
def admin_group():
pass
@admin_group.command(name="set")
async def set_setting(ctx, key: str, value: str):
...
## Fetching Guilds
Use `Client.fetch_guild` to retrieve a guild from the Discord API if it
isn't already cached. This is useful when working with guild IDs from
outside the gateway events.
```python
guild = await client.fetch_guild("123456789012345678")
roles = await client.fetch_roles(guild.id)
```
## Sharding
To run your bot across multiple gateway shards, pass `shard_count` when creating
the client:
```python
client = disagreement.Client(token=BOT_TOKEN, shard_count=2)
```
See `examples/sharded_bot.py` for a full example.
## Contributing
Contributions are welcome! Please open an issue or submit a pull request.
See the [docs](docs/) directory for detailed guides on components, slash commands, caching, and voice features.

36
disagreement/__init__.py Normal file
View File

@ -0,0 +1,36 @@
# disagreement/__init__.py
"""
Disagreement
~~~~~~~~~~~~
A Python library for interacting with the Discord API.
:copyright: (c) 2025 Slipstream
:license: BSD 3-Clause License, see LICENSE for more details.
"""
__title__ = "disagreement"
__author__ = "Slipstream"
__license__ = "BSD 3-Clause License"
__copyright__ = "Copyright 2025 Slipstream"
__version__ = "0.0.1"
from .client import Client
from .models import Message, User
from .voice_client import VoiceClient
from .typing import Typing
from .errors import (
DisagreementException,
HTTPException,
GatewayException,
AuthenticationError,
)
from .enums import GatewayIntent, GatewayOpcode # Export enums
from .error_handler import setup_global_error_handler
from .hybrid_context import HybridContext
from .ext import tasks
# Set up logging if desired
# import logging
# logging.getLogger(__name__).addHandler(logging.NullHandler())

55
disagreement/cache.py Normal file
View File

@ -0,0 +1,55 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
if TYPE_CHECKING:
from .models import Channel, Guild
T = TypeVar("T")
class Cache(Generic[T]):
"""Simple in-memory cache with optional TTL support."""
def __init__(self, ttl: Optional[float] = None) -> None:
self.ttl = ttl
self._data: Dict[str, tuple[T, Optional[float]]] = {}
def set(self, key: str, value: T) -> None:
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
self._data[key] = (value, expiry)
def get(self, key: str) -> Optional[T]:
item = self._data.get(key)
if not item:
return None
value, expiry = item
if expiry is not None and expiry < time.monotonic():
self.invalidate(key)
return None
return value
def invalidate(self, key: str) -> None:
self._data.pop(key, None)
def clear(self) -> None:
self._data.clear()
def values(self) -> list[T]:
now = time.monotonic()
items = []
for key, (value, expiry) in list(self._data.items()):
if expiry is not None and expiry < now:
self.invalidate(key)
else:
items.append(value)
return items
class GuildCache(Cache["Guild"]):
"""Cache specifically for :class:`Guild` objects."""
class ChannelCache(Cache["Channel"]):
"""Cache specifically for :class:`Channel` objects."""

1144
disagreement/client.py Normal file

File diff suppressed because it is too large Load Diff

166
disagreement/components.py Normal file
View File

@ -0,0 +1,166 @@
"""Message component utilities."""
from __future__ import annotations
from typing import Any, Dict, Optional, TYPE_CHECKING
from .enums import ComponentType, ButtonStyle, ChannelType, TextInputStyle
from .models import (
ActionRow,
Button,
Component,
SelectMenu,
SelectOption,
PartialEmoji,
PartialEmoji,
Section,
TextDisplay,
Thumbnail,
MediaGallery,
MediaGalleryItem,
File,
Separator,
Container,
UnfurledMediaItem,
)
if TYPE_CHECKING: # pragma: no cover - optional client for future use
from .client import Client
def component_factory(
data: Dict[str, Any], client: Optional["Client"] = None
) -> "Component":
"""Create a component object from raw API data."""
ctype = ComponentType(data["type"])
if ctype == ComponentType.ACTION_ROW:
row = ActionRow()
for comp in data.get("components", []):
row.add_component(component_factory(comp, client))
return row
if ctype == ComponentType.BUTTON:
return Button(
style=ButtonStyle(data["style"]),
label=data.get("label"),
emoji=PartialEmoji(data["emoji"]) if data.get("emoji") else None,
custom_id=data.get("custom_id"),
url=data.get("url"),
disabled=data.get("disabled", False),
)
if ctype in {
ComponentType.STRING_SELECT,
ComponentType.USER_SELECT,
ComponentType.ROLE_SELECT,
ComponentType.MENTIONABLE_SELECT,
ComponentType.CHANNEL_SELECT,
}:
options = [
SelectOption(
label=o["label"],
value=o["value"],
description=o.get("description"),
emoji=PartialEmoji(o["emoji"]) if o.get("emoji") else None,
default=o.get("default", False),
)
for o in data.get("options", [])
]
channel_types = None
if ctype == ComponentType.CHANNEL_SELECT and data.get("channel_types"):
channel_types = [ChannelType(ct) for ct in data.get("channel_types", [])]
return SelectMenu(
custom_id=data["custom_id"],
options=options,
placeholder=data.get("placeholder"),
min_values=data.get("min_values", 1),
max_values=data.get("max_values", 1),
disabled=data.get("disabled", False),
channel_types=channel_types,
type=ctype,
)
if ctype == ComponentType.TEXT_INPUT:
from .ui.modal import TextInput
return TextInput(
label=data.get("label", ""),
custom_id=data.get("custom_id"),
style=TextInputStyle(data.get("style", TextInputStyle.SHORT.value)),
placeholder=data.get("placeholder"),
required=data.get("required", True),
min_length=data.get("min_length"),
max_length=data.get("max_length"),
)
if ctype == ComponentType.SECTION:
# The components in a section can only be TextDisplay
section_components = []
for c in data.get("components", []):
comp = component_factory(c, client)
if isinstance(comp, TextDisplay):
section_components.append(comp)
accessory = None
if data.get("accessory"):
acc_comp = component_factory(data["accessory"], client)
if isinstance(acc_comp, (Thumbnail, Button)):
accessory = acc_comp
return Section(
components=section_components,
accessory=accessory,
id=data.get("id"),
)
if ctype == ComponentType.TEXT_DISPLAY:
return TextDisplay(content=data["content"], id=data.get("id"))
if ctype == ComponentType.THUMBNAIL:
return Thumbnail(
media=UnfurledMediaItem(**data["media"]),
description=data.get("description"),
spoiler=data.get("spoiler", False),
id=data.get("id"),
)
if ctype == ComponentType.MEDIA_GALLERY:
return MediaGallery(
items=[
MediaGalleryItem(
media=UnfurledMediaItem(**i["media"]),
description=i.get("description"),
spoiler=i.get("spoiler", False),
)
for i in data.get("items", [])
],
id=data.get("id"),
)
if ctype == ComponentType.FILE:
return File(
file=UnfurledMediaItem(**data["file"]),
spoiler=data.get("spoiler", False),
id=data.get("id"),
)
if ctype == ComponentType.SEPARATOR:
return Separator(
divider=data.get("divider", True),
spacing=data.get("spacing", 1),
id=data.get("id"),
)
if ctype == ComponentType.CONTAINER:
return Container(
components=[
component_factory(c, client) for c in data.get("components", [])
],
accent_color=data.get("accent_color"),
spoiler=data.get("spoiler", False),
id=data.get("id"),
)
raise ValueError(f"Unsupported component type: {ctype}")

357
disagreement/enums.py Normal file
View File

@ -0,0 +1,357 @@
# disagreement/enums.py
"""
Enums for Discord constants.
"""
from enum import IntEnum, Enum # Import Enum
class GatewayOpcode(IntEnum):
"""Represents a Discord Gateway Opcode."""
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
PRESENCE_UPDATE = 3
VOICE_STATE_UPDATE = 4
RESUME = 6
RECONNECT = 7
REQUEST_GUILD_MEMBERS = 8
INVALID_SESSION = 9
HELLO = 10
HEARTBEAT_ACK = 11
class GatewayIntent(IntEnum):
"""Represents a Discord Gateway Intent bit.
Intents are used to subscribe to specific groups of events from the Gateway.
"""
GUILDS = 1 << 0
GUILD_MEMBERS = 1 << 1 # Privileged
GUILD_MODERATION = 1 << 2 # Formerly GUILD_BANS
GUILD_EMOJIS_AND_STICKERS = 1 << 3
GUILD_INTEGRATIONS = 1 << 4
GUILD_WEBHOOKS = 1 << 5
GUILD_INVITES = 1 << 6
GUILD_VOICE_STATES = 1 << 7
GUILD_PRESENCES = 1 << 8 # Privileged
GUILD_MESSAGES = 1 << 9
GUILD_MESSAGE_REACTIONS = 1 << 10
GUILD_MESSAGE_TYPING = 1 << 11
DIRECT_MESSAGES = 1 << 12
DIRECT_MESSAGE_REACTIONS = 1 << 13
DIRECT_MESSAGE_TYPING = 1 << 14
MESSAGE_CONTENT = 1 << 15 # Privileged (as of Aug 31, 2022)
GUILD_SCHEDULED_EVENTS = 1 << 16
AUTO_MODERATION_CONFIGURATION = 1 << 20
AUTO_MODERATION_EXECUTION = 1 << 21
@classmethod
def default(cls) -> int:
"""Returns default intents (excluding privileged ones like members, presences, message content)."""
return (
cls.GUILDS
| cls.GUILD_MODERATION
| cls.GUILD_EMOJIS_AND_STICKERS
| cls.GUILD_INTEGRATIONS
| cls.GUILD_WEBHOOKS
| cls.GUILD_INVITES
| cls.GUILD_VOICE_STATES
| cls.GUILD_MESSAGES
| cls.GUILD_MESSAGE_REACTIONS
| cls.GUILD_MESSAGE_TYPING
| cls.DIRECT_MESSAGES
| cls.DIRECT_MESSAGE_REACTIONS
| cls.DIRECT_MESSAGE_TYPING
| cls.GUILD_SCHEDULED_EVENTS
| cls.AUTO_MODERATION_CONFIGURATION
| cls.AUTO_MODERATION_EXECUTION
)
@classmethod
def all(cls) -> int:
"""Returns all intents, including privileged ones. Use with caution."""
val = 0
for intent in cls:
val |= intent.value
return val
@classmethod
def privileged(cls) -> int:
"""Returns a bitmask of all privileged intents."""
return cls.GUILD_MEMBERS | cls.GUILD_PRESENCES | cls.MESSAGE_CONTENT
# --- Application Command Enums ---
class ApplicationCommandType(IntEnum):
"""Type of application command."""
CHAT_INPUT = 1
USER = 2
MESSAGE = 3
PRIMARY_ENTRY_POINT = 4
class ApplicationCommandOptionType(IntEnum):
"""Type of application command option."""
SUB_COMMAND = 1
SUB_COMMAND_GROUP = 2
STRING = 3
INTEGER = 4 # Any integer between -2^53 and 2^53
BOOLEAN = 5
USER = 6
CHANNEL = 7 # Includes all channel types + categories
ROLE = 8
MENTIONABLE = 9 # Includes users and roles
NUMBER = 10 # Any double between -2^53 and 2^53
ATTACHMENT = 11
class InteractionType(IntEnum):
"""Type of interaction."""
PING = 1
APPLICATION_COMMAND = 2
MESSAGE_COMPONENT = 3
APPLICATION_COMMAND_AUTOCOMPLETE = 4
MODAL_SUBMIT = 5
class InteractionCallbackType(IntEnum):
"""Type of interaction callback."""
PONG = 1
CHANNEL_MESSAGE_WITH_SOURCE = 4
DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE = 5
DEFERRED_UPDATE_MESSAGE = 6
UPDATE_MESSAGE = 7
APPLICATION_COMMAND_AUTOCOMPLETE_RESULT = 8
MODAL = 9 # Response to send a modal
class IntegrationType(IntEnum):
"""
Installation context(s) where the command is available,
only for globally-scoped commands.
"""
GUILD_INSTALL = (
0 # Command is available when the app is installed to a guild (default)
)
USER_INSTALL = 1 # Command is available when the app is installed to a user
class InteractionContextType(IntEnum):
"""
Interaction context(s) where the command can be used,
only for globally-scoped commands.
"""
GUILD = 0 # Command can be used in guilds
BOT_DM = 1 # Command can be used in DMs with the app's bot user
PRIVATE_CHANNEL = 2 # Command can be used in Group DMs and DMs (requires USER_INSTALL integration_type)
class MessageFlags(IntEnum):
"""Represents the flags of a message."""
CROSSPOSTED = 1 << 0
IS_CROSSPOST = 1 << 1
SUPPRESS_EMBEDS = 1 << 2
SOURCE_MESSAGE_DELETED = 1 << 3
URGENT = 1 << 4
HAS_THREAD = 1 << 5
EPHEMERAL = 1 << 6
LOADING = 1 << 7
FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8
SUPPRESS_NOTIFICATIONS = (
1 << 12
) # Discord specific, was previously 1 << 4 (IS_VOICE_MESSAGE)
IS_COMPONENTS_V2 = 1 << 15
# --- Guild Enums ---
class VerificationLevel(IntEnum):
"""Guild verification level."""
NONE = 0
LOW = 1
MEDIUM = 2
HIGH = 3
VERY_HIGH = 4
class MessageNotificationLevel(IntEnum):
"""Default message notification level for a guild."""
ALL_MESSAGES = 0
ONLY_MENTIONS = 1
class ExplicitContentFilterLevel(IntEnum):
"""Explicit content filter level for a guild."""
DISABLED = 0
MEMBERS_WITHOUT_ROLES = 1
ALL_MEMBERS = 2
class MFALevel(IntEnum):
"""Multi-Factor Authentication level for a guild."""
NONE = 0
ELEVATED = 1
class GuildNSFWLevel(IntEnum):
"""NSFW level of a guild."""
DEFAULT = 0
EXPLICIT = 1
SAFE = 2
AGE_RESTRICTED = 3
class PremiumTier(IntEnum):
"""Guild premium tier (boost level)."""
NONE = 0
TIER_1 = 1
TIER_2 = 2
TIER_3 = 3
class GuildFeature(str, Enum): # Changed from IntEnum to Enum
"""Features that a guild can have.
Note: This is not an exhaustive list and Discord may add more.
Using str as a base allows for unknown features to be stored as strings.
"""
ANIMATED_BANNER = "ANIMATED_BANNER"
ANIMATED_ICON = "ANIMATED_ICON"
APPLICATION_COMMAND_PERMISSIONS_V2 = "APPLICATION_COMMAND_PERMISSIONS_V2"
AUTO_MODERATION = "AUTO_MODERATION"
BANNER = "BANNER"
COMMUNITY = "COMMUNITY"
CREATOR_MONETIZABLE_PROVISIONAL = "CREATOR_MONETIZABLE_PROVISIONAL"
CREATOR_STORE_PAGE = "CREATOR_STORE_PAGE"
DEVELOPER_SUPPORT_SERVER = "DEVELOPER_SUPPORT_SERVER"
DISCOVERABLE = "DISCOVERABLE"
FEATURABLE = "FEATURABLE"
INVITES_DISABLED = "INVITES_DISABLED"
INVITE_SPLASH = "INVITE_SPLASH"
MEMBER_VERIFICATION_GATE_ENABLED = "MEMBER_VERIFICATION_GATE_ENABLED"
MORE_STICKERS = "MORE_STICKERS"
NEWS = "NEWS"
PARTNERED = "PARTNERED"
PREVIEW_ENABLED = "PREVIEW_ENABLED"
RAID_ALERTS_DISABLED = "RAID_ALERTS_DISABLED"
ROLE_ICONS = "ROLE_ICONS"
ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE = (
"ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE"
)
ROLE_SUBSCRIPTIONS_ENABLED = "ROLE_SUBSCRIPTIONS_ENABLED"
TICKETED_EVENTS_ENABLED = "TICKETED_EVENTS_ENABLED"
VANITY_URL = "VANITY_URL"
VERIFIED = "VERIFIED"
VIP_REGIONS = "VIP_REGIONS"
WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED"
# Add more as they become known or needed
# This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work
@classmethod
def _missing_(cls, value): # type: ignore
return str(value)
# --- Channel Enums ---
class ChannelType(IntEnum):
"""Type of channel."""
GUILD_TEXT = 0 # a text channel within a server
DM = 1 # a direct message between users
GUILD_VOICE = 2 # a voice channel within a server
GROUP_DM = 3 # a direct message between multiple users
GUILD_CATEGORY = 4 # an organizational category that contains up to 50 channels
GUILD_ANNOUNCEMENT = 5 # a channel that users can follow and crosspost into their own server (formerly GUILD_NEWS)
ANNOUNCEMENT_THREAD = (
10 # a temporary sub-channel within a GUILD_ANNOUNCEMENT channel
)
PUBLIC_THREAD = (
11 # a temporary sub-channel within a GUILD_TEXT or GUILD_ANNOUNCEMENT channel
)
PRIVATE_THREAD = 12 # a temporary sub-channel within a GUILD_TEXT channel that is only viewable by those invited and those with the MANAGE_THREADS permission
GUILD_STAGE_VOICE = (
13 # a voice channel for hosting events with speakers and audiences
)
GUILD_DIRECTORY = 14 # a channel in a hub containing the listed servers
GUILD_FORUM = 15 # (Still in development) a channel that can only contain threads
GUILD_MEDIA = 16 # (Still in development) a channel that can only contain media
class OverwriteType(IntEnum):
"""Type of target for a permission overwrite."""
ROLE = 0
MEMBER = 1
# --- Component Enums ---
class ComponentType(IntEnum):
"""Type of message component."""
ACTION_ROW = 1
BUTTON = 2
STRING_SELECT = 3 # Formerly SELECT_MENU
TEXT_INPUT = 4
USER_SELECT = 5
ROLE_SELECT = 6
MENTIONABLE_SELECT = 7
CHANNEL_SELECT = 8
SECTION = 9
TEXT_DISPLAY = 10
THUMBNAIL = 11
MEDIA_GALLERY = 12
FILE = 13
SEPARATOR = 14
CONTAINER = 17
class ButtonStyle(IntEnum):
"""Style of a button component."""
# Blurple
PRIMARY = 1
# Grey
SECONDARY = 2
# Green
SUCCESS = 3
# Red
DANGER = 4
# Grey, navigates to a URL
LINK = 5
class TextInputStyle(IntEnum):
"""Style of a text input component."""
SHORT = 1
PARAGRAPH = 2
# Example of how you might combine intents:
# intents = GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT
# client = Client(token="YOUR_TOKEN", intents=intents)

View File

@ -0,0 +1,33 @@
import asyncio
import logging
import traceback
from typing import Optional
from .logging_config import setup_logging
def setup_global_error_handler(
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Configure a basic global error handler for the provided loop.
The handler logs unhandled exceptions so they don't crash the bot.
"""
if loop is None:
loop = asyncio.get_event_loop()
if not logging.getLogger().hasHandlers():
setup_logging(logging.ERROR)
def handle_exception(loop: asyncio.AbstractEventLoop, context: dict) -> None:
exception = context.get("exception")
if exception:
logging.error("Unhandled exception in event loop: %s", exception)
traceback.print_exception(
type(exception), exception, exception.__traceback__
)
else:
message = context.get("message")
logging.error("Event loop error: %s", message)
loop.set_exception_handler(handle_exception)

112
disagreement/errors.py Normal file
View File

@ -0,0 +1,112 @@
# disagreement/errors.py
"""
Custom exceptions for the Disagreement library.
"""
from typing import Optional, Any # Add Optional and Any here
class DisagreementException(Exception):
"""Base exception class for all errors raised by this library."""
pass
class HTTPException(DisagreementException):
"""Exception raised for HTTP-related errors.
Attributes:
response: The aiohttp response object, if available.
status: The HTTP status code.
text: The response text, if available.
error_code: Discord specific error code, if available.
"""
def __init__(
self, response=None, message=None, *, status=None, text=None, error_code=None
):
self.response = response
self.status = status or (response.status if response else None)
self.text = text or (
response.text if response else None
) # Or await response.text() if in async context
self.error_code = error_code
full_message = f"HTTP {self.status}"
if message:
full_message += f": {message}"
elif self.text:
full_message += f": {self.text}"
if self.error_code:
full_message += f" (Discord Error Code: {self.error_code})"
super().__init__(full_message)
class GatewayException(DisagreementException):
"""Exception raised for errors related to the Discord Gateway connection or protocol."""
pass
class AuthenticationError(DisagreementException):
"""Exception raised for authentication failures (e.g., invalid token)."""
pass
class RateLimitError(HTTPException):
"""
Exception raised when a rate limit is encountered.
Attributes:
retry_after (float): The number of seconds to wait before retrying.
is_global (bool): Whether this is a global rate limit.
"""
def __init__(
self, response, message=None, *, retry_after: float, is_global: bool = False
):
self.retry_after = retry_after
self.is_global = is_global
super().__init__(
response,
message
or f"Rate limited. Retry after: {retry_after}s. Global: {is_global}",
)
# You can add more specific exceptions as needed, e.g.:
# class NotFound(HTTPException):
# """Raised for 404 Not Found errors."""
# pass
# class Forbidden(HTTPException):
# """Raised for 403 Forbidden errors."""
# pass
class AppCommandError(DisagreementException):
"""Base exception for application command related errors."""
pass
class AppCommandOptionConversionError(AppCommandError):
"""Exception raised when an application command option fails to convert."""
def __init__(
self,
message: str,
option_name: Optional[str] = None,
original_value: Any = None,
):
self.option_name = option_name
self.original_value = original_value
full_message = message
if option_name:
full_message = f"Failed to convert option '{option_name}': {message}"
if original_value is not None:
full_message += f" (Original value: '{original_value}')"
super().__init__(full_message)

View File

@ -0,0 +1,243 @@
# disagreement/event_dispatcher.py
"""
Event dispatcher for handling Discord Gateway events.
"""
import asyncio
import inspect
from collections import defaultdict
from typing import (
Callable,
Coroutine,
Any,
Dict,
List,
Set,
TYPE_CHECKING,
Awaitable,
Optional,
)
from .models import Message, User # Assuming User might be part of other events
from .errors import DisagreementException
if TYPE_CHECKING:
from .client import Client # For type hinting to avoid circular imports
from .interactions import Interaction
# Type alias for an event listener
EventListener = Callable[..., Awaitable[None]]
class EventDispatcher:
"""
Manages registration and dispatching of event listeners.
"""
def __init__(self, client_instance: "Client"):
self._client: "Client" = client_instance
self._listeners: Dict[str, List[EventListener]] = defaultdict(list)
self._waiters: Dict[
str, List[tuple[asyncio.Future, Optional[Callable[[Any], bool]]]]
] = defaultdict(list)
self.on_dispatch_error: Optional[
Callable[[str, Exception, EventListener], Awaitable[None]]
] = None
# Pre-defined parsers for specific event types to convert raw data to models
self._event_parsers: Dict[str, Callable[[Dict[str, Any]], Any]] = {
"MESSAGE_CREATE": self._parse_message_create,
"INTERACTION_CREATE": self._parse_interaction_create,
"GUILD_CREATE": self._parse_guild_create,
"CHANNEL_CREATE": self._parse_channel_create,
"PRESENCE_UPDATE": self._parse_presence_update,
"TYPING_START": self._parse_typing_start,
}
def _parse_message_create(self, data: Dict[str, Any]) -> Message:
"""Parses raw MESSAGE_CREATE data into a Message object."""
return self._client.parse_message(data)
def _parse_interaction_create(self, data: Dict[str, Any]) -> "Interaction":
"""Parses raw INTERACTION_CREATE data into an Interaction object."""
from .interactions import Interaction
return Interaction(data=data, client_instance=self._client)
def _parse_guild_create(self, data: Dict[str, Any]):
"""Parses raw GUILD_CREATE data into a Guild object."""
return self._client.parse_guild(data)
def _parse_channel_create(self, data: Dict[str, Any]):
"""Parses raw CHANNEL_CREATE data into a Channel object."""
return self._client.parse_channel(data)
def _parse_presence_update(self, data: Dict[str, Any]):
"""Parses raw PRESENCE_UPDATE data into a PresenceUpdate object."""
from .models import PresenceUpdate
return PresenceUpdate(data, client_instance=self._client)
def _parse_typing_start(self, data: Dict[str, Any]):
"""Parses raw TYPING_START data into a TypingStart object."""
from .models import TypingStart
return TypingStart(data, client_instance=self._client)
# Potentially add _parse_user for events that directly provide a full user object
# def _parse_user_update(self, data: Dict[str, Any]) -> User:
# return User(data=data)
def register(self, event_name: str, coro: EventListener):
"""
Registers a coroutine function to listen for a specific event.
Args:
event_name (str): The name of the event (e.g., 'MESSAGE_CREATE').
coro (Callable): The coroutine function to call when the event occurs.
It should accept arguments appropriate for the event.
Raises:
TypeError: If the provided callback is not a coroutine function.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(
f"Event listener for '{event_name}' must be a coroutine function (async def)."
)
# Normalize event name, e.g., 'on_message' -> 'MESSAGE_CREATE'
# For now, we assume event_name is already the Discord event type string.
# If using decorators like @client.on_message, the decorator would handle this mapping.
self._listeners[event_name.upper()].append(coro)
def unregister(self, event_name: str, coro: EventListener):
"""
Unregisters a coroutine function from an event.
Args:
event_name (str): The name of the event.
coro (Callable): The coroutine function to unregister.
"""
event_name_upper = event_name.upper()
if event_name_upper in self._listeners:
try:
self._listeners[event_name_upper].remove(coro)
except ValueError:
pass # Listener not in list
def add_waiter(
self,
event_name: str,
future: asyncio.Future,
check: Optional[Callable[[Any], bool]] = None,
) -> None:
self._waiters[event_name.upper()].append((future, check))
def remove_waiter(self, event_name: str, future: asyncio.Future) -> None:
waiters = self._waiters.get(event_name.upper())
if not waiters:
return
self._waiters[event_name.upper()] = [
(f, c) for f, c in waiters if f is not future
]
if not self._waiters[event_name.upper()]:
self._waiters.pop(event_name.upper(), None)
def _resolve_waiters(self, event_name: str, data: Any) -> None:
waiters = self._waiters.get(event_name)
if not waiters:
return
to_remove: List[tuple[asyncio.Future, Optional[Callable[[Any], bool]]]] = []
for future, check in waiters:
if future.cancelled():
to_remove.append((future, check))
continue
try:
if check is None or check(data):
future.set_result(data)
to_remove.append((future, check))
except Exception as exc:
future.set_exception(exc)
to_remove.append((future, check))
for item in to_remove:
if item in waiters:
waiters.remove(item)
if not waiters:
self._waiters.pop(event_name, None)
async def dispatch(self, event_name: str, raw_data: Dict[str, Any]):
"""
Dispatches an event to all registered listeners.
Args:
event_name (str): The name of the event (e.g., 'MESSAGE_CREATE').
raw_data (Dict[str, Any]): The raw data payload from the Discord Gateway for this event.
"""
event_name_upper = event_name.upper()
listeners = self._listeners.get(event_name_upper)
if not listeners:
# print(f"No listeners for event {event_name_upper}")
return
parsed_data: Any = raw_data
if event_name_upper in self._event_parsers:
try:
parser = self._event_parsers[event_name_upper]
parsed_data = parser(raw_data)
except Exception as e:
print(f"Error parsing event data for {event_name_upper}: {e}")
# Optionally, dispatch with raw_data or raise, or log more formally
# For now, we'll proceed to dispatch with raw_data if parsing fails,
# or just log and return if parsed_data is critical.
# Let's assume if a parser exists, its output is critical.
return
self._resolve_waiters(event_name_upper, parsed_data)
# print(f"Dispatching event {event_name_upper} with data: {parsed_data} to {len(listeners)} listeners.")
for listener in listeners:
try:
# Inspect the listener to see how many arguments it expects
sig = inspect.signature(listener)
num_params = len(sig.parameters)
if num_params == 0: # Listener takes no arguments
await listener()
elif (
num_params == 1
): # Listener takes one argument (the parsed data or model)
await listener(parsed_data)
# elif num_params == 2 and event_name_upper == "MESSAGE_CREATE": # Special case for (client, message)
# await listener(self._client, parsed_data) # This might be too specific here
else:
# Fallback or error if signature doesn't match expected patterns
# For now, assume one arg is the most common for parsed data.
# Or, if you want to be strict:
print(
f"Warning: Listener {listener.__name__} for {event_name_upper} has an unhandled number of parameters ({num_params}). Skipping or attempting with one arg."
)
if num_params > 0: # Try with one arg if it takes any
await listener(parsed_data)
except Exception as e:
callback = self.on_dispatch_error
if callback is not None:
try:
await callback(event_name_upper, e, listener)
except Exception as hook_error:
print(f"Error in on_dispatch_error hook itself: {hook_error}")
else:
# Default error handling if no hook is set
print(
f"Error in event listener {listener.__name__} for {event_name_upper}: {e}"
)
if hasattr(self._client, "on_error"):
try:
await self._client.on_error(event_name_upper, e, listener)
except Exception as client_err_e:
print(f"Error in client.on_error itself: {client_err_e}")

View File

@ -0,0 +1,46 @@
# disagreement/ext/app_commands/__init__.py
"""
Application Commands Extension for Disagreement.
This package provides the framework for creating and handling
Discord Application Commands (slash commands, user commands, message commands).
"""
from .commands import (
AppCommand,
SlashCommand,
UserCommand,
MessageCommand,
AppCommandGroup,
)
from .decorators import (
slash_command,
user_command,
message_command,
hybrid_command,
group,
subcommand,
subcommand_group,
OptionMetadata,
)
from .context import AppCommandContext
# from .handler import AppCommandHandler # Will be imported when defined
__all__ = [
"AppCommand",
"SlashCommand",
"UserCommand",
"MessageCommand",
"AppCommandGroup", # To be defined
"slash_command",
"user_command",
"message_command",
"hybrid_command",
"group",
"subcommand",
"subcommand_group",
"OptionMetadata",
"AppCommandContext", # To be defined
]

View File

@ -0,0 +1,513 @@
# disagreement/ext/app_commands/commands.py
import inspect
from typing import Callable, Optional, List, Dict, Any, Union, TYPE_CHECKING
if TYPE_CHECKING:
from disagreement.ext.commands.core import (
Command as PrefixCommand,
) # Alias to avoid name clash
from disagreement.interactions import ApplicationCommandOption, Snowflake
from disagreement.enums import (
ApplicationCommandType,
IntegrationType,
InteractionContextType,
ApplicationCommandOptionType, # Added
)
from disagreement.ext.commands.cog import Cog # Corrected import path
# Placeholder for Cog if not using the existing one or if it needs adaptation
if not TYPE_CHECKING:
# This dynamic Cog = Any might not be ideal if Cog is used in runtime type checks.
# However, for type hinting purposes when TYPE_CHECKING is false, it avoids import.
# If Cog is needed at runtime by this module (it is, for AppCommand.cog type hint),
# it should be imported directly.
# For now, the TYPE_CHECKING block handles the proper import for static analysis.
# Let's ensure Cog is available at runtime if AppCommand.cog is accessed.
# A simple way is to import it outside TYPE_CHECKING too, if it doesn't cause circularity.
# Given its usage, a forward reference string 'Cog' might be better in AppCommand.cog type hint.
# Let's try importing it directly for runtime, assuming no circularity with this specific module.
try:
from disagreement.ext.commands.cog import Cog
except ImportError:
Cog = Any # Fallback if direct import fails (e.g. during partial builds/tests)
# Import PrefixCommand at runtime for HybridCommand
try:
from disagreement.ext.commands.core import Command as PrefixCommand
except Exception: # pragma: no cover - safeguard against unusual import issues
PrefixCommand = Any # type: ignore
# Import enums used at runtime
try:
from disagreement.enums import (
ApplicationCommandType,
IntegrationType,
InteractionContextType,
ApplicationCommandOptionType,
)
from disagreement.interactions import ApplicationCommandOption, Snowflake
except Exception: # pragma: no cover
ApplicationCommandType = ApplicationCommandOptionType = IntegrationType = (
InteractionContextType
) = Any # type: ignore
ApplicationCommandOption = Snowflake = Any # type: ignore
else: # When TYPE_CHECKING is true, Cog and PrefixCommand are already imported above.
pass
class AppCommand:
"""
Base class for an application command.
Attributes:
name (str): The name of the command.
description (Optional[str]): The description of the command.
Required for CHAT_INPUT, empty for USER and MESSAGE commands.
callback (Callable[..., Any]): The coroutine function that will be called when the command is executed.
type (ApplicationCommandType): The type of the application command.
options (Optional[List[ApplicationCommandOption]]): Parameters for the command. Populated by decorators.
guild_ids (Optional[List[Snowflake]]): List of guild IDs where this command is active. None for global.
default_member_permissions (Optional[str]): Bitwise permissions required by default for users to run the command.
nsfw (bool): Whether the command is age-restricted.
parent (Optional['AppCommandGroup']): The parent group if this is a subcommand.
cog (Optional[Cog]): The cog this command belongs to, if any.
_full_description (Optional[str]): Stores the original full description, e.g. from docstring,
even if the payload description is different (like for User/Message commands).
name_localizations (Optional[Dict[str, str]]): Localizations for the command's name.
description_localizations (Optional[Dict[str, str]]): Localizations for the command's description.
integration_types (Optional[List[IntegrationType]]): Installation contexts.
contexts (Optional[List[InteractionContextType]]): Interaction contexts.
"""
def __init__(
self,
callback: Callable[..., Any],
*,
name: str,
description: Optional[str] = None,
locale: Optional[str] = None,
type: "ApplicationCommandType",
guild_ids: Optional[List["Snowflake"]] = None,
default_member_permissions: Optional[str] = None,
nsfw: bool = False,
parent: Optional["AppCommandGroup"] = None,
cog: Optional[
Any
] = None, # Changed 'Cog' to Any to avoid runtime import issues if Cog is complex
name_localizations: Optional[Dict[str, str]] = None,
description_localizations: Optional[Dict[str, str]] = None,
integration_types: Optional[List["IntegrationType"]] = None,
contexts: Optional[List["InteractionContextType"]] = None,
):
if not asyncio.iscoroutinefunction(callback):
raise TypeError(
"Application command callback must be a coroutine function."
)
if locale:
from disagreement import i18n
translate = i18n.translate
self.name = translate(name, locale)
self.description = (
translate(description, locale) if description is not None else None
)
else:
self.name = name
self.description = description
self.locale: Optional[str] = locale
self.callback: Callable[..., Any] = callback
self.type: "ApplicationCommandType" = type
self.options: List["ApplicationCommandOption"] = [] # Populated by decorator
self.guild_ids: Optional[List["Snowflake"]] = guild_ids
self.default_member_permissions: Optional[str] = default_member_permissions
self.nsfw: bool = nsfw
self.parent: Optional["AppCommandGroup"] = parent
self.cog: Optional[Any] = cog # Changed 'Cog' to Any
self.name_localizations: Optional[Dict[str, str]] = name_localizations
self.description_localizations: Optional[Dict[str, str]] = (
description_localizations
)
self.integration_types: Optional[List["IntegrationType"]] = integration_types
self.contexts: Optional[List["InteractionContextType"]] = contexts
self._full_description: Optional[str] = (
None # Initialized by decorator if needed
)
# Signature for argument parsing by decorators/handlers
self.params = inspect.signature(callback).parameters
async def invoke(
self, context: "AppCommandContext", *args: Any, **kwargs: Any
) -> None:
"""Invokes the command's callback with the given context and arguments."""
# Similar to Command.invoke, handle cog if present
actual_args = []
if self.cog:
actual_args.append(self.cog)
actual_args.append(context)
actual_args.extend(args)
await self.callback(*actual_args, **kwargs)
def to_dict(self) -> Dict[str, Any]:
"""Converts the command to a dictionary payload for Discord API."""
payload: Dict[str, Any] = {
"name": self.name,
"type": self.type.value,
# CHAT_INPUT commands require a description.
# USER and MESSAGE commands must have an empty description in the payload if not omitted.
# The constructor for UserCommand/MessageCommand already sets self.description to ""
"description": (
self.description
if self.type == ApplicationCommandType.CHAT_INPUT
else ""
),
}
# For CHAT_INPUT commands, options are its parameters.
# For USER/MESSAGE commands, options should be empty or not present.
if self.type == ApplicationCommandType.CHAT_INPUT and self.options:
payload["options"] = [opt.to_dict() for opt in self.options]
if self.default_member_permissions is not None: # Can be "0" for no permissions
payload["default_member_permissions"] = str(self.default_member_permissions)
# nsfw defaults to False, only include if True
if self.nsfw:
payload["nsfw"] = True
if self.name_localizations:
payload["name_localizations"] = self.name_localizations
# Description localizations only apply if there's a description (CHAT_INPUT commands)
if (
self.type == ApplicationCommandType.CHAT_INPUT
and self.description
and self.description_localizations
):
payload["description_localizations"] = self.description_localizations
if self.integration_types:
payload["integration_types"] = [it.value for it in self.integration_types]
if self.contexts:
payload["contexts"] = [ict.value for ict in self.contexts]
# According to Discord API, guild_id is not part of this payload,
# it's used in the URL path for guild-specific command registration.
# However, the global command registration takes an 'application_id' in the payload,
# but that's handled by the HTTPClient.
return payload
def __repr__(self) -> str:
return f"<{self.__class__.__name__} name='{self.name}' type={self.type!r}>"
class SlashCommand(AppCommand):
"""Represents a CHAT_INPUT (slash) command."""
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
if not kwargs.get("description"):
raise ValueError("SlashCommand requires a description.")
super().__init__(callback, type=ApplicationCommandType.CHAT_INPUT, **kwargs)
class UserCommand(AppCommand):
"""Represents a USER context menu command."""
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
# Description is not allowed by Discord API for User Commands, but can be set to empty string.
kwargs["description"] = kwargs.get(
"description", ""
) # Ensure it's empty or not present in payload
super().__init__(callback, type=ApplicationCommandType.USER, **kwargs)
class MessageCommand(AppCommand):
"""Represents a MESSAGE context menu command."""
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
# Description is not allowed by Discord API for Message Commands.
kwargs["description"] = kwargs.get("description", "")
super().__init__(callback, type=ApplicationCommandType.MESSAGE, **kwargs)
class HybridCommand(SlashCommand, PrefixCommand): # Inherit from both
"""
Represents a command that can be invoked as both a slash command
and a traditional prefix-based command.
"""
def __init__(self, callback: Callable[..., Any], **kwargs: Any):
# Initialize SlashCommand part (which calls AppCommand.__init__)
# We need to ensure 'type' is correctly passed for AppCommand
# kwargs for SlashCommand: name, description, guild_ids, default_member_permissions, nsfw, parent, cog, etc.
# kwargs for PrefixCommand: name, aliases, brief, description, cog
# Pop prefix-specific args before passing to SlashCommand constructor
prefix_aliases = kwargs.pop("aliases", [])
prefix_brief = kwargs.pop("brief", None)
# Description is used by both, AppCommand's constructor will handle it.
# Name is used by both. Cog is used by both.
# Call SlashCommand's __init__
# This will set up name, description, callback, type=CHAT_INPUT, options, etc.
super().__init__(callback, **kwargs) # This is SlashCommand.__init__
# Now, explicitly initialize the PrefixCommand parts that SlashCommand didn't cover
# or that need specific values for the prefix version.
# PrefixCommand.__init__(self, callback, name=self.name, aliases=prefix_aliases, brief=prefix_brief, description=self.description, cog=self.cog)
# However, PrefixCommand.__init__ also sets self.params, which AppCommand already did.
# We need to be careful not to re-initialize things unnecessarily or incorrectly.
# Let's manually set the distinct attributes for the PrefixCommand aspect.
# Attributes from PrefixCommand:
# self.callback is already set by AppCommand
# self.name is already set by AppCommand
self.aliases: List[str] = (
prefix_aliases # This was specific to HybridCommand before, now aligns with PrefixCommand
)
self.brief: Optional[str] = prefix_brief
# self.description is already set by AppCommand (SlashCommand ensures it exists)
# self.cog is already set by AppCommand
# self.params is already set by AppCommand
# Ensure the MRO is handled correctly. Python's MRO (C3 linearization)
# should call SlashCommand's __init__ then AppCommand's __init__.
# PrefixCommand.__init__ won't be called automatically unless we explicitly call it.
# By setting attributes directly, we avoid potential issues with multiple __init__ calls
# if their logic overlaps too much (e.g., both trying to set self.params).
# We might need to override invoke if the context or argument passing differs significantly
# between app command invocation and prefix command invocation.
# For now, SlashCommand.invoke and PrefixCommand.invoke are separate.
# The correct one will be called depending on how the command is dispatched.
# The AppCommandHandler will use AppCommand.invoke (via SlashCommand).
# The prefix CommandHandler will use PrefixCommand.invoke.
# This seems acceptable.
class AppCommandGroup:
"""
Represents a group of application commands (subcommands or subcommand groups).
This itself is not directly callable but acts as a namespace.
"""
def __init__(
self,
name: str,
description: Optional[
str
] = None, # Required for top-level groups that form part of a slash command
guild_ids: Optional[List["Snowflake"]] = None,
parent: Optional["AppCommandGroup"] = None,
default_member_permissions: Optional[str] = None,
nsfw: bool = False,
name_localizations: Optional[Dict[str, str]] = None,
description_localizations: Optional[Dict[str, str]] = None,
integration_types: Optional[List["IntegrationType"]] = None,
contexts: Optional[List["InteractionContextType"]] = None,
):
self.name: str = name
self.description: Optional[str] = description
self.guild_ids: Optional[List["Snowflake"]] = guild_ids
self.parent: Optional["AppCommandGroup"] = parent
self.commands: Dict[str, Union[AppCommand, "AppCommandGroup"]] = {}
self.default_member_permissions: Optional[str] = default_member_permissions
self.nsfw: bool = nsfw
self.name_localizations: Optional[Dict[str, str]] = name_localizations
self.description_localizations: Optional[Dict[str, str]] = (
description_localizations
)
self.integration_types: Optional[List["IntegrationType"]] = integration_types
self.contexts: Optional[List["InteractionContextType"]] = contexts
# A group itself doesn't have a cog directly, its commands do.
def add_command(self, command: Union[AppCommand, "AppCommandGroup"]) -> None:
if command.name in self.commands:
raise ValueError(
f"Command or group '{command.name}' already exists in group '{self.name}'."
)
command.parent = self
self.commands[command.name] = command
def get_command(self, name: str) -> Optional[Union[AppCommand, "AppCommandGroup"]]:
return self.commands.get(name)
def command(self, *d_args: Any, **d_kwargs: Any):
d_kwargs.setdefault("parent", self)
from .decorators import slash_command
return slash_command(*d_args, **d_kwargs)
def group(
self,
name: str,
description: Optional[str] = None,
**kwargs: Any,
):
sub_group = AppCommandGroup(
name=name,
description=description,
parent=self,
guild_ids=kwargs.get("guild_ids"),
default_member_permissions=kwargs.get("default_member_permissions"),
nsfw=kwargs.get("nsfw", False),
name_localizations=kwargs.get("name_localizations"),
description_localizations=kwargs.get("description_localizations"),
integration_types=kwargs.get("integration_types"),
contexts=kwargs.get("contexts"),
)
self.add_command(sub_group)
def decorator(func: Optional[Callable[..., Any]] = None):
if func is not None:
setattr(func, "__app_command_object__", sub_group)
return sub_group
return sub_group
return decorator
def __repr__(self) -> str:
return f"<AppCommandGroup name='{self.name}' commands={len(self.commands)}>"
def to_dict(self) -> Dict[str, Any]:
"""
Converts the command group to a dictionary payload for Discord API.
This represents a top-level command that has subcommands/subcommand groups.
"""
payload: Dict[str, Any] = {
"name": self.name,
"type": ApplicationCommandType.CHAT_INPUT.value, # Groups are implicitly CHAT_INPUT
"description": self.description
or "No description provided", # Top-level groups require a description
"options": [],
}
if self.default_member_permissions is not None:
payload["default_member_permissions"] = str(self.default_member_permissions)
if self.nsfw:
payload["nsfw"] = True
if self.name_localizations:
payload["name_localizations"] = self.name_localizations
if (
self.description and self.description_localizations
): # Only if description is not empty
payload["description_localizations"] = self.description_localizations
if self.integration_types:
payload["integration_types"] = [it.value for it in self.integration_types]
if self.contexts:
payload["contexts"] = [ict.value for ict in self.contexts]
# guild_ids are handled at the registration level, not in this specific payload part.
options_payload: List[Dict[str, Any]] = []
for cmd_name, command_or_group in self.commands.items():
if isinstance(command_or_group, AppCommand): # This is a Subcommand
# Subcommands use their own options (parameters)
sub_options = (
[opt.to_dict() for opt in command_or_group.options]
if command_or_group.options
else []
)
option_dict = {
"type": ApplicationCommandOptionType.SUB_COMMAND.value,
"name": command_or_group.name,
"description": command_or_group.description
or "No description provided",
"options": sub_options,
}
# Add localization for subcommand name and description if available
if command_or_group.name_localizations:
option_dict["name_localizations"] = (
command_or_group.name_localizations
)
if (
command_or_group.description
and command_or_group.description_localizations
):
option_dict["description_localizations"] = (
command_or_group.description_localizations
)
options_payload.append(option_dict)
elif isinstance(
command_or_group, AppCommandGroup
): # This is a Subcommand Group
# Subcommand groups have their subcommands/groups as options
sub_group_options: List[Dict[str, Any]] = []
for sub_cmd_name, sub_command in command_or_group.commands.items():
# Nested groups can only contain subcommands, not further nested groups as per Discord rules.
# So, sub_command here must be an AppCommand.
if isinstance(
sub_command, AppCommand
): # Should always be AppCommand if structure is valid
sub_cmd_options = (
[opt.to_dict() for opt in sub_command.options]
if sub_command.options
else []
)
sub_group_option_entry = {
"type": ApplicationCommandOptionType.SUB_COMMAND.value,
"name": sub_command.name,
"description": sub_command.description
or "No description provided",
"options": sub_cmd_options,
}
# Add localization for subcommand name and description if available
if sub_command.name_localizations:
sub_group_option_entry["name_localizations"] = (
sub_command.name_localizations
)
if (
sub_command.description
and sub_command.description_localizations
):
sub_group_option_entry["description_localizations"] = (
sub_command.description_localizations
)
sub_group_options.append(sub_group_option_entry)
# else:
# # This case implies a group nested inside a group, which then contains another group.
# # Discord's structure is:
# # command -> option (SUB_COMMAND_GROUP) -> option (SUB_COMMAND) -> option (param)
# # This should be caught by validation logic in decorators or add_command.
# # For now, we assume valid structure where AppCommandGroup's commands are AppCommands.
# pass
option_dict = {
"type": ApplicationCommandOptionType.SUB_COMMAND_GROUP.value,
"name": command_or_group.name,
"description": command_or_group.description
or "No description provided",
"options": sub_group_options, # These are the SUB_COMMANDs
}
# Add localization for subcommand group name and description if available
if command_or_group.name_localizations:
option_dict["name_localizations"] = (
command_or_group.name_localizations
)
if (
command_or_group.description
and command_or_group.description_localizations
):
option_dict["description_localizations"] = (
command_or_group.description_localizations
)
options_payload.append(option_dict)
payload["options"] = options_payload
return payload
# Need to import asyncio for iscoroutinefunction check
import asyncio
if TYPE_CHECKING:
from .context import AppCommandContext # For type hint in AppCommand.invoke
# Ensure ApplicationCommandOptionType is available for the to_dict method
from disagreement.enums import ApplicationCommandOptionType

View File

@ -0,0 +1,556 @@
# disagreement/ext/app_commands/context.py
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, List, Union, Any, Dict
if TYPE_CHECKING:
from disagreement.client import Client
from disagreement.interactions import (
Interaction,
InteractionCallbackData,
InteractionResponsePayload,
Snowflake,
)
from disagreement.enums import InteractionCallbackType, MessageFlags
from disagreement.models import (
User,
Member,
Message,
Channel,
ActionRow,
)
from disagreement.ui.view import View
# For full model hints, these would be imported from disagreement.models when defined:
Embed = Any
PartialAttachment = Any
Guild = Any # from disagreement.models import Guild
TextChannel = Any # from disagreement.models import TextChannel, etc.
from .commands import AppCommand
from disagreement.enums import InteractionCallbackType, MessageFlags
from disagreement.interactions import (
Interaction,
InteractionCallbackData,
InteractionResponsePayload,
Snowflake,
)
from disagreement.models import Message
from disagreement.typing import Typing
class AppCommandContext:
"""
Represents the context in which an application command is being invoked.
Provides methods to respond to the interaction.
"""
def __init__(
self,
bot: "Client",
interaction: "Interaction",
command: Optional["AppCommand"] = None,
):
self.bot: "Client" = bot
self.interaction: "Interaction" = interaction
self.command: Optional["AppCommand"] = command # The command that was invoked
self._responded: bool = False
self._deferred: bool = False
@property
def token(self) -> str:
"""The interaction token."""
return self.interaction.token
@property
def interaction_id(self) -> "Snowflake":
"""The interaction ID."""
return self.interaction.id
@property
def application_id(self) -> "Snowflake":
"""The application ID of the interaction."""
return self.interaction.application_id
@property
def guild_id(self) -> Optional["Snowflake"]:
"""The ID of the guild where the interaction occurred, if any."""
return self.interaction.guild_id
@property
def channel_id(self) -> Optional["Snowflake"]:
"""The ID of the channel where the interaction occurred."""
return self.interaction.channel_id
@property
def author(self) -> Optional[Union["User", "Member"]]:
"""The user or member who invoked the interaction."""
return self.interaction.member or self.interaction.user
@property
def user(self) -> Optional["User"]:
"""The user who invoked the interaction.
If in a guild, this is the user part of the member.
If in a DM, this is the top-level user.
"""
return self.interaction.user
@property
def member(self) -> Optional["Member"]:
"""The member who invoked the interaction, if this occurred in a guild."""
return self.interaction.member
@property
def locale(self) -> Optional[str]:
"""The selected language of the invoking user."""
return self.interaction.locale
@property
def guild_locale(self) -> Optional[str]:
"""The guild's preferred language, if applicable."""
return self.interaction.guild_locale
@property
async def guild(self) -> Optional["Guild"]:
"""The guild object where the interaction occurred, if available."""
if not self.guild_id:
return None
guild = None
if hasattr(self.bot, "get_guild"):
guild = self.bot.get_guild(self.guild_id)
if not guild and hasattr(self.bot, "fetch_guild"):
try:
guild = await self.bot.fetch_guild(self.guild_id)
except Exception:
guild = None
return guild
@property
async def channel(self) -> Optional[Any]:
"""The channel object where the interaction occurred, if available."""
if not self.channel_id:
return None
channel = None
if hasattr(self.bot, "get_channel"):
channel = self.bot.get_channel(self.channel_id)
elif hasattr(self.bot, "_channels"):
channel = self.bot._channels.get(self.channel_id)
if not channel and hasattr(self.bot, "fetch_channel"):
try:
channel = await self.bot.fetch_channel(self.channel_id)
except Exception:
channel = None
return channel
async def _send_response(
self,
response_type: "InteractionCallbackType",
data: Optional[Dict[str, Any]] = None,
) -> None:
"""Internal helper to send interaction responses."""
if (
self._responded
and not self._deferred
and response_type
!= InteractionCallbackType.APPLICATION_COMMAND_AUTOCOMPLETE_RESULT
):
# If already responded and not deferred, subsequent responses must be followups
# (unless it's an autocomplete result which is a special case)
# For now, let's assume followups are handled by separate methods.
# This logic might need refinement based on how followups are exposed.
raise RuntimeError(
"Interaction has already been responded to. Use send_followup()."
)
callback_data = InteractionCallbackData(data) if data else None
payload = InteractionResponsePayload(type=response_type, data=callback_data)
await self.bot._http.create_interaction_response(
interaction_id=self.interaction_id,
interaction_token=self.token,
payload=payload,
)
if (
response_type
!= InteractionCallbackType.APPLICATION_COMMAND_AUTOCOMPLETE_RESULT
):
self._responded = True
if (
response_type
== InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE
or response_type == InteractionCallbackType.DEFERRED_UPDATE_MESSAGE
):
self._deferred = True
async def defer(self, ephemeral: bool = False, thinking: bool = True) -> None:
"""
Defers the interaction response.
This is typically used when your command might take longer than 3 seconds to process.
You must send a followup message within 15 minutes.
Args:
ephemeral (bool): Whether the subsequent followup response should be ephemeral.
Only applicable if `thinking` is True.
thinking (bool): If True (default), responds with a "Bot is thinking..." message
(DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE).
If False, responds with DEFERRED_UPDATE_MESSAGE (for components).
"""
if self._responded:
raise RuntimeError("Interaction has already been responded to or deferred.")
response_type = (
InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE
if thinking
else InteractionCallbackType.DEFERRED_UPDATE_MESSAGE
)
data = None
if ephemeral and thinking:
data = {
"flags": MessageFlags.EPHEMERAL.value
} # Assuming MessageFlags enum exists
await self._send_response(response_type, data)
self._deferred = True # Mark as deferred
async def send(
self,
content: Optional[str] = None,
embed: Optional["Embed"] = None, # Convenience for single embed
embeds: Optional[List["Embed"]] = None,
*,
tts: bool = False,
files: Optional[List[Any]] = None,
components: Optional[List[ActionRow]] = None,
view: Optional[View] = None,
allowed_mentions: Optional[Dict[str, Any]] = None,
ephemeral: bool = False,
flags: Optional[int] = None,
) -> Optional[
"Message"
]: # Returns Message if not ephemeral and response was not deferred
"""
Sends a response to the interaction.
If the interaction was previously deferred, this will edit the original deferred response.
Otherwise, it sends an initial response.
Args:
content (Optional[str]): The message content.
embed (Optional[Embed]): A single embed to send. If `embeds` is also provided, this is ignored.
embeds (Optional[List[Embed]]): A list of embeds to send (max 10).
ephemeral (bool): Whether the message should be ephemeral (only visible to the invoker).
flags (Optional[int]): Additional message flags to apply.
Returns:
Optional[Message]: The sent message object if a new message was created and not ephemeral.
None if the response was ephemeral or an edit to a deferred message.
"""
if not self._responded and self._deferred: # Editing a deferred response
# Use edit_original_interaction_response
payload: Dict[str, Any] = {}
if content is not None:
payload["content"] = content
if tts:
payload["tts"] = True
actual_embeds = embeds
if embed and not embeds:
actual_embeds = [embed]
if actual_embeds:
payload["embeds"] = [e.to_dict() for e in actual_embeds]
if view:
await view._start(self.bot)
payload["components"] = view.to_components_payload()
elif components:
payload["components"] = [c.to_dict() for c in components]
if files is not None:
payload["attachments"] = [
f.to_dict() if hasattr(f, "to_dict") else f for f in files
]
if allowed_mentions is not None:
payload["allowed_mentions"] = allowed_mentions
# Flags (like ephemeral) cannot be set when editing the original deferred response this way.
# Ephemeral for deferred must be set during defer().
msg_data = await self.bot._http.edit_original_interaction_response(
application_id=self.application_id,
interaction_token=self.token,
payload=payload,
)
self._responded = True # Ensure it's marked as fully responded
if view and msg_data and "id" in msg_data:
view.message_id = msg_data["id"]
self.bot._views[msg_data["id"]] = view
# Construct and return Message object if needed, for now returns None for edits
return None
elif not self._responded: # Sending an initial response
data: Dict[str, Any] = {}
if content is not None:
data["content"] = content
if tts:
data["tts"] = True
actual_embeds = embeds
if embed and not embeds:
actual_embeds = [embed]
if actual_embeds:
data["embeds"] = [
e.to_dict() for e in actual_embeds
] # Assuming embeds have to_dict()
if view:
await view._start(self.bot)
data["components"] = view.to_components_payload()
elif components:
data["components"] = [c.to_dict() for c in components]
if files is not None:
data["attachments"] = [
f.to_dict() if hasattr(f, "to_dict") else f for f in files
]
if allowed_mentions is not None:
data["allowed_mentions"] = allowed_mentions
flags_value = 0
if ephemeral:
flags_value |= MessageFlags.EPHEMERAL.value
if flags:
flags_value |= flags
if flags_value:
data["flags"] = flags_value
await self._send_response(
InteractionCallbackType.CHANNEL_MESSAGE_WITH_SOURCE, data
)
if view and not ephemeral:
try:
msg_data = await self.bot._http.get_original_interaction_response(
application_id=self.application_id,
interaction_token=self.token,
)
if msg_data and "id" in msg_data:
view.message_id = msg_data["id"]
self.bot._views[msg_data["id"]] = view
except Exception:
pass
if not ephemeral:
return None
return None
else:
# If already responded and not deferred, this should be a followup.
# This method is for initial response or editing deferred.
raise RuntimeError(
"Interaction has already been responded to. Use send_followup()."
)
async def send_followup(
self,
content: Optional[str] = None,
embed: Optional["Embed"] = None,
embeds: Optional[List["Embed"]] = None,
*,
ephemeral: bool = False,
tts: bool = False,
files: Optional[List[Any]] = None,
components: Optional[List["ActionRow"]] = None,
view: Optional[View] = None,
allowed_mentions: Optional[Dict[str, Any]] = None,
flags: Optional[int] = None,
) -> Optional["Message"]:
"""
Sends a followup message to an interaction.
This can be used after an initial response or a deferred response.
Args:
content (Optional[str]): The message content.
embed (Optional[Embed]): A single embed to send.
embeds (Optional[List[Embed]]): A list of embeds to send.
ephemeral (bool): Whether the followup message should be ephemeral.
flags (Optional[int]): Additional message flags to apply.
Returns:
Message: The sent followup message object.
"""
if not self._responded:
raise RuntimeError(
"Must acknowledge or defer the interaction before sending a followup."
)
payload: Dict[str, Any] = {}
if content is not None:
payload["content"] = content
if tts:
payload["tts"] = True
actual_embeds = embeds
if embed and not embeds:
actual_embeds = [embed]
if actual_embeds:
payload["embeds"] = [
e.to_dict() for e in actual_embeds
] # Assuming embeds have to_dict()
if view:
await view._start(self.bot)
payload["components"] = view.to_components_payload()
elif components:
payload["components"] = [c.to_dict() for c in components]
if files is not None:
payload["attachments"] = [
f.to_dict() if hasattr(f, "to_dict") else f for f in files
]
if allowed_mentions is not None:
payload["allowed_mentions"] = allowed_mentions
flags_value = 0
if ephemeral:
flags_value |= MessageFlags.EPHEMERAL.value
if flags:
flags_value |= flags
if flags_value:
payload["flags"] = flags_value
# Followup messages are sent to a webhook endpoint
message_data = await self.bot._http.create_followup_message(
application_id=self.application_id,
interaction_token=self.token,
payload=payload,
)
if view and message_data and "id" in message_data:
view.message_id = message_data["id"]
self.bot._views[message_data["id"]] = view
from disagreement.models import Message # Ensure Message is available
return Message(data=message_data, client_instance=self.bot)
async def edit(
self,
message_id: "Snowflake" = "@original", # Defaults to editing the original response
content: Optional[str] = None,
embed: Optional["Embed"] = None,
embeds: Optional[List["Embed"]] = None,
*,
components: Optional[List["ActionRow"]] = None,
attachments: Optional[List[Any]] = None,
allowed_mentions: Optional[Dict[str, Any]] = None,
) -> Optional["Message"]:
"""
Edits a message previously sent in response to this interaction.
Can edit the original response or a followup message.
Args:
message_id (Snowflake): The ID of the message to edit. Defaults to "@original"
to edit the initial interaction response.
content (Optional[str]): The new message content.
embed (Optional[Embed]): A single new embed.
embeds (Optional[List[Embed]]): A list of new embeds.
Returns:
Optional[Message]: The edited message object if available.
"""
if not self._responded:
raise RuntimeError(
"Cannot edit response if interaction hasn't been responded to or deferred."
)
payload: Dict[str, Any] = {}
if content is not None:
payload["content"] = content # Use None to clear
actual_embeds = embeds
if embed and not embeds:
actual_embeds = [embed]
if actual_embeds is not None: # Allow passing empty list to clear embeds
payload["embeds"] = [
e.to_dict() for e in actual_embeds
] # Assuming embeds have to_dict()
if components is not None:
payload["components"] = [c.to_dict() for c in components]
if attachments is not None:
payload["attachments"] = [
a.to_dict() if hasattr(a, "to_dict") else a for a in attachments
]
if allowed_mentions is not None:
payload["allowed_mentions"] = allowed_mentions
if message_id == "@original":
edited_message_data = (
await self.bot._http.edit_original_interaction_response(
application_id=self.application_id,
interaction_token=self.token,
payload=payload,
)
)
else:
edited_message_data = await self.bot._http.edit_followup_message(
application_id=self.application_id,
interaction_token=self.token,
message_id=message_id,
payload=payload,
)
# The HTTP methods used in tests return minimal data that is insufficient
# to construct a full ``Message`` instance, so we simply return ``None``
# rather than attempting to parse the response.
return None
async def delete(self, message_id: "Snowflake" = "@original") -> None:
"""
Deletes a message previously sent in response to this interaction.
Can delete the original response or a followup message.
Args:
message_id (Snowflake): The ID of the message to delete. Defaults to "@original"
to delete the initial interaction response.
"""
if not self._responded:
# If not responded, there's nothing to delete via this interaction's lifecycle.
# Deferral doesn't create a message to delete until a followup is sent.
raise RuntimeError(
"Cannot delete response if interaction hasn't been responded to."
)
if message_id == "@original":
await self.bot._http.delete_original_interaction_response(
application_id=self.application_id, interaction_token=self.token
)
else:
await self.bot._http.delete_followup_message(
application_id=self.application_id,
interaction_token=self.token,
message_id=message_id,
)
# After deleting the original response, further followups might be problematic.
# Discord docs: "Once the original message is deleted, you can no longer edit the message or send followups."
# Consider implications for context state.
def typing(self) -> Typing:
"""Return a typing context manager for this interaction's channel."""
if not self.channel_id:
raise RuntimeError("Cannot send typing indicator without a channel.")
return self.bot.typing(self.channel_id)

View File

@ -0,0 +1,478 @@
# disagreement/ext/app_commands/converters.py
"""
Converters for transforming application command option values.
"""
from typing import (
Any,
Awaitable,
Callable,
Dict,
Protocol,
TypeVar,
Union,
TYPE_CHECKING,
)
from disagreement.enums import ApplicationCommandOptionType
from disagreement.errors import (
AppCommandOptionConversionError,
) # To be created in disagreement/errors.py
if TYPE_CHECKING:
from disagreement.interactions import Interaction # For context if needed
from disagreement.models import (
User,
Member,
Role,
Channel,
Attachment,
) # Discord models
from disagreement.client import Client # For fetching objects
T = TypeVar("T", covariant=True)
class Converter(Protocol[T]):
"""
A protocol for classes that can convert an interaction option value to a specific type.
"""
async def convert(self, interaction: "Interaction", value: Any) -> T:
"""
Converts the given value to the target type.
Parameters:
interaction (Interaction): The interaction context.
value (Any): The raw value from the interaction option.
Returns:
T: The converted value.
Raises:
AppCommandOptionConversionError: If conversion fails.
"""
...
# Basic Type Converters
class StringConverter(Converter[str]):
async def convert(self, interaction: "Interaction", value: Any) -> str:
if not isinstance(value, str):
raise AppCommandOptionConversionError(
f"Expected a string, but got {type(value).__name__}: {value}"
)
return value
class IntegerConverter(Converter[int]):
async def convert(self, interaction: "Interaction", value: Any) -> int:
if not isinstance(value, int):
try:
return int(value)
except (ValueError, TypeError):
raise AppCommandOptionConversionError(
f"Expected an integer, but got {type(value).__name__}: {value}"
)
return value
class BooleanConverter(Converter[bool]):
async def convert(self, interaction: "Interaction", value: Any) -> bool:
if not isinstance(value, bool):
if isinstance(value, str):
if value.lower() == "true":
return True
elif value.lower() == "false":
return False
raise AppCommandOptionConversionError(
f"Expected a boolean, but got {type(value).__name__}: {value}"
)
return value
class NumberConverter(Converter[float]): # Discord 'NUMBER' type is float
async def convert(self, interaction: "Interaction", value: Any) -> float:
if not isinstance(value, (int, float)):
try:
return float(value)
except (ValueError, TypeError):
raise AppCommandOptionConversionError(
f"Expected a number (float), but got {type(value).__name__}: {value}"
)
return float(value) # Ensure it's a float even if int is passed
# Discord Model Converters
class UserConverter(Converter["User"]):
def __init__(self, client: "Client"):
self._client = client
async def convert(self, interaction: "Interaction", value: Any) -> "User":
if isinstance(value, str): # Assume it's a user ID
user_id = value
# Attempt to get from interaction resolved data first
if (
interaction.data
and interaction.data.resolved
and interaction.data.resolved.users
):
user_object = interaction.data.resolved.users.get(
user_id
) # This is already a User object
if user_object:
return user_object # Return the already parsed User object
# Fallback to fetching if not in resolved or if interaction has no resolved data
try:
user = await self._client.fetch_user(
user_id
) # fetch_user now also parses and caches
if user:
return user
raise AppCommandOptionConversionError(
f"User with ID '{user_id}' not found.",
option_name="user",
original_value=value,
)
except Exception as e: # Catch potential HTTP errors from fetch_user
raise AppCommandOptionConversionError(
f"Failed to fetch user '{user_id}': {e}",
option_name="user",
original_value=value,
)
elif (
isinstance(value, dict) and "id" in value
): # If it's raw user data dict (less common path now)
return self._client.parse_user(value) # parse_user handles dict -> User
raise AppCommandOptionConversionError(
f"Expected a user ID string or user data dict, got {type(value).__name__}",
option_name="user",
original_value=value,
)
class MemberConverter(Converter["Member"]):
def __init__(self, client: "Client"):
self._client = client
async def convert(self, interaction: "Interaction", value: Any) -> "Member":
if not interaction.guild_id:
raise AppCommandOptionConversionError(
"Cannot convert to Member outside of a guild context.",
option_name="member",
)
if isinstance(value, str): # Assume it's a user ID
member_id = value
# Attempt to get from interaction resolved data first
if (
interaction.data
and interaction.data.resolved
and interaction.data.resolved.members
):
# The Member object from resolved.members should already be correctly initialized
# by ResolvedData, including its User part.
member = interaction.data.resolved.members.get(member_id)
if member:
return (
member # Return the already resolved and parsed Member object
)
# Fallback to fetching if not in resolved
try:
member = await self._client.fetch_member(
interaction.guild_id, member_id
)
if member:
return member
raise AppCommandOptionConversionError(
f"Member with ID '{member_id}' not found in guild '{interaction.guild_id}'.",
option_name="member",
original_value=value,
)
except Exception as e:
raise AppCommandOptionConversionError(
f"Failed to fetch member '{member_id}': {e}",
option_name="member",
original_value=value,
)
elif isinstance(value, dict) and "id" in value.get(
"user", {}
): # If it's already a member data dict
return self._client.parse_member(value, interaction.guild_id)
raise AppCommandOptionConversionError(
f"Expected a member ID string or member data dict, got {type(value).__name__}",
option_name="member",
original_value=value,
)
class RoleConverter(Converter["Role"]):
def __init__(self, client: "Client"):
self._client = client
async def convert(self, interaction: "Interaction", value: Any) -> "Role":
if not interaction.guild_id:
raise AppCommandOptionConversionError(
"Cannot convert to Role outside of a guild context.", option_name="role"
)
if isinstance(value, str): # Assume it's a role ID
role_id = value
# Attempt to get from interaction resolved data first
if (
interaction.data
and interaction.data.resolved
and interaction.data.resolved.roles
):
role_object = interaction.data.resolved.roles.get(
role_id
) # Should be a Role object
if role_object:
return role_object
# Fallback to fetching from guild if not in resolved
# This requires Client to have a fetch_role method or similar
try:
# Assuming Client.fetch_role(guild_id, role_id) will be implemented
role = await self._client.fetch_role(interaction.guild_id, role_id)
if role:
return role
raise AppCommandOptionConversionError(
f"Role with ID '{role_id}' not found in guild '{interaction.guild_id}'.",
option_name="role",
original_value=value,
)
except Exception as e:
raise AppCommandOptionConversionError(
f"Failed to fetch role '{role_id}': {e}",
option_name="role",
original_value=value,
)
elif (
isinstance(value, dict) and "id" in value
): # If it's already role data dict
if (
not interaction.guild_id
): # Should have been caught earlier, but as a safeguard
raise AppCommandOptionConversionError(
"Guild context is required to parse role data.",
option_name="role",
original_value=value,
)
return self._client.parse_role(value, interaction.guild_id)
# This path is reached if value is not a string (role ID) and not a dict (role data)
# or if it's a string but all fetching/lookup attempts failed.
# The final raise AppCommandOptionConversionError should be outside the if/elif for string values.
# If value was a string, an error should have been raised within the 'if isinstance(value, str)' block.
# If it wasn't a string or dict, this is the correct place to raise.
# The previous structure was slightly off, as the final raise was inside the string check block.
# Let's ensure the final raise is at the correct scope.
# The current structure seems to imply that if it's not a string, it must be a dict or error.
# If it's a string and all lookups fail, an error is raised within that block.
# If it's a dict and parsing fails (or guild_id missing), error raised.
# If it's neither, this final raise is correct.
# The "Function with declared return type "Role" must return value on all code paths"
# error suggests a path where no return or raise happens.
# This happens if `isinstance(value, str)` is true, but then all internal paths
# (resolved check, fetch try/except) don't lead to a return or raise *before*
# falling out of the `if isinstance(value, str)` block.
# The `raise AppCommandOptionConversionError` at the end of the `if isinstance(value, str)` block
# (line 156 in previous version) handles the case where a role ID is given but not found.
# The one at the very end (line 164 in previous) handles cases where value is not str/dict.
# Corrected structure for the final raise:
# It should be at the same level as the initial `if isinstance(value, str):`
# to catch cases where `value` is neither a str nor a dict.
# However, the current logic within the `if isinstance(value, str):` block
# ensures a raise if the role ID is not found.
# The `elif isinstance(value, dict)` handles the dict case.
# The final `raise` (line 164) is for types other than str or dict.
# The Pylance error "Function with declared return type "Role" must return value on all code paths"
# implies that if `value` is a string, and `interaction.data.resolved.roles.get(role_id)` is None,
# AND `self._client.fetch_role` returns None (which it can), then the
# `raise AppCommandOptionConversionError` on line 156 is correctly hit.
# The issue might be that Pylance doesn't see `AppCommandOptionConversionError` as definitively terminating.
# This is unlikely. Let's re-verify the logic flow.
# The `raise` on line 156 is correct if role_id is not found after fetching.
# The `raise` on line 164 is for when `value` is not a str and not a dict.
# This seems logically sound. The Pylance error might be a misinterpretation or a subtle issue.
# For now, the duplicated `except` is the primary syntax error.
# The "must return value on all code paths" often occurs if an if/elif chain doesn't
# exhaust all possibilities or if a path through a try/except doesn't guarantee a return/raise.
# In this case, if `value` is a string, it either returns a Role or raises an error.
# If `value` is a dict, it either returns a Role or raises an error.
# If `value` is neither, it raises an error. All paths seem covered.
# The syntax error from the duplicated `except` is the most likely culprit for Pylance's confusion.
raise AppCommandOptionConversionError(
f"Expected a role ID string or role data dict, got {type(value).__name__}",
option_name="role",
original_value=value,
)
class ChannelConverter(Converter["Channel"]):
def __init__(self, client: "Client"):
self._client = client
async def convert(self, interaction: "Interaction", value: Any) -> "Channel":
if isinstance(value, str): # Assume it's a channel ID
channel_id = value
# Attempt to get from interaction resolved data first
if (
interaction.data
and interaction.data.resolved
and interaction.data.resolved.channels
):
# Resolved channels are PartialChannel. Client.fetch_channel will get the full typed one.
partial_channel = interaction.data.resolved.channels.get(channel_id)
if partial_channel:
# Client.fetch_channel should handle fetching and parsing to the correct Channel subtype
full_channel = await self._client.fetch_channel(partial_channel.id)
if full_channel:
return full_channel
# If fetch_channel returns None even with a resolved ID, it's an issue.
raise AppCommandOptionConversionError(
f"Failed to fetch full channel for resolved ID '{channel_id}'.",
option_name="channel",
original_value=value,
)
# Fallback to fetching directly if not in resolved or if resolved fetch failed
try:
channel = await self._client.fetch_channel(
channel_id
) # fetch_channel handles parsing
if channel:
return channel
raise AppCommandOptionConversionError(
f"Channel with ID '{channel_id}' not found.",
option_name="channel",
original_value=value,
)
except Exception as e:
raise AppCommandOptionConversionError(
f"Failed to fetch channel '{channel_id}': {e}",
option_name="channel",
original_value=value,
)
# Raw channel data dicts are not typically provided for slash command options.
raise AppCommandOptionConversionError(
f"Expected a channel ID string, got {type(value).__name__}",
option_name="channel",
original_value=value,
)
class AttachmentConverter(Converter["Attachment"]):
def __init__(
self, client: "Client"
): # Client might be needed for future enhancements or consistency
self._client = client
async def convert(self, interaction: "Interaction", value: Any) -> "Attachment":
if isinstance(value, str): # Value is the attachment ID
attachment_id = value
if (
interaction.data
and interaction.data.resolved
and interaction.data.resolved.attachments
):
attachment_object = interaction.data.resolved.attachments.get(
attachment_id
) # This is already an Attachment object
if attachment_object:
return (
attachment_object # Return the already parsed Attachment object
)
raise AppCommandOptionConversionError(
f"Attachment with ID '{attachment_id}' not found in resolved data.",
option_name="attachment",
original_value=value,
)
raise AppCommandOptionConversionError(
f"Expected an attachment ID string, got {type(value).__name__}",
option_name="attachment",
original_value=value,
)
# Converters can be registered dynamically using
# :meth:`disagreement.ext.app_commands.handler.AppCommandHandler.register_converter`.
# Mapping from ApplicationCommandOptionType to default converters
# This will be used by the AppCommandHandler to automatically apply converters
# if no explicit converter is specified for a command option's type hint.
DEFAULT_CONVERTERS: Dict[
ApplicationCommandOptionType, Callable[..., Converter[Any]]
] = { # Changed Callable signature
ApplicationCommandOptionType.STRING: StringConverter,
ApplicationCommandOptionType.INTEGER: IntegerConverter,
ApplicationCommandOptionType.BOOLEAN: BooleanConverter,
ApplicationCommandOptionType.NUMBER: NumberConverter,
ApplicationCommandOptionType.USER: UserConverter,
ApplicationCommandOptionType.CHANNEL: ChannelConverter,
ApplicationCommandOptionType.ROLE: RoleConverter,
# ApplicationCommandOptionType.MENTIONABLE: MentionableConverter, # Special case, can be User or Role
ApplicationCommandOptionType.ATTACHMENT: AttachmentConverter, # Added
}
async def run_converters(
interaction: "Interaction",
param_type: Any, # The type hint of the parameter
option_type: ApplicationCommandOptionType, # The Discord option type
value: Any,
client: "Client", # Needed for model converters
) -> Any:
"""
Runs the appropriate converter for a given parameter type and value.
This function will be more complex, handling custom converters, unions, optionals etc.
For now, a basic lookup.
"""
converter_class_factory = DEFAULT_CONVERTERS.get(option_type)
if converter_class_factory:
# Check if the factory needs the client instance
# This is a bit simplistic; a more robust way might involve inspecting __init__ signature
# or having converters register their needs.
if option_type in [
ApplicationCommandOptionType.USER,
ApplicationCommandOptionType.CHANNEL, # Anticipating these
ApplicationCommandOptionType.ROLE,
ApplicationCommandOptionType.MENTIONABLE,
ApplicationCommandOptionType.ATTACHMENT,
]:
converter_instance = converter_class_factory(client=client)
else:
converter_instance = converter_class_factory()
return await converter_instance.convert(interaction, value)
# Fallback for unhandled types or if direct type matching is needed
if param_type is str and isinstance(value, str):
return value
if param_type is int and isinstance(value, int):
return value
if param_type is bool and isinstance(value, bool):
return value
if param_type is float and isinstance(value, (float, int)):
return float(value)
# If no specific converter, and it's not a basic type match, raise error or return raw
# For now, let's raise if no converter found for a specific option type
if option_type in DEFAULT_CONVERTERS: # Should have been handled
pass # This path implies a logic error above or missing converter in DEFAULT_CONVERTERS
# If it's a model type but no converter yet, this will need to be handled
# e.g. if param_type is User and option_type is ApplicationCommandOptionType.USER
raise AppCommandOptionConversionError(
f"No suitable converter found for option type {option_type.name} "
f"with value '{value}' to target type {param_type.__name__ if hasattr(param_type, '__name__') else param_type}"
)

View File

@ -0,0 +1,569 @@
# disagreement/ext/app_commands/decorators.py
import inspect
import asyncio
from dataclasses import dataclass
from typing import (
Callable,
Optional,
List,
Dict,
Any,
Union,
Type,
get_origin,
get_args,
TYPE_CHECKING,
Literal,
Annotated,
TypeVar,
cast,
)
from .commands import (
SlashCommand,
UserCommand,
MessageCommand,
AppCommand,
AppCommandGroup,
HybridCommand,
)
from disagreement.interactions import (
ApplicationCommandOption,
ApplicationCommandOptionChoice,
Snowflake,
)
from disagreement.enums import (
ApplicationCommandOptionType,
IntegrationType,
InteractionContextType,
# Assuming ChannelType will be added to disagreement.enums
)
if TYPE_CHECKING:
from disagreement.client import Client # For potential future use
from disagreement.models import Channel, User
# Assuming TextChannel, VoiceChannel etc. might be defined or aliased
# For now, we'll use string comparisons for channel types or rely on a yet-to-be-defined ChannelType enum
Channel = Any
Member = Any
Role = Any
Attachment = Any
# from .cog import Cog # Placeholder
else:
# Runtime fallbacks for optional model classes
from disagreement.models import Channel
Client = Any # type: ignore
User = Any # type: ignore
Member = Any # type: ignore
Role = Any # type: ignore
Attachment = Any # type: ignore
# Mapping Python types to Discord ApplicationCommandOptionType
# This will need to be expanded and made more robust.
# Consider using a registry or a more sophisticated type mapping system.
_type_mapping: Dict[Any, ApplicationCommandOptionType] = (
{ # Changed Type to Any for key due to placeholders
str: ApplicationCommandOptionType.STRING,
int: ApplicationCommandOptionType.INTEGER,
bool: ApplicationCommandOptionType.BOOLEAN,
float: ApplicationCommandOptionType.NUMBER, # Discord 'NUMBER' type is for float/double
User: ApplicationCommandOptionType.USER,
Channel: ApplicationCommandOptionType.CHANNEL,
# Placeholders for actual model types from disagreement.models
# These will be resolved to their actual types when TYPE_CHECKING is False or via isinstance checks
}
)
# Helper dataclass for storing extra option metadata
@dataclass
class OptionMetadata:
channel_types: Optional[List[int]] = None
min_value: Optional[Union[int, float]] = None
max_value: Optional[Union[int, float]] = None
min_length: Optional[int] = None
max_length: Optional[int] = None
autocomplete: bool = False
# Ensure these are updated if model names/locations change
if TYPE_CHECKING:
# _type_mapping[User] = ApplicationCommandOptionType.USER # Already added above
_type_mapping[Member] = ApplicationCommandOptionType.USER # Member implies User
_type_mapping[Role] = ApplicationCommandOptionType.ROLE
_type_mapping[Attachment] = ApplicationCommandOptionType.ATTACHMENT
_type_mapping[Channel] = ApplicationCommandOptionType.CHANNEL
# TypeVar for the app command decorator factory
AppCmdType = TypeVar("AppCmdType", bound=AppCommand)
def _extract_options_from_signature(
func: Callable[..., Any], option_meta: Optional[Dict[str, OptionMetadata]] = None
) -> List[ApplicationCommandOption]:
"""
Inspects a function signature and generates ApplicationCommandOption list.
"""
options: List[ApplicationCommandOption] = []
params = inspect.signature(func).parameters
doc = inspect.getdoc(func)
param_descriptions: Dict[str, str] = {}
if doc:
for line in inspect.cleandoc(doc).splitlines():
line = line.strip()
if line.startswith(":param"):
try:
_, rest = line.split(" ", 1)
name, desc = rest.split(":", 1)
param_descriptions[name.strip()] = desc.strip()
except ValueError:
continue
# Skip 'self' (for cogs) and 'ctx' (context) parameters
param_iter = iter(params.values())
first_param = next(param_iter, None)
# Heuristic: if the function is bound to a class (cog), 'self' might be the first param.
# A more robust way would be to check if `func` is a method of a Cog instance later.
# For now, simple name check.
if first_param and first_param.name == "self":
first_param = next(param_iter, None) # Consume 'self', get next
if first_param and first_param.name == "ctx": # Consume 'ctx'
pass # ctx is handled, now iterate over actual command options
elif (
first_param
): # If first_param was not 'self' and not 'ctx', it's a command option
param_iter = iter(params.values()) # Reset iterator to include the first param
for param in param_iter:
if param.name == "self" or param.name == "ctx": # Should have been skipped
continue
if param.kind == param.VAR_POSITIONAL or param.kind == param.VAR_KEYWORD:
# *args and **kwargs are not directly supported by slash command options structure.
# Could raise an error or ignore. For now, ignore.
# print(f"Warning: *args/**kwargs ({param.name}) are not supported for slash command options.")
continue
option_name = param.name
option_description = param_descriptions.get(
option_name, f"Description for {option_name}"
)
meta = option_meta.get(option_name) if option_meta else None
param_type_hint = param.annotation
if param_type_hint == inspect.Parameter.empty:
# Default to string if no type hint, or raise error.
# Forcing type hints is generally better for slash commands.
# raise TypeError(f"Option '{option_name}' must have a type hint for slash commands.")
param_type_hint = str # Defaulting to string, can be made stricter
option_type: Optional[ApplicationCommandOptionType] = None
choices: Optional[List[ApplicationCommandOptionChoice]] = None
origin = get_origin(param_type_hint)
args = get_args(param_type_hint)
if origin is Annotated:
param_type_hint = args[0]
for extra in args[1:]:
if isinstance(extra, OptionMetadata):
meta = extra
origin = get_origin(param_type_hint)
args = get_args(param_type_hint)
actual_type_for_mapping = param_type_hint
is_optional = False
if origin is Union: # Handles Optional[T] which is Union[T, NoneType]
# Filter out NoneType to get the actual type for mapping
union_types = [t for t in args if t is not type(None)]
if len(union_types) == 1:
actual_type_for_mapping = union_types[0]
is_optional = True
else:
# More complex Unions are not directly supported by a single option type.
# Could default to STRING or raise.
# For now, let's assume simple Optional[T] or direct types.
# print(f"Warning: Complex Union type for '{option_name}' not fully supported, defaulting to STRING.")
actual_type_for_mapping = str
elif origin is list and len(args) == 1:
# List[T] is not a direct option type. Discord handles multiple values for some types
# via repeated options or specific component interactions, not directly in slash command options.
# This might indicate a need for a different interaction pattern or custom parsing.
# For now, treat List[str] as a string, others might error or default.
# print(f"Warning: List type for '{option_name}' not directly supported as a single option. Consider type {args[0]}.")
actual_type_for_mapping = args[
0
] # Use the inner type for mapping, but this is a simplification.
if origin is Literal: # typing.Literal['a', 'b']
choices = []
for choice_val in args:
if not isinstance(choice_val, (str, int, float)):
raise TypeError(
f"Literal choices for '{option_name}' must be str, int, or float. Got {type(choice_val)}."
)
choices.append(
ApplicationCommandOptionChoice(
data={"name": str(choice_val), "value": choice_val}
)
)
# The type of the Literal's arguments determines the option type
if choices:
literal_arg_type = type(choices[0].value)
option_type = _type_mapping.get(literal_arg_type)
if (
not option_type and literal_arg_type is float
): # float maps to NUMBER
option_type = ApplicationCommandOptionType.NUMBER
if not option_type: # If not determined by Literal
option_type = _type_mapping.get(actual_type_for_mapping)
# Special handling for User, Member, Role, Attachment, Channel if not directly in _type_mapping
# This is a bit crude; a proper registry or isinstance checks would be better.
if not option_type:
if (
actual_type_for_mapping.__name__ == "User"
or actual_type_for_mapping.__name__ == "Member"
):
option_type = ApplicationCommandOptionType.USER
elif actual_type_for_mapping.__name__ == "Role":
option_type = ApplicationCommandOptionType.ROLE
elif actual_type_for_mapping.__name__ == "Attachment":
option_type = ApplicationCommandOptionType.ATTACHMENT
elif (
inspect.isclass(actual_type_for_mapping)
and isinstance(Channel, type)
and issubclass(actual_type_for_mapping, cast(type, Channel))
):
option_type = ApplicationCommandOptionType.CHANNEL
if not option_type:
# Fallback or error if type couldn't be mapped
# print(f"Warning: Could not map type '{actual_type_for_mapping}' for option '{option_name}'. Defaulting to STRING.")
option_type = ApplicationCommandOptionType.STRING # Default fallback
required = (param.default == inspect.Parameter.empty) and not is_optional
data: Dict[str, Any] = {
"name": option_name,
"description": option_description,
"type": option_type.value,
"required": required,
"choices": ([c.to_dict() for c in choices] if choices else None),
}
if meta:
if meta.channel_types is not None:
data["channel_types"] = meta.channel_types
if meta.min_value is not None:
data["min_value"] = meta.min_value
if meta.max_value is not None:
data["max_value"] = meta.max_value
if meta.min_length is not None:
data["min_length"] = meta.min_length
if meta.max_length is not None:
data["max_length"] = meta.max_length
if meta.autocomplete:
data["autocomplete"] = True
options.append(ApplicationCommandOption(data=data))
return options
def _app_command_decorator(
cls: Type[AppCmdType],
option_meta: Optional[Dict[str, OptionMetadata]] = None,
**attrs: Any,
) -> Callable[[Callable[..., Any]], AppCmdType]:
"""Generic factory for creating app command decorators."""
def decorator(func: Callable[..., Any]) -> AppCmdType:
if not asyncio.iscoroutinefunction(func):
raise TypeError(
"Application command callback must be a coroutine function."
)
name = attrs.pop("name", None) or func.__name__
description = attrs.pop("description", None) or inspect.getdoc(func)
if description: # Clean up docstring
description = inspect.cleandoc(description).split("\n\n", 1)[
0
] # Use first paragraph
parent_group = attrs.pop("parent", None)
if parent_group and not isinstance(parent_group, AppCommandGroup):
raise TypeError(
"The 'parent' argument must be an AppCommandGroup instance."
)
# For User/Message commands, description should be empty for payload, but can be stored for help.
if cls is UserCommand or cls is MessageCommand:
actual_description_for_payload = ""
else:
actual_description_for_payload = description
if not actual_description_for_payload and cls is SlashCommand:
raise ValueError(f"Slash command '{name}' must have a description.")
# Create the command instance
cmd_instance = cls(
callback=func,
name=name,
description=actual_description_for_payload, # Use payload-appropriate description
**attrs, # Remaining attributes like guild_ids, nsfw, etc.
)
# Store original description if different (e.g. for User/Message commands for help text)
if description != actual_description_for_payload:
cmd_instance._full_description = (
description # Custom attribute for library use
)
if isinstance(cmd_instance, SlashCommand):
cmd_instance.options = _extract_options_from_signature(func, option_meta)
if parent_group:
parent_group.add_command(cmd_instance) # This also sets cmd_instance.parent
# Attach command object to the function for later collection by Cog or Client
# This is a common pattern.
if hasattr(func, "__app_command_object__"):
# Function might already be decorated (e.g. hybrid or stacked decorators)
# Decide on behavior: error, overwrite, or store list of commands.
# For now, let's assume one app command decorator of a specific type per function.
# Hybrid commands will need special handling.
print(
f"Warning: Function {func.__name__} is already an app command or has one attached. Overwriting."
)
setattr(func, "__app_command_object__", cmd_instance)
setattr(cmd_instance, "__app_command_object__", cmd_instance)
# If the command is a HybridCommand, also set the attribute
# that the prefix command system's Cog._inject looks for.
if isinstance(cmd_instance, HybridCommand):
setattr(func, "__command_object__", cmd_instance)
setattr(cmd_instance, "__command_object__", cmd_instance)
return cmd_instance # Return the command instance itself, not the function
# This allows it to be added to cogs/handlers directly.
return decorator
def slash_command(
name: Optional[str] = None,
description: Optional[str] = None,
guild_ids: Optional[List[Snowflake]] = None,
default_member_permissions: Optional[str] = None,
nsfw: bool = False,
name_localizations: Optional[Dict[str, str]] = None,
description_localizations: Optional[Dict[str, str]] = None,
integration_types: Optional[List[IntegrationType]] = None,
contexts: Optional[List[InteractionContextType]] = None,
*,
guilds: bool = True,
dms: bool = True,
private_channels: bool = True,
parent: Optional[AppCommandGroup] = None, # Added parent parameter
locale: Optional[str] = None,
option_meta: Optional[Dict[str, OptionMetadata]] = None,
) -> Callable[[Callable[..., Any]], SlashCommand]:
"""
Decorator to create a CHAT_INPUT (slash) command.
Options are inferred from the function's type hints.
"""
if contexts is None:
ctxs: List[InteractionContextType] = []
if guilds:
ctxs.append(InteractionContextType.GUILD)
if dms:
ctxs.append(InteractionContextType.BOT_DM)
if private_channels:
ctxs.append(InteractionContextType.PRIVATE_CHANNEL)
if len(ctxs) != 3:
contexts = ctxs
attrs = {
"name": name,
"description": description,
"guild_ids": guild_ids,
"default_member_permissions": default_member_permissions,
"nsfw": nsfw,
"name_localizations": name_localizations,
"description_localizations": description_localizations,
"integration_types": integration_types,
"contexts": contexts,
"parent": parent, # Pass parent to attrs
"locale": locale,
}
# Filter out None values to avoid passing them as explicit None to command constructor
# Keep 'parent' even if None, as _app_command_decorator handles None parent.
# nsfw default is False, so it's fine if not present and defaults.
attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw", "parent"]}
return _app_command_decorator(SlashCommand, option_meta, **attrs)
def user_command(
name: Optional[str] = None,
guild_ids: Optional[List[Snowflake]] = None,
default_member_permissions: Optional[str] = None,
nsfw: bool = False, # Though less common for user commands
name_localizations: Optional[Dict[str, str]] = None,
integration_types: Optional[List[IntegrationType]] = None,
contexts: Optional[List[InteractionContextType]] = None,
locale: Optional[str] = None,
# description is not used by Discord for User commands
) -> Callable[[Callable[..., Any]], UserCommand]:
"""Decorator to create a USER context menu command."""
attrs = {
"name": name,
"guild_ids": guild_ids,
"default_member_permissions": default_member_permissions,
"nsfw": nsfw,
"name_localizations": name_localizations,
"integration_types": integration_types,
"contexts": contexts,
"locale": locale,
}
attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw"]}
return _app_command_decorator(UserCommand, **attrs)
def message_command(
name: Optional[str] = None,
guild_ids: Optional[List[Snowflake]] = None,
default_member_permissions: Optional[str] = None,
nsfw: bool = False, # Though less common for message commands
name_localizations: Optional[Dict[str, str]] = None,
integration_types: Optional[List[IntegrationType]] = None,
contexts: Optional[List[InteractionContextType]] = None,
locale: Optional[str] = None,
# description is not used by Discord for Message commands
) -> Callable[[Callable[..., Any]], MessageCommand]:
"""Decorator to create a MESSAGE context menu command."""
attrs = {
"name": name,
"guild_ids": guild_ids,
"default_member_permissions": default_member_permissions,
"nsfw": nsfw,
"name_localizations": name_localizations,
"integration_types": integration_types,
"contexts": contexts,
"locale": locale,
}
attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw"]}
return _app_command_decorator(MessageCommand, **attrs)
def hybrid_command(
name: Optional[str] = None,
description: Optional[str] = None,
guild_ids: Optional[List[Snowflake]] = None,
default_member_permissions: Optional[str] = None,
nsfw: bool = False,
name_localizations: Optional[Dict[str, str]] = None,
description_localizations: Optional[Dict[str, str]] = None,
integration_types: Optional[List[IntegrationType]] = None,
contexts: Optional[List[InteractionContextType]] = None,
*,
guilds: bool = True,
dms: bool = True,
private_channels: bool = True,
aliases: Optional[List[str]] = None, # Specific to prefix command aspect
# Other prefix-specific options can be added here (e.g., help, brief)
option_meta: Optional[Dict[str, OptionMetadata]] = None,
locale: Optional[str] = None,
) -> Callable[[Callable[..., Any]], HybridCommand]:
"""
Decorator to create a command that can be invoked as both a slash command
and a traditional prefix-based command.
Options for the slash command part are inferred from the function's type hints.
"""
if contexts is None:
ctxs: List[InteractionContextType] = []
if guilds:
ctxs.append(InteractionContextType.GUILD)
if dms:
ctxs.append(InteractionContextType.BOT_DM)
if private_channels:
ctxs.append(InteractionContextType.PRIVATE_CHANNEL)
if len(ctxs) != 3:
contexts = ctxs
attrs = {
"name": name,
"description": description,
"guild_ids": guild_ids,
"default_member_permissions": default_member_permissions,
"nsfw": nsfw,
"name_localizations": name_localizations,
"description_localizations": description_localizations,
"integration_types": integration_types,
"contexts": contexts,
"aliases": aliases or [], # Ensure aliases is a list
"locale": locale,
}
# Filter out None values to avoid passing them as explicit None to command constructor
# Keep 'nsfw' and 'aliases' as they have defaults (False, [])
attrs = {
k: v for k, v in attrs.items() if v is not None or k in ["nsfw", "aliases"]
}
return _app_command_decorator(HybridCommand, option_meta, **attrs)
def subcommand(
parent: AppCommandGroup, *d_args: Any, **d_kwargs: Any
) -> Callable[[Callable[..., Any]], SlashCommand]:
"""Create a subcommand under an existing :class:`AppCommandGroup`."""
d_kwargs.setdefault("parent", parent)
return slash_command(*d_args, **d_kwargs)
def group(
name: str,
description: Optional[str] = None,
**kwargs: Any,
) -> Callable[[Optional[Callable[..., Any]]], AppCommandGroup]:
"""Decorator to declare a top level :class:`AppCommandGroup`."""
def decorator(func: Optional[Callable[..., Any]] = None) -> AppCommandGroup:
grp = AppCommandGroup(
name=name,
description=description,
guild_ids=kwargs.get("guild_ids"),
parent=kwargs.get("parent"),
default_member_permissions=kwargs.get("default_member_permissions"),
nsfw=kwargs.get("nsfw", False),
name_localizations=kwargs.get("name_localizations"),
description_localizations=kwargs.get("description_localizations"),
integration_types=kwargs.get("integration_types"),
contexts=kwargs.get("contexts"),
)
if func is not None:
setattr(func, "__app_command_object__", grp)
return grp
return decorator
def subcommand_group(
parent: AppCommandGroup,
name: str,
description: Optional[str] = None,
**kwargs: Any,
) -> Callable[[Optional[Callable[..., Any]]], AppCommandGroup]:
"""Create a nested :class:`AppCommandGroup` under ``parent``."""
return parent.group(
name=name,
description=description,
**kwargs,
)

View File

@ -0,0 +1,627 @@
# disagreement/ext/app_commands/handler.py
import inspect
from typing import (
TYPE_CHECKING,
Dict,
Optional,
List,
Any,
Tuple,
Union,
get_origin,
get_args,
Literal,
)
if TYPE_CHECKING:
from disagreement.client import Client
from disagreement.interactions import Interaction, ResolvedData, Snowflake
from disagreement.enums import (
ApplicationCommandType,
ApplicationCommandOptionType,
InteractionType,
)
from .commands import (
AppCommand,
SlashCommand,
UserCommand,
MessageCommand,
AppCommandGroup,
)
from .context import AppCommandContext
from disagreement.models import (
User,
Member,
Role,
Attachment,
Message,
) # For resolved data
# Channel models would also go here
# Placeholder for models not yet fully defined or imported
if not TYPE_CHECKING:
from disagreement.enums import (
ApplicationCommandType,
ApplicationCommandOptionType,
InteractionType,
)
from .commands import (
AppCommand,
SlashCommand,
UserCommand,
MessageCommand,
AppCommandGroup,
)
from .context import AppCommandContext
User = Any
Member = Any
Role = Any
Attachment = Any
Channel = Any
Message = Any
class AppCommandHandler:
"""
Manages application command registration, parsing, and dispatching.
"""
def __init__(self, client: "Client"):
self.client: "Client" = client
# Store commands: key could be (name, type) for global, or (name, type, guild_id) for guild-specific
# For simplicity, let's start with a flat structure and refine if needed for guild commands.
# A more robust system might have separate dicts for global and guild commands.
self._slash_commands: Dict[str, SlashCommand] = {}
self._user_commands: Dict[str, UserCommand] = {}
self._message_commands: Dict[str, MessageCommand] = {}
self._app_command_groups: Dict[str, AppCommandGroup] = {}
self._converter_registry: Dict[type, type] = {}
def add_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
"""Adds an application command or a command group to the handler."""
if isinstance(command, AppCommandGroup):
if command.name in self._app_command_groups:
raise ValueError(
f"AppCommandGroup '{command.name}' is already registered."
)
self._app_command_groups[command.name] = command
return
if isinstance(command, SlashCommand):
if command.name in self._slash_commands:
raise ValueError(
f"SlashCommand '{command.name}' is already registered."
)
self._slash_commands[command.name] = command
return
if isinstance(command, UserCommand):
if command.name in self._user_commands:
raise ValueError(f"UserCommand '{command.name}' is already registered.")
self._user_commands[command.name] = command
return
if isinstance(command, MessageCommand):
if command.name in self._message_commands:
raise ValueError(
f"MessageCommand '{command.name}' is already registered."
)
self._message_commands[command.name] = command
return
if isinstance(command, AppCommand):
# Fallback for plain AppCommand objects
if command.type == ApplicationCommandType.CHAT_INPUT:
if command.name in self._slash_commands:
raise ValueError(
f"SlashCommand '{command.name}' is already registered."
)
self._slash_commands[command.name] = command # type: ignore
elif command.type == ApplicationCommandType.USER:
if command.name in self._user_commands:
raise ValueError(
f"UserCommand '{command.name}' is already registered."
)
self._user_commands[command.name] = command # type: ignore
elif command.type == ApplicationCommandType.MESSAGE:
if command.name in self._message_commands:
raise ValueError(
f"MessageCommand '{command.name}' is already registered."
)
self._message_commands[command.name] = command # type: ignore
else:
raise TypeError(
f"Unsupported command type: {command.type} for '{command.name}'"
)
else:
raise TypeError("Can only add AppCommand or AppCommandGroup instances.")
def remove_command(
self, name: str
) -> Optional[Union["AppCommand", "AppCommandGroup"]]:
"""Removes an application command or group by name."""
if name in self._slash_commands:
return self._slash_commands.pop(name)
if name in self._user_commands:
return self._user_commands.pop(name)
if name in self._message_commands:
return self._message_commands.pop(name)
if name in self._app_command_groups:
return self._app_command_groups.pop(name)
return None
def register_converter(self, annotation: type, converter_cls: type) -> None:
"""Register a custom converter class for a type annotation."""
self._converter_registry[annotation] = converter_cls
def get_converter(self, annotation: type) -> Optional[type]:
"""Retrieve a registered converter class for a type annotation."""
return self._converter_registry.get(annotation)
def get_command(
self,
name: str,
command_type: "ApplicationCommandType",
interaction_options: Optional[List[Dict[str, Any]]] = None,
) -> Optional["AppCommand"]:
"""Retrieves a command of a specific type."""
if command_type == ApplicationCommandType.CHAT_INPUT:
if not interaction_options:
return self._slash_commands.get(name)
# Handle subcommands/groups
current_options = interaction_options
target_command_or_group: Optional[Union[AppCommand, AppCommandGroup]] = (
self._app_command_groups.get(name)
)
if not target_command_or_group:
return self._slash_commands.get(name)
final_command: Optional[AppCommand] = None
while current_options:
opt_data = current_options[0]
opt_name = opt_data.get("name")
opt_type = (
ApplicationCommandOptionType(opt_data["type"])
if opt_data.get("type")
else None
)
if not opt_name or not isinstance(
target_command_or_group, AppCommandGroup
):
break
next_target = target_command_or_group.get_command(opt_name)
if isinstance(next_target, AppCommand) and (
opt_type == ApplicationCommandOptionType.SUB_COMMAND
or not opt_data.get("options")
):
final_command = next_target
break
elif (
isinstance(next_target, AppCommandGroup)
and opt_type == ApplicationCommandOptionType.SUB_COMMAND_GROUP
):
target_command_or_group = next_target
current_options = opt_data.get("options", [])
if not current_options:
break
else:
break
return final_command
if command_type == ApplicationCommandType.USER:
return self._user_commands.get(name)
if command_type == ApplicationCommandType.MESSAGE:
return self._message_commands.get(name)
return None
async def _resolve_option_value(
self,
value: Any,
expected_type: Any,
resolved_data: Optional["ResolvedData"],
guild_id: Optional["Snowflake"],
) -> Any:
"""
Resolves an option value to the expected Python type using resolved_data.
"""
converter_cls = self.get_converter(expected_type)
if converter_cls:
try:
init_params = inspect.signature(converter_cls.__init__).parameters
if "client" in init_params:
converter_instance = converter_cls(client=self.client) # type: ignore[arg-type]
else:
converter_instance = converter_cls()
return await converter_instance.convert(None, value) # type: ignore[arg-type]
except Exception:
pass
# This is a simplified resolver. A more robust one would use converters.
if resolved_data:
if expected_type is User or expected_type.__name__ == "User":
return resolved_data.users.get(value) if resolved_data.users else None
if expected_type is Member or expected_type.__name__ == "Member":
member_obj = (
resolved_data.members.get(value) if resolved_data.members else None
)
if member_obj:
if (
hasattr(member_obj, "username")
and not member_obj.username
and resolved_data.users
):
user_obj = resolved_data.users.get(value)
if user_obj:
member_obj.username = user_obj.username
member_obj.discriminator = user_obj.discriminator
member_obj.avatar = user_obj.avatar
member_obj.bot = user_obj.bot
member_obj.user = user_obj # type: ignore[attr-defined]
return member_obj
return None
if expected_type is Role or expected_type.__name__ == "Role":
return resolved_data.roles.get(value) if resolved_data.roles else None
if expected_type is Attachment or expected_type.__name__ == "Attachment":
return (
resolved_data.attachments.get(value)
if resolved_data.attachments
else None
)
if expected_type is Message or expected_type.__name__ == "Message":
return (
resolved_data.messages.get(value)
if resolved_data.messages
else None
)
if "Channel" in expected_type.__name__:
return (
resolved_data.channels.get(value)
if resolved_data.channels
else None
)
# For basic types, Discord already sends them correctly (string, int, bool, float)
if isinstance(value, expected_type):
return value
try: # Attempt direct conversion for basic types if Discord sent string for int/float/bool
if expected_type is int:
return int(value)
if expected_type is float:
return float(value)
if expected_type is bool: # Discord sends true/false
if isinstance(value, str):
return value.lower() == "true"
return bool(value)
except (ValueError, TypeError):
pass # Conversion failed
return value # Return as is if no specific resolution or conversion applied
async def _resolve_value(
self,
value: Any,
expected_type: Any,
resolved_data: Optional["ResolvedData"],
guild_id: Optional["Snowflake"],
) -> Any:
"""Public wrapper around ``_resolve_option_value`` used by tests."""
return await self._resolve_option_value(
value=value,
expected_type=expected_type,
resolved_data=resolved_data,
guild_id=guild_id,
)
async def _parse_interaction_options(
self,
command_params: Dict[str, inspect.Parameter], # From command.params
interaction_options: Optional[List[Dict[str, Any]]],
resolved_data: Optional["ResolvedData"],
guild_id: Optional["Snowflake"],
) -> Tuple[List[Any], Dict[str, Any]]:
"""
Parses options from an interaction payload and maps them to command function arguments.
"""
args_list: List[Any] = []
kwargs_dict: Dict[str, Any] = {}
if not interaction_options: # No options provided in interaction
# Check if command has required params without defaults
for name, param in command_params.items():
if param.default == inspect.Parameter.empty:
# This should ideally be caught by Discord if option is marked required
raise ValueError(f"Missing required option: {name}")
return args_list, kwargs_dict
# Create a dictionary of provided options by name for easier lookup
provided_options: Dict[str, Any] = {
opt["name"]: opt["value"] for opt in interaction_options if "value" in opt
}
for name, param in command_params.items():
if name in provided_options:
raw_value = provided_options[name]
expected_type = (
param.annotation
if param.annotation != inspect.Parameter.empty
else str
)
# Handle Optional[T]
origin_type = get_origin(expected_type)
if origin_type is Union:
union_args = get_args(expected_type)
# Assuming Optional[T] is Union[T, NoneType]
non_none_types = [t for t in union_args if t is not type(None)]
if len(non_none_types) == 1:
expected_type = non_none_types[0]
# Else, complex Union, might need more sophisticated handling or default to raw_value/str
elif origin_type is Literal:
literal_args = get_args(expected_type)
if literal_args:
expected_type = type(literal_args[0])
else:
expected_type = str
resolved_value = await self._resolve_option_value(
raw_value, expected_type, resolved_data, guild_id
)
if (
param.kind == inspect.Parameter.KEYWORD_ONLY
or param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
):
kwargs_dict[name] = resolved_value
# Note: Slash commands don't map directly to *args. All options are named.
# So, we'll primarily use kwargs_dict and then construct args_list based on param order if needed,
# but Discord sends named options, so direct kwarg usage is more natural.
elif param.default != inspect.Parameter.empty:
kwargs_dict[name] = param.default
else:
# Required parameter not provided by Discord - this implies an issue with command definition
# or Discord's validation, as Discord should enforce required options.
raise ValueError(
f"Required option '{name}' not found in interaction payload."
)
# Populate args_list based on the order in command_params for positional arguments
# This assumes that all args that are not keyword-only are passed positionally if present in kwargs_dict
for name, param in command_params.items():
if param.kind == inspect.Parameter.POSITIONAL_ONLY or (
param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and name in kwargs_dict
):
if name in kwargs_dict: # Ensure it was resolved or had a default
args_list.append(kwargs_dict[name])
# If it was POSITIONAL_ONLY and not in kwargs_dict, it's an error (already raised)
elif param.kind == inspect.Parameter.VAR_POSITIONAL: # *args
# Slash commands don't map to *args well. This would be empty.
pass
# Filter kwargs_dict to only include actual KEYWORD_ONLY or POSITIONAL_OR_KEYWORD params
# that were not used for args_list (if strict positional/keyword separation is desired).
# For slash commands, it's simpler to pass all resolved named options as kwargs.
final_kwargs = {
k: v
for k, v in kwargs_dict.items()
if k in command_params
and command_params[k].kind != inspect.Parameter.POSITIONAL_ONLY
}
# For simplicity with slash commands, let's assume all resolved options are passed via kwargs
# and the command signature is primarily (self, ctx, **options) or (ctx, **options)
# or (self, ctx, option1, option2) where names match.
# The AppCommand.invoke will handle passing them.
# The current args_list and final_kwargs might be redundant if invoke just uses **final_kwargs.
# Let's return kwargs_dict directly for now, and AppCommand.invoke can map them.
return [], kwargs_dict # Return empty args, all in kwargs for now.
async def dispatch_app_command_error(
self, context: "AppCommandContext", error: Exception
) -> None:
"""Dispatches an app command error to the client if implemented."""
if hasattr(self.client, "on_app_command_error"):
await self.client.on_app_command_error(context, error)
async def process_interaction(self, interaction: "Interaction") -> None:
"""Processes an incoming interaction."""
if interaction.type == InteractionType.MODAL_SUBMIT:
callback = getattr(self.client, "on_modal_submit", None)
if callback is not None:
from typing import Awaitable, Callable, cast
await cast(Callable[["Interaction"], Awaitable[None]], callback)(
interaction
)
return
if interaction.type == InteractionType.APPLICATION_COMMAND_AUTOCOMPLETE:
callback = getattr(self.client, "on_autocomplete", None)
if callback is not None:
from typing import Awaitable, Callable, cast
await cast(Callable[["Interaction"], Awaitable[None]], callback)(
interaction
)
return
if interaction.type != InteractionType.APPLICATION_COMMAND:
return
if not interaction.data or not interaction.data.name:
from .context import AppCommandContext
ctx = AppCommandContext(
bot=self.client, interaction=interaction, command=None
)
await ctx.send("Command not found.", ephemeral=True)
return
command_name = interaction.data.name
command_type = interaction.data.type or ApplicationCommandType.CHAT_INPUT
command = self.get_command(
command_name,
command_type,
interaction.data.options if interaction.data else None,
)
if not command:
from .context import AppCommandContext
ctx = AppCommandContext(
bot=self.client, interaction=interaction, command=None
)
await ctx.send(f"Command '{command_name}' not found.", ephemeral=True)
return
# Create context
from .context import AppCommandContext # Ensure AppCommandContext is available
ctx = AppCommandContext(
bot=self.client, interaction=interaction, command=command
)
try:
# Prepare arguments for the command callback
# Skip 'self' and 'ctx' from command.params for parsing interaction options
params_to_parse = {
name: param
for name, param in command.params.items()
if name not in ("self", "ctx")
}
if command.type in (
ApplicationCommandType.USER,
ApplicationCommandType.MESSAGE,
):
# Context menu commands provide a target_id. Resolve and pass it
args = []
kwargs = {}
if params_to_parse and interaction.data and interaction.data.target_id:
first_param = next(iter(params_to_parse.values()))
expected = (
first_param.annotation
if first_param.annotation != inspect.Parameter.empty
else str
)
resolved = await self._resolve_option_value(
interaction.data.target_id,
expected,
interaction.data.resolved,
interaction.guild_id,
)
if first_param.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
args.append(resolved)
else:
kwargs[first_param.name] = resolved
await command.invoke(ctx, *args, **kwargs)
else:
parsed_args, parsed_kwargs = await self._parse_interaction_options(
command_params=params_to_parse,
interaction_options=interaction.data.options,
resolved_data=interaction.data.resolved,
guild_id=interaction.guild_id,
)
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
except Exception as e:
print(f"Error invoking app command '{command.name}': {e}")
await self.dispatch_app_command_error(ctx, e)
# else:
# # Default error reply if no handler on client
# try:
# await ctx.send(f"An error occurred: {e}", ephemeral=True)
# except Exception as send_e:
# print(f"Failed to send error message for app command: {send_e}")
async def sync_commands(
self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None
) -> None:
"""
Synchronizes (registers/updates) all application commands with Discord.
If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
"""
commands_to_sync: List[Dict[str, Any]] = []
# Collect commands based on scope (global or specific guild)
# This needs to be more sophisticated to handle guild_ids on commands/groups
source_commands = (
list(self._slash_commands.values())
+ list(self._user_commands.values())
+ list(self._message_commands.values())
+ list(self._app_command_groups.values())
)
for cmd_or_group in source_commands:
# Determine if this command/group should be synced for the current scope
is_guild_specific_command = (
cmd_or_group.guild_ids is not None and len(cmd_or_group.guild_ids) > 0
)
if guild_id: # Syncing for a specific guild
# Skip if not a guild-specific command OR if it's for a different guild
if not is_guild_specific_command or (
cmd_or_group.guild_ids is not None
and guild_id not in cmd_or_group.guild_ids
):
continue
else: # Syncing global commands
if is_guild_specific_command:
continue # Skip guild-specific commands when syncing global
# Use the to_dict() method from AppCommand or AppCommandGroup
try:
payload = cmd_or_group.to_dict()
commands_to_sync.append(payload)
except AttributeError:
print(
f"Warning: Command or group '{cmd_or_group.name}' does not have a to_dict() method. Skipping."
)
except Exception as e:
print(
f"Error converting command/group '{cmd_or_group.name}' to dict: {e}. Skipping."
)
if not commands_to_sync:
print(
f"No commands to sync for {'guild ' + str(guild_id) if guild_id else 'global'} scope."
)
return
try:
if guild_id:
print(
f"Syncing {len(commands_to_sync)} commands for guild {guild_id}..."
)
await self.client._http.bulk_overwrite_guild_application_commands(
application_id, guild_id, commands_to_sync
)
else:
print(f"Syncing {len(commands_to_sync)} global commands...")
await self.client._http.bulk_overwrite_global_application_commands(
application_id, commands_to_sync
)
print("Command sync successful.")
except Exception as e:
print(f"Error syncing application commands: {e}")
# Consider re-raising or specific error handling

View File

@ -0,0 +1,49 @@
# disagreement/ext/commands/__init__.py
"""
disagreement.ext.commands - A command framework extension for the Disagreement library.
"""
from .cog import Cog
from .core import (
Command,
CommandContext,
CommandHandler,
) # CommandHandler might be internal
from .decorators import command, listener, check, check_any, cooldown
from .errors import (
CommandError,
CommandNotFound,
BadArgument,
MissingRequiredArgument,
ArgumentParsingError,
CheckFailure,
CheckAnyFailure,
CommandOnCooldown,
CommandInvokeError,
)
__all__ = [
# Cog
"Cog",
# Core
"Command",
"CommandContext",
# "CommandHandler", # Usually not part of public API for direct use by bot devs
# Decorators
"command",
"listener",
"check",
"check_any",
"cooldown",
# Errors
"CommandError",
"CommandNotFound",
"BadArgument",
"MissingRequiredArgument",
"ArgumentParsingError",
"CheckFailure",
"CheckAnyFailure",
"CommandOnCooldown",
"CommandInvokeError",
]

View File

@ -0,0 +1,155 @@
# disagreement/ext/commands/cog.py
import inspect
from typing import TYPE_CHECKING, List, Tuple, Callable, Awaitable, Any, Dict, Union
if TYPE_CHECKING:
from disagreement.client import Client
from .core import Command
from disagreement.ext.app_commands.commands import (
AppCommand,
AppCommandGroup,
) # Added
else: # pragma: no cover - runtime imports for isinstance checks
from disagreement.ext.app_commands.commands import AppCommand, AppCommandGroup
# EventDispatcher might be needed if cogs register listeners directly
# from disagreement.event_dispatcher import EventDispatcher
class Cog:
"""
The base class for cogs, which are collections of commands and listeners.
"""
def __init__(self, client: "Client"):
self._client: "Client" = client
self._cog_name: str = self.__class__.__name__
self._commands: Dict[str, "Command"] = {}
self._listeners: List[Tuple[str, Callable[..., Awaitable[None]]]] = []
self._app_commands_and_groups: List[Union["AppCommand", "AppCommandGroup"]] = (
[]
) # Added
# Discover commands and listeners defined in this cog instance
self._inject()
@property
def client(self) -> "Client":
return self._client
@property
def cog_name(self) -> str:
return self._cog_name
def _inject(self) -> None:
"""
Called to discover and prepare commands and listeners within this cog.
This is typically called by the CommandHandler when adding the cog.
"""
# Clear any previously injected state (e.g., if re-injecting)
self._commands.clear()
self._listeners.clear()
self._app_commands_and_groups.clear() # Added
for member_name, member in inspect.getmembers(self):
if hasattr(member, "__command_object__"):
# This is a prefix or hybrid command object
cmd: "Command" = getattr(member, "__command_object__")
cmd.cog = self # Assign the cog instance to the command
if cmd.name in self._commands:
# This should ideally be caught earlier or handled by CommandHandler
print(
f"Warning: Duplicate command name '{cmd.name}' in cog '{self.cog_name}'. Overwriting."
)
self._commands[cmd.name.lower()] = cmd
# Also register aliases
for alias in cmd.aliases:
self._commands[alias.lower()] = cmd
# If this command is also an application command (HybridCommand)
if isinstance(cmd, (AppCommand, AppCommandGroup)):
self._app_commands_and_groups.append(cmd)
elif hasattr(member, "__app_command_object__"): # Added for app commands
app_cmd_obj = getattr(member, "__app_command_object__")
if isinstance(app_cmd_obj, (AppCommand, AppCommandGroup)):
if isinstance(app_cmd_obj, AppCommand):
app_cmd_obj.cog = self # Associate cog
# For AppCommandGroup, its commands will have cog set individually if they are AppCommands
self._app_commands_and_groups.append(app_cmd_obj)
else:
print(
f"Warning: Member '{member_name}' in cog '{self.cog_name}' has '__app_command_object__' but it's not an AppCommand or AppCommandGroup."
)
elif isinstance(member, (AppCommand, AppCommandGroup)):
if isinstance(member, AppCommand):
member.cog = self
self._app_commands_and_groups.append(member)
elif hasattr(member, "__listener_name__"):
# This is a method decorated with @commands.Cog.listener or @commands.listener
if not inspect.iscoroutinefunction(member):
# Decorator should have caught this, but double check
print(
f"Warning: Listener '{member_name}' in cog '{self.cog_name}' is not a coroutine. Skipping."
)
continue
event_name: str = getattr(member, "__listener_name__")
# The callback needs to be the bound method from this cog instance
self._listeners.append((event_name, member))
def _eject(self) -> None:
"""
Called when the cog is being removed.
The CommandHandler will handle unregistering commands/listeners.
This method is for any cog-specific cleanup before that.
"""
# For now, just clear local collections. Actual unregistration is external.
self._commands.clear()
self._listeners.clear()
self._app_commands_and_groups.clear() # Added
def get_commands(self) -> List["Command"]:
"""Returns a list of commands in this cog."""
# Avoid duplicates if aliases point to the same command object
return list(dict.fromkeys(self._commands.values()))
def get_listeners(self) -> List[Tuple[str, Callable[..., Awaitable[None]]]]:
"""Returns a list of (event_name, callback) tuples for listeners in this cog."""
return self._listeners
def get_app_commands_and_groups(
self,
) -> List[Union["AppCommand", "AppCommandGroup"]]:
"""Returns a list of application commands and groups in this cog."""
return self._app_commands_and_groups
async def cog_load(self) -> None:
"""
A special method that is called when the cog is loaded.
This is a good place for any asynchronous setup.
Subclasses should override this if they need async setup.
"""
pass
async def cog_unload(self) -> None:
"""
A special method that is called when the cog is unloaded.
This is a good place for any asynchronous cleanup.
Subclasses should override this if they need async cleanup.
"""
pass
# Example of how a listener might be defined within a Cog using the decorator
# from .decorators import listener # Would be imported at module level
#
# @listener(name="ON_MESSAGE_CREATE_CUSTOM") # Explicit name
# async def on_my_event(self, message: 'Message'):
# print(f"Cog '{self.cog_name}' received event with message: {message.content}")
#
# @listener() # Name derived from method: on_ready
# async def on_ready(self):
# print(f"Cog '{self.cog_name}' is ready.")

View File

@ -0,0 +1,175 @@
# disagreement/ext/commands/converters.py
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic
from abc import ABC, abstractmethod
import re
from .errors import BadArgument
from disagreement.models import Member, Guild, Role
if TYPE_CHECKING:
from .core import CommandContext
T = TypeVar("T")
class Converter(ABC, Generic[T]):
"""
Base class for custom command argument converters.
Subclasses must implement the `convert` method.
"""
async def convert(self, ctx: "CommandContext", argument: str) -> T:
"""
Converts the argument to the desired type.
Args:
ctx: The invocation context.
argument: The string argument to convert.
Returns:
The converted argument.
Raises:
BadArgument: If the conversion fails.
"""
raise NotImplementedError("Converter subclass must implement convert method.")
# --- Built-in Type Converters ---
class IntConverter(Converter[int]):
async def convert(self, ctx: "CommandContext", argument: str) -> int:
try:
return int(argument)
except ValueError:
raise BadArgument(f"'{argument}' is not a valid integer.")
class FloatConverter(Converter[float]):
async def convert(self, ctx: "CommandContext", argument: str) -> float:
try:
return float(argument)
except ValueError:
raise BadArgument(f"'{argument}' is not a valid number.")
class BoolConverter(Converter[bool]):
async def convert(self, ctx: "CommandContext", argument: str) -> bool:
lowered = argument.lower()
if lowered in ("yes", "y", "true", "t", "1", "on", "enable", "enabled"):
return True
elif lowered in ("no", "n", "false", "f", "0", "off", "disable", "disabled"):
return False
raise BadArgument(f"'{argument}' is not a valid boolean-like value.")
class StringConverter(Converter[str]):
async def convert(self, ctx: "CommandContext", argument: str) -> str:
# For basic string, no conversion is needed, but this provides a consistent interface
return argument
# --- Discord Model Converters ---
class MemberConverter(Converter["Member"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "Member":
if not ctx.message.guild_id:
raise BadArgument("Member converter requires guild context.")
match = re.match(r"<@!?(\d+)>$", argument)
member_id = match.group(1) if match else argument
guild = ctx.bot.get_guild(ctx.message.guild_id)
if guild:
member = guild.get_member(member_id)
if member:
return member
member = await ctx.bot.fetch_member(ctx.message.guild_id, member_id)
if member:
return member
raise BadArgument(f"Member '{argument}' not found.")
class RoleConverter(Converter["Role"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "Role":
if not ctx.message.guild_id:
raise BadArgument("Role converter requires guild context.")
match = re.match(r"<@&(?P<id>\d+)>$", argument)
role_id = match.group("id") if match else argument
guild = ctx.bot.get_guild(ctx.message.guild_id)
if guild:
role = guild.get_role(role_id)
if role:
return role
role = await ctx.bot.fetch_role(ctx.message.guild_id, role_id)
if role:
return role
raise BadArgument(f"Role '{argument}' not found.")
class GuildConverter(Converter["Guild"]):
async def convert(self, ctx: "CommandContext", argument: str) -> "Guild":
guild_id = argument.strip("<>") # allow <id> style
guild = ctx.bot.get_guild(guild_id)
if guild:
return guild
guild = await ctx.bot.fetch_guild(guild_id)
if guild:
return guild
raise BadArgument(f"Guild '{argument}' not found.")
# Default converters mapping
DEFAULT_CONVERTERS: dict[type, Converter[Any]] = {
int: IntConverter(),
float: FloatConverter(),
bool: BoolConverter(),
str: StringConverter(),
Member: MemberConverter(),
Guild: GuildConverter(),
Role: RoleConverter(),
# User: UserConverter(), # Add when User model and converter are ready
}
async def run_converters(ctx: "CommandContext", annotation: Any, argument: str) -> Any:
"""
Attempts to run a converter for the given annotation and argument.
"""
converter = DEFAULT_CONVERTERS.get(annotation)
if converter:
return await converter.convert(ctx, argument)
# If no direct converter, check if annotation itself is a Converter subclass
if inspect.isclass(annotation) and issubclass(annotation, Converter):
try:
instance = annotation() # type: ignore
return await instance.convert(ctx, argument)
except Exception as e: # Catch instantiation errors or other issues
raise BadArgument(
f"Failed to use custom converter {annotation.__name__}: {e}"
)
# If it's a custom class that's not a Converter, we can't handle it by default
# Or if it's a complex type hint like Union, Optional, Literal etc.
# This part needs more advanced logic for those.
# For now, if no specific converter, and it's not 'str', raise error or return as str?
# Let's be strict for now if an annotation is given but no converter found.
if annotation is not str and annotation is not inspect.Parameter.empty:
raise BadArgument(f"No converter found for type annotation '{annotation}'.")
return argument # Default to string if no annotation or annotation is str
# Need to import inspect for the run_converters function
import inspect

View File

@ -0,0 +1,490 @@
# disagreement/ext/commands/core.py
import asyncio
import inspect
from typing import (
TYPE_CHECKING,
Optional,
List,
Dict,
Any,
Union,
Callable,
Awaitable,
Tuple,
get_origin,
get_args,
)
from .view import StringView
from .errors import (
CommandError,
CommandNotFound,
BadArgument,
MissingRequiredArgument,
ArgumentParsingError,
CheckFailure,
CommandInvokeError,
)
from .converters import run_converters, DEFAULT_CONVERTERS, Converter
from .cog import Cog
from disagreement.typing import Typing
if TYPE_CHECKING:
from disagreement.client import Client
from disagreement.models import Message, User
class Command:
"""
Represents a bot command.
Attributes:
name (str): The primary name of the command.
callback (Callable[..., Awaitable[None]]): The coroutine function to execute.
aliases (List[str]): Alternative names for the command.
brief (Optional[str]): A short description for help commands.
description (Optional[str]): A longer description for help commands.
cog (Optional['Cog']): Reference to the Cog this command belongs to.
params (Dict[str, inspect.Parameter]): Cached parameters of the callback.
"""
def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
if not asyncio.iscoroutinefunction(callback):
raise TypeError("Command callback must be a coroutine function.")
self.callback: Callable[..., Awaitable[None]] = callback
self.name: str = attrs.get("name", callback.__name__)
self.aliases: List[str] = attrs.get("aliases", [])
self.brief: Optional[str] = attrs.get("brief")
self.description: Optional[str] = attrs.get("description") or callback.__doc__
self.cog: Optional["Cog"] = attrs.get("cog")
self.params = inspect.signature(callback).parameters
self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = []
if hasattr(callback, "__command_checks__"):
self.checks.extend(getattr(callback, "__command_checks__"))
def add_check(
self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
) -> None:
self.checks.append(predicate)
async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
from .errors import CheckFailure
for predicate in self.checks:
result = predicate(ctx)
if inspect.isawaitable(result):
result = await result
if not result:
raise CheckFailure("Check predicate failed.")
if self.cog:
await self.callback(self.cog, ctx, *args, **kwargs)
else:
await self.callback(ctx, *args, **kwargs)
class CommandContext:
"""
Represents the context in which a command is being invoked.
"""
def __init__(
self,
*,
message: "Message",
bot: "Client",
prefix: str,
command: "Command",
invoked_with: str,
args: Optional[List[Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
cog: Optional["Cog"] = None,
):
self.message: "Message" = message
self.bot: "Client" = bot
self.prefix: str = prefix
self.command: "Command" = command
self.invoked_with: str = invoked_with
self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {}
self.cog: Optional["Cog"] = cog
self.author: "User" = message.author
async def reply(
self,
content: str,
*,
mention_author: Optional[bool] = None,
**kwargs: Any,
) -> "Message":
"""Replies to the invoking message.
Parameters
----------
content: str
The content to send.
mention_author: Optional[bool]
Whether to mention the author in the reply. If ``None`` the
client's :attr:`mention_replies` value is used.
"""
allowed_mentions = kwargs.pop("allowed_mentions", None)
if mention_author is None:
mention_author = getattr(self.bot, "mention_replies", False)
if allowed_mentions is None:
allowed_mentions = {"replied_user": mention_author}
else:
allowed_mentions = dict(allowed_mentions)
allowed_mentions.setdefault("replied_user", mention_author)
return await self.bot.send_message(
channel_id=self.message.channel_id,
content=content,
message_reference={
"message_id": self.message.id,
"channel_id": self.message.channel_id,
"guild_id": self.message.guild_id,
},
allowed_mentions=allowed_mentions,
**kwargs,
)
async def send(self, content: str, **kwargs: Any) -> "Message":
return await self.bot.send_message(
channel_id=self.message.channel_id, content=content, **kwargs
)
async def edit(
self,
message: Union[str, "Message"],
*,
content: Optional[str] = None,
**kwargs: Any,
) -> "Message":
"""Edits a message previously sent by the bot."""
message_id = message if isinstance(message, str) else message.id
return await self.bot.edit_message(
channel_id=self.message.channel_id,
message_id=message_id,
content=content,
**kwargs,
)
def typing(self) -> "Typing":
"""Return a typing context manager for this context's channel."""
return self.bot.typing(self.message.channel_id)
class CommandHandler:
"""
Manages command registration, parsing, and dispatching.
"""
def __init__(
self,
client: "Client",
prefix: Union[
str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
],
):
self.client: "Client" = client
self.prefix: Union[
str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
] = prefix
self.commands: Dict[str, Command] = {}
self.cogs: Dict[str, "Cog"] = {}
from .help import HelpCommand
self.add_command(HelpCommand(self))
def add_command(self, command: Command) -> None:
if command.name in self.commands:
raise ValueError(f"Command '{command.name}' is already registered.")
self.commands[command.name.lower()] = command
for alias in command.aliases:
if alias in self.commands:
print(
f"Warning: Alias '{alias}' for command '{command.name}' conflicts with an existing command or alias."
)
self.commands[alias.lower()] = command
def remove_command(self, name: str) -> Optional[Command]:
command = self.commands.pop(name.lower(), None)
if command:
for alias in command.aliases:
self.commands.pop(alias.lower(), None)
return command
def get_command(self, name: str) -> Optional[Command]:
return self.commands.get(name.lower())
def add_cog(self, cog_to_add: "Cog") -> None:
if not isinstance(cog_to_add, Cog):
raise TypeError("Argument must be a subclass of Cog.")
if cog_to_add.cog_name in self.cogs:
raise ValueError(
f"Cog with name '{cog_to_add.cog_name}' is already registered."
)
self.cogs[cog_to_add.cog_name] = cog_to_add
for cmd in cog_to_add.get_commands():
self.add_command(cmd)
if hasattr(self.client, "_event_dispatcher"):
for event_name, callback in cog_to_add.get_listeners():
self.client._event_dispatcher.register(event_name.upper(), callback)
else:
print(
f"Warning: Client does not have '_event_dispatcher'. Listeners for cog '{cog_to_add.cog_name}' not registered."
)
if hasattr(cog_to_add, "cog_load") and inspect.iscoroutinefunction(
cog_to_add.cog_load
):
asyncio.create_task(cog_to_add.cog_load())
print(f"Cog '{cog_to_add.cog_name}' added.")
def remove_cog(self, cog_name: str) -> Optional["Cog"]:
cog_to_remove = self.cogs.pop(cog_name, None)
if cog_to_remove:
for cmd in cog_to_remove.get_commands():
self.remove_command(cmd.name)
if hasattr(self.client, "_event_dispatcher"):
for event_name, callback in cog_to_remove.get_listeners():
print(
f"Note: Listener '{callback.__name__}' for event '{event_name}' from cog '{cog_name}' needs manual unregistration logic in EventDispatcher."
)
if hasattr(cog_to_remove, "cog_unload") and inspect.iscoroutinefunction(
cog_to_remove.cog_unload
):
asyncio.create_task(cog_to_remove.cog_unload())
cog_to_remove._eject()
print(f"Cog '{cog_name}' removed.")
return cog_to_remove
async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
if callable(self.prefix):
if inspect.iscoroutinefunction(self.prefix):
return await self.prefix(self.client, message)
else:
return self.prefix(self.client, message) # type: ignore
return self.prefix
async def _parse_arguments(
self, command: Command, ctx: CommandContext, view: StringView
) -> Tuple[List[Any], Dict[str, Any]]:
args_list = []
kwargs_dict = {}
params_to_parse = list(command.params.values())
if params_to_parse and params_to_parse[0].name == "self" and command.cog:
params_to_parse.pop(0)
if params_to_parse and params_to_parse[0].name == "ctx":
params_to_parse.pop(0)
for param in params_to_parse:
view.skip_whitespace()
final_value_for_param: Any = inspect.Parameter.empty
if param.kind == inspect.Parameter.VAR_POSITIONAL:
while not view.eof:
view.skip_whitespace()
if view.eof:
break
word = view.get_word()
if word or not view.eof:
args_list.append(word)
elif view.eof:
break
break
arg_str_value: Optional[str] = (
None # Holds the raw string for current param
)
if view.eof: # No more input string
if param.default is not inspect.Parameter.empty:
final_value_for_param = param.default
elif param.kind != inspect.Parameter.VAR_KEYWORD:
raise MissingRequiredArgument(param.name)
else: # VAR_KEYWORD at EOF is fine
break
else: # Input available
is_last_pos_str_greedy = (
param == params_to_parse[-1]
and param.annotation is str
and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
)
if is_last_pos_str_greedy:
arg_str_value = view.read_rest().strip()
if (
not arg_str_value
and param.default is not inspect.Parameter.empty
):
final_value_for_param = param.default
else: # Includes empty string if that's what's left
final_value_for_param = arg_str_value
else: # Not greedy, or not string, or not last positional
if view.buffer[view.index] == '"':
arg_str_value = view.get_quoted_string()
if arg_str_value == "" and view.buffer[view.index] == '"':
raise BadArgument(
f"Unterminated quoted string for argument '{param.name}'."
)
else:
arg_str_value = view.get_word()
# If final_value_for_param was not set by greedy logic, try conversion
if final_value_for_param is inspect.Parameter.empty:
if (
arg_str_value is None
): # Should not happen if view.get_word/get_quoted_string is robust
if param.default is not inspect.Parameter.empty:
final_value_for_param = param.default
else:
raise MissingRequiredArgument(param.name)
else: # We have an arg_str_value (could be empty string "" from quotes)
annotation = param.annotation
origin = get_origin(annotation)
if origin is Union: # Handles Optional[T] and Union[T1, T2]
union_args = get_args(annotation)
is_optional = (
len(union_args) == 2 and type(None) in union_args
)
converted_for_union = False
last_err_union: Optional[BadArgument] = None
for t_arg in union_args:
if t_arg is type(None):
continue
try:
final_value_for_param = await run_converters(
ctx, t_arg, arg_str_value
)
converted_for_union = True
break
except BadArgument as e:
last_err_union = e
if not converted_for_union:
if (
is_optional and param.default is None
): # Special handling for Optional[T] if conversion failed
# If arg_str_value was "" and type was Optional[str], StringConverter would return ""
# If arg_str_value was "" and type was Optional[int], BadArgument would be raised.
# This path is for when all actual types in Optional[T] fail conversion.
# If default is None, we can assign None.
final_value_for_param = None
elif last_err_union:
raise last_err_union
else: # Should not be reached if logic is correct
raise BadArgument(
f"Could not convert '{arg_str_value}' to any of {union_args} for param '{param.name}'."
)
elif annotation is inspect.Parameter.empty or annotation is str:
final_value_for_param = arg_str_value
else: # Standard type hint
final_value_for_param = await run_converters(
ctx, annotation, arg_str_value
)
# Final check if value was resolved
if final_value_for_param is inspect.Parameter.empty:
if param.default is not inspect.Parameter.empty:
final_value_for_param = param.default
elif param.kind != inspect.Parameter.VAR_KEYWORD:
# This state implies an issue if required and no default, and no input was parsed.
raise MissingRequiredArgument(
f"Parameter '{param.name}' could not be resolved."
)
# Assign to args_list or kwargs_dict if a value was determined
if final_value_for_param is not inspect.Parameter.empty:
if (
param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
or param.kind == inspect.Parameter.POSITIONAL_ONLY
):
args_list.append(final_value_for_param)
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs_dict[param.name] = final_value_for_param
return args_list, kwargs_dict
async def process_commands(self, message: "Message") -> None:
if not message.content:
return
prefix_to_use = await self.get_prefix(message)
if not prefix_to_use:
return
actual_prefix: Optional[str] = None
if isinstance(prefix_to_use, list):
for p in prefix_to_use:
if message.content.startswith(p):
actual_prefix = p
break
if not actual_prefix:
return
elif isinstance(prefix_to_use, str):
if message.content.startswith(prefix_to_use):
actual_prefix = prefix_to_use
else:
return
else:
return
if actual_prefix is None:
return
content_without_prefix = message.content[len(actual_prefix) :]
view = StringView(content_without_prefix)
command_name = view.get_word()
if not command_name:
return
command = self.get_command(command_name)
if not command:
return
ctx = CommandContext(
message=message,
bot=self.client,
prefix=actual_prefix,
command=command,
invoked_with=command_name,
cog=command.cog,
)
try:
parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
ctx.args = parsed_args
ctx.kwargs = parsed_kwargs
await command.invoke(ctx, *parsed_args, **parsed_kwargs)
except CommandError as e:
print(f"Command error for '{command.name}': {e}")
if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, e)
except Exception as e:
print(f"Unexpected error invoking command '{command.name}': {e}")
exc = CommandInvokeError(e)
if hasattr(self.client, "on_command_error"):
await self.client.on_command_error(ctx, exc)

View File

@ -0,0 +1,150 @@
# disagreement/ext/commands/decorators.py
import asyncio
import inspect
import time
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
if TYPE_CHECKING:
from .core import Command, CommandContext # For type hinting return or internal use
# from .cog import Cog # For Cog specific decorators
def command(
name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
) -> Callable:
"""
A decorator that transforms a function into a Command.
Args:
name (Optional[str]): The name of the command. Defaults to the function name.
aliases (Optional[List[str]]): Alternative names for the command.
**attrs: Additional attributes to pass to the Command constructor
(e.g., brief, description, hidden).
Returns:
Callable: A decorator that registers the command.
"""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Command callback must be a coroutine function.")
from .core import (
Command,
) # Late import to avoid circular dependencies at module load time
# The actual registration will happen when a Cog is added or if commands are global.
# For now, this decorator creates a Command instance and attaches it to the function,
# or returns a Command instance that can be collected.
cmd_name = name or func.__name__
# Store command attributes on the function itself for later collection by Cog or Client
# This is a common pattern.
if hasattr(func, "__command_attrs__"):
# This case might occur if decorators are stacked in an unusual way,
# or if a function is decorated multiple times (which should be disallowed or handled).
# For now, let's assume one @command decorator per function.
raise TypeError("Function is already a command or has command attributes.")
# Create the command object. It will be registered by the Cog or Client.
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
# We can attach the command object to the function, so Cogs can find it.
func.__command_object__ = cmd # type: ignore # type: ignore[attr-defined]
return func # Return the original function, now marked.
# Or return `cmd` if commands are registered globally immediately.
# For Cogs, returning `func` and letting Cog collect is cleaner.
return decorator
def listener(
name: Optional[str] = None,
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""
A decorator that marks a function as an event listener within a Cog.
The actual registration happens when the Cog is added to the client.
Args:
name (Optional[str]): The name of the event to listen to.
Defaults to the function name (e.g., `on_message`).
"""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Listener callback must be a coroutine function.")
# 'name' here is from the outer 'listener' scope (closure)
actual_event_name = name or func.__name__
# Store listener info on the function for Cog to collect
setattr(func, "__listener_name__", actual_event_name)
return func
return decorator # This must be correctly indented under 'listener'
def check(
predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator to add a check to a command."""
def decorator(
func: Callable[..., Awaitable[None]],
) -> Callable[..., Awaitable[None]]:
checks = getattr(func, "__command_checks__", [])
checks.append(predicate)
setattr(func, "__command_checks__", checks)
return func
return decorator
def check_any(
*predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Decorator that passes if any predicate returns ``True``."""
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CheckAnyFailure, CheckFailure
errors = []
for p in predicates:
try:
result = p(ctx)
if inspect.isawaitable(result):
result = await result
if result:
return True
except CheckFailure as e:
errors.append(e)
raise CheckAnyFailure(errors)
return check(predicate)
def cooldown(
rate: int, per: float
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
"""Simple per-user cooldown decorator."""
buckets: dict[str, dict[str, float]] = {}
async def predicate(ctx: "CommandContext") -> bool:
from .errors import CommandOnCooldown
now = time.monotonic()
user_buckets = buckets.setdefault(ctx.command.name, {})
reset = user_buckets.get(ctx.author.id, 0)
if now < reset:
raise CommandOnCooldown(reset - now)
user_buckets[ctx.author.id] = now + per
return True
return check(predicate)

View File

@ -0,0 +1,76 @@
# disagreement/ext/commands/errors.py
"""
Custom exceptions for the command extension.
"""
from disagreement.errors import DisagreementException
class CommandError(DisagreementException):
"""Base exception for errors raised by the commands extension."""
pass
class CommandNotFound(CommandError):
"""Exception raised when a command is not found."""
def __init__(self, command_name: str):
self.command_name = command_name
super().__init__(f"Command '{command_name}' not found.")
class BadArgument(CommandError):
"""Exception raised when a command argument fails to parse or validate."""
pass
class MissingRequiredArgument(BadArgument):
"""Exception raised when a required command argument is missing."""
def __init__(self, param_name: str):
self.param_name = param_name
super().__init__(f"Missing required argument: {param_name}")
class ArgumentParsingError(BadArgument):
"""Exception raised during the argument parsing process."""
pass
class CheckFailure(CommandError):
"""Exception raised when a command check fails."""
pass
class CheckAnyFailure(CheckFailure):
"""Raised when :func:`check_any` fails all checks."""
def __init__(self, errors: list[CheckFailure]):
self.errors = errors
msg = "; ".join(str(e) for e in errors)
super().__init__(f"All checks failed: {msg}")
class CommandOnCooldown(CheckFailure):
"""Raised when a command is invoked while on cooldown."""
def __init__(self, retry_after: float):
self.retry_after = retry_after
super().__init__(f"Command is on cooldown. Retry in {retry_after:.2f}s")
class CommandInvokeError(CommandError):
"""Exception raised when an error occurs during command invocation."""
def __init__(self, original: Exception):
self.original = original
super().__init__(f"Error during command invocation: {original}")
# Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc.
# These might inherit from BadArgument.

View File

@ -0,0 +1,37 @@
# disagreement/ext/commands/help.py
from typing import List, Optional
from .core import Command, CommandContext, CommandHandler
class HelpCommand(Command):
"""Built-in command that displays help information for other commands."""
def __init__(self, handler: CommandHandler) -> None:
self.handler = handler
async def callback(ctx: CommandContext, command: Optional[str] = None) -> None:
if command:
cmd = handler.get_command(command)
if not cmd or cmd.name.lower() != command.lower():
await ctx.send(f"Command '{command}' not found.")
return
description = cmd.description or cmd.brief or "No description provided."
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
else:
lines: List[str] = []
for registered in dict.fromkeys(handler.commands.values()):
brief = registered.brief or registered.description or ""
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
if lines:
await ctx.send("\n".join(lines))
else:
await ctx.send("No commands available.")
super().__init__(
callback,
name="help",
brief="Show command help.",
description="Displays help for commands.",
)

View File

@ -0,0 +1,103 @@
# disagreement/ext/commands/view.py
import re
class StringView:
"""
A utility class to help with parsing strings, particularly for command arguments.
It keeps track of the current position in the string and provides methods
to read parts of it.
"""
def __init__(self, buffer: str):
self.buffer: str = buffer
self.original: str = buffer # Keep original for error reporting if needed
self.index: int = 0
self.end: int = len(buffer)
self.previous: int = 0 # Index before the last successful read
@property
def remaining(self) -> str:
"""Returns the rest of the string that hasn't been consumed."""
return self.buffer[self.index :]
@property
def eof(self) -> bool:
"""Checks if the end of the string has been reached."""
return self.index >= self.end
def skip_whitespace(self) -> None:
"""Skips any leading whitespace from the current position."""
while not self.eof and self.buffer[self.index].isspace():
self.index += 1
def get_word(self) -> str:
"""
Reads a "word" from the current position.
A word is a sequence of non-whitespace characters.
"""
self.skip_whitespace()
if self.eof:
return ""
self.previous = self.index
match = re.match(r"\S+", self.buffer[self.index :])
if match:
word = match.group(0)
self.index += len(word)
return word
return "" # Should not happen if not eof and skip_whitespace was called
def get_quoted_string(self) -> str:
"""
Reads a string enclosed in double quotes.
Handles escaped quotes inside the string.
"""
self.skip_whitespace()
if self.eof or self.buffer[self.index] != '"':
return "" # Or raise an error, or return None
self.previous = self.index
self.index += 1 # Skip the opening quote
result = []
escaped = False
while not self.eof:
char = self.buffer[self.index]
self.index += 1
if escaped:
result.append(char)
escaped = False
elif char == "\\":
escaped = True
elif char == '"':
return "".join(result) # Closing quote found
else:
result.append(char)
# If loop finishes, means EOF was reached before closing quote
# This is an error condition. Restore index and indicate failure.
self.index = self.previous
# Consider raising an error like UnterminatedQuotedStringError
return "" # Or raise
def read_rest(self) -> str:
"""Reads all remaining characters from the current position."""
self.skip_whitespace()
if self.eof:
return ""
self.previous = self.index
result = self.buffer[self.index :]
self.index = self.end
return result
def undo(self) -> None:
"""Resets the current position to before the last successful read."""
self.index = self.previous
# Could add more methods like:
# peek() - look at next char without consuming
# match_regex(pattern) - consume if regex matches

View File

@ -0,0 +1,43 @@
from __future__ import annotations
from importlib import import_module
import sys
from types import ModuleType
from typing import Dict
__all__ = ["load_extension", "unload_extension"]
_loaded_extensions: Dict[str, ModuleType] = {}
def load_extension(name: str) -> ModuleType:
"""Load an extension by name.
The extension module must define a ``setup`` coroutine or function that
will be called after loading. Any value returned by ``setup`` is ignored.
"""
if name in _loaded_extensions:
raise ValueError(f"Extension '{name}' already loaded")
module = import_module(name)
if not hasattr(module, "setup"):
raise ImportError(f"Extension '{name}' does not define a setup function")
module.setup()
_loaded_extensions[name] = module
return module
def unload_extension(name: str) -> None:
"""Unload a previously loaded extension."""
module = _loaded_extensions.pop(name, None)
if module is None:
raise ValueError(f"Extension '{name}' is not loaded")
if hasattr(module, "teardown"):
module.teardown()
sys.modules.pop(name, None)

89
disagreement/ext/tasks.py Normal file
View File

@ -0,0 +1,89 @@
import asyncio
from typing import Any, Awaitable, Callable, Optional
__all__ = ["loop", "Task"]
class Task:
"""Simple repeating task."""
def __init__(self, coro: Callable[..., Awaitable[Any]], *, seconds: float) -> None:
self._coro = coro
self._seconds = float(seconds)
self._task: Optional[asyncio.Task[None]] = None
async def _run(self, *args: Any, **kwargs: Any) -> None:
try:
while True:
await self._coro(*args, **kwargs)
await asyncio.sleep(self._seconds)
except asyncio.CancelledError:
pass
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
if self._task is None or self._task.done():
self._task = asyncio.create_task(self._run(*args, **kwargs))
return self._task
def stop(self) -> None:
if self._task is not None:
self._task.cancel()
self._task = None
@property
def running(self) -> bool:
return self._task is not None and not self._task.done()
class _Loop:
def __init__(self, func: Callable[..., Awaitable[Any]], seconds: float) -> None:
self.func = func
self.seconds = seconds
self._task: Optional[Task] = None
self._owner: Any = None
def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop":
return _BoundLoop(self, obj)
def _coro(self, *args: Any, **kwargs: Any) -> Awaitable[Any]:
if self._owner is None:
return self.func(*args, **kwargs)
return self.func(self._owner, *args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
self._task = Task(self._coro, seconds=self.seconds)
return self._task.start(*args, **kwargs)
def stop(self) -> None:
if self._task is not None:
self._task.stop()
@property
def running(self) -> bool:
return self._task.running if self._task else False
class _BoundLoop:
def __init__(self, parent: _Loop, owner: Any) -> None:
self._parent = parent
self._owner = owner
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
self._parent._owner = self._owner
return self._parent.start(*args, **kwargs)
def stop(self) -> None:
self._parent.stop()
@property
def running(self) -> bool:
return self._parent.running
def loop(*, seconds: float) -> Callable[[Callable[..., Awaitable[Any]]], _Loop]:
"""Decorator to create a looping task."""
def decorator(func: Callable[..., Awaitable[Any]]) -> _Loop:
return _Loop(func, seconds)
return decorator

490
disagreement/gateway.py Normal file
View File

@ -0,0 +1,490 @@
# disagreement/gateway.py
"""
Manages the WebSocket connection to the Discord Gateway.
"""
import asyncio
import traceback
import aiohttp
import json
import zlib
import time
from typing import Optional, TYPE_CHECKING, Any, Dict
from .enums import GatewayOpcode, GatewayIntent
from .errors import GatewayException, DisagreementException, AuthenticationError
from .interactions import Interaction
if TYPE_CHECKING:
from .client import Client # For type hinting
from .event_dispatcher import EventDispatcher
from .http import HTTPClient
from .interactions import Interaction # Added for INTERACTION_CREATE
# ZLIB Decompression constants
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
MAX_DECOMPRESSION_SIZE = 10 * 1024 * 1024 # 10 MiB, adjust as needed
class GatewayClient:
"""
Handles the Discord Gateway WebSocket connection, heartbeating, and event dispatching.
"""
def __init__(
self,
http_client: "HTTPClient",
event_dispatcher: "EventDispatcher",
token: str,
intents: int,
client_instance: "Client", # Pass the main client instance
verbose: bool = False,
*,
shard_id: Optional[int] = None,
shard_count: Optional[int] = None,
):
self._http: "HTTPClient" = http_client
self._dispatcher: "EventDispatcher" = event_dispatcher
self._token: str = token
self._intents: int = intents
self._client_instance: "Client" = client_instance # Store client instance
self.verbose: bool = verbose
self._shard_id: Optional[int] = shard_id
self._shard_count: Optional[int] = shard_count
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
self._heartbeat_interval: Optional[float] = None
self._last_sequence: Optional[int] = None
self._session_id: Optional[str] = None
self._resume_gateway_url: Optional[str] = None
self._keep_alive_task: Optional[asyncio.Task] = None
self._receive_task: Optional[asyncio.Task] = None
# For zlib decompression
self._buffer = bytearray()
self._inflator = zlib.decompressobj()
async def _decompress_message(
self, message_bytes: bytes
) -> Optional[Dict[str, Any]]:
"""Decompresses a zlib-compressed message from the Gateway."""
self._buffer.extend(message_bytes)
if len(message_bytes) < 4 or message_bytes[-4:] != ZLIB_SUFFIX:
# Message is not complete or not zlib compressed in the expected way
return None
# Or handle partial messages if Discord ever sends them fragmented like this,
# but typically each binary message is a complete zlib stream.
try:
decompressed = self._inflator.decompress(self._buffer)
self._buffer.clear() # Reset buffer after successful decompression
return json.loads(decompressed.decode("utf-8"))
except zlib.error as e:
print(f"Zlib decompression error: {e}")
self._buffer.clear() # Clear buffer on error
self._inflator = zlib.decompressobj() # Reset inflator
return None
except json.JSONDecodeError as e:
print(f"JSON decode error after decompression: {e}")
return None
async def _send_json(self, payload: Dict[str, Any]):
if self._ws and not self._ws.closed:
if self.verbose:
print(f"GATEWAY SEND: {payload}")
await self._ws.send_json(payload)
else:
print("Gateway send attempted but WebSocket is closed or not available.")
# raise GatewayException("WebSocket is not connected.")
async def _heartbeat(self):
"""Sends a heartbeat to the Gateway."""
payload = {"op": GatewayOpcode.HEARTBEAT, "d": self._last_sequence}
await self._send_json(payload)
# print("Sent heartbeat.")
async def _keep_alive(self):
"""Manages the heartbeating loop."""
if self._heartbeat_interval is None:
# This should not happen if HELLO was processed correctly
print("Error: Heartbeat interval not set. Cannot start keep_alive.")
return
try:
while True:
await self._heartbeat()
await asyncio.sleep(
self._heartbeat_interval / 1000
) # Interval is in ms
except asyncio.CancelledError:
print("Keep_alive task cancelled.")
except Exception as e:
print(f"Error in keep_alive loop: {e}")
# Potentially trigger a reconnect here or notify client
await self._client_instance.close_gateway(code=1000) # Generic close
async def _identify(self):
"""Sends the IDENTIFY payload to the Gateway."""
payload = {
"op": GatewayOpcode.IDENTIFY,
"d": {
"token": self._token,
"intents": self._intents,
"properties": {
"$os": "python", # Or platform.system()
"$browser": "disagreement", # Library name
"$device": "disagreement", # Library name
},
"compress": True, # Request zlib compression
},
}
if self._shard_id is not None and self._shard_count is not None:
payload["d"]["shard"] = [self._shard_id, self._shard_count]
await self._send_json(payload)
print("Sent IDENTIFY.")
async def _resume(self):
"""Sends the RESUME payload to the Gateway."""
if not self._session_id or self._last_sequence is None:
print("Cannot RESUME: session_id or last_sequence is missing.")
await self._identify() # Fallback to identify
return
payload = {
"op": GatewayOpcode.RESUME,
"d": {
"token": self._token,
"session_id": self._session_id,
"seq": self._last_sequence,
},
}
await self._send_json(payload)
print(
f"Sent RESUME for session {self._session_id} at sequence {self._last_sequence}."
)
async def update_presence(
self,
status: str,
activity_name: Optional[str] = None,
activity_type: int = 0,
since: int = 0,
afk: bool = False,
):
"""Sends the presence update payload to the Gateway."""
payload = {
"op": GatewayOpcode.PRESENCE_UPDATE,
"d": {
"since": since,
"activities": [
{
"name": activity_name,
"type": activity_type,
}
]
if activity_name
else [],
"status": status,
"afk": afk,
},
}
await self._send_json(payload)
async def _handle_dispatch(self, data: Dict[str, Any]):
"""Handles DISPATCH events (actual Discord events)."""
event_name = data.get("t")
sequence_num = data.get("s")
raw_event_d_payload = data.get(
"d"
) # This is the 'd' field from the gateway event
if sequence_num is not None:
self._last_sequence = sequence_num
if event_name == "READY": # Special handling for READY
if not isinstance(raw_event_d_payload, dict):
print(
f"Error: READY event 'd' payload is not a dict or is missing: {raw_event_d_payload}"
)
# Consider raising an error or attempting a reconnect
return
self._session_id = raw_event_d_payload.get("session_id")
self._resume_gateway_url = raw_event_d_payload.get("resume_gateway_url")
app_id_str = "N/A"
# Store application_id on the client instance
if (
"application" in raw_event_d_payload
and isinstance(raw_event_d_payload["application"], dict)
and "id" in raw_event_d_payload["application"]
):
app_id_value = raw_event_d_payload["application"]["id"]
self._client_instance.application_id = (
app_id_value # Snowflake can be str or int
)
app_id_str = str(app_id_value)
else:
print(
f"Warning: Could not find application ID in READY payload. App commands may not work."
)
# Parse and store the bot's own user object
if "user" in raw_event_d_payload and isinstance(
raw_event_d_payload["user"], dict
):
try:
# Assuming Client has a parse_user method that takes user data dict
# and returns a User object, also caching it.
bot_user_obj = self._client_instance.parse_user(
raw_event_d_payload["user"]
)
self._client_instance.user = bot_user_obj
print(
f"Gateway READY. Bot User: {bot_user_obj.username}#{bot_user_obj.discriminator}. Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}"
)
except Exception as e:
print(f"Error parsing bot user from READY payload: {e}")
print(
f"Gateway READY (user parse failed). Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}"
)
else:
print(
f"Warning: Bot user object not found or invalid in READY payload."
)
print(
f"Gateway READY (no user). Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}"
)
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
elif event_name == "INTERACTION_CREATE":
# print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}")
if isinstance(raw_event_d_payload, dict):
interaction = Interaction(
data=raw_event_d_payload, client_instance=self._client_instance
)
await self._dispatcher.dispatch(
"INTERACTION_CREATE", raw_event_d_payload
)
# Dispatch to a new client method that will then call AppCommandHandler
if hasattr(self._client_instance, "process_interaction"):
asyncio.create_task(
self._client_instance.process_interaction(interaction)
) # type: ignore
else:
print(
"Warning: Client instance does not have process_interaction method for INTERACTION_CREATE."
)
else:
print(
f"Error: INTERACTION_CREATE event 'd' payload is not a dict: {raw_event_d_payload}"
)
elif event_name == "RESUMED":
print("Gateway RESUMED successfully.")
# RESUMED 'd' payload is often an empty object or debug info.
# Ensure it's a dict for the dispatcher.
event_data_to_dispatch = (
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
)
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
elif event_name:
# For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing.
# Models/parsers in EventDispatcher will need to handle potentially empty dicts.
event_data_to_dispatch = (
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
)
# print(f"GATEWAY RECV EVENT: {event_name} | DATA: {event_data_to_dispatch}")
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
else:
print(f"Received dispatch with no event name: {data}")
async def _process_message(self, msg: aiohttp.WSMessage):
"""Processes a single message from the WebSocket."""
if msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
except json.JSONDecodeError:
print(
f"Failed to decode JSON from Gateway: {msg.data[:200]}"
) # Log snippet
return
elif msg.type == aiohttp.WSMsgType.BINARY:
decompressed_data = await self._decompress_message(msg.data)
if decompressed_data is None:
print("Failed to decompress or decode binary message from Gateway.")
return
data = decompressed_data
elif msg.type == aiohttp.WSMsgType.ERROR:
print(
f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}"
)
raise GatewayException(
f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}"
)
elif msg.type == aiohttp.WSMsgType.CLOSED:
close_code = (
self._ws.close_code
if self._ws and hasattr(self._ws, "close_code")
else "N/A"
)
print(f"WebSocket connection closed by server. Code: {close_code}")
# Raise an exception to signal the closure to the client's main run loop
raise GatewayException(f"WebSocket closed by server. Code: {close_code}")
else:
print(f"Received unhandled WebSocket message type: {msg.type}")
return
if self.verbose:
print(f"GATEWAY RECV: {data}")
op = data.get("op")
# 'd' payload (event_data) is handled specifically by each opcode handler below
if op == GatewayOpcode.DISPATCH:
await self._handle_dispatch(data) # _handle_dispatch will extract 'd'
elif op == GatewayOpcode.HEARTBEAT: # Server requests a heartbeat
await self._heartbeat()
elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect
print("Gateway requested RECONNECT. Closing and will attempt to reconnect.")
await self.close(code=4000) # Use a non-1000 code to indicate reconnect
elif op == GatewayOpcode.INVALID_SESSION:
# The 'd' payload for INVALID_SESSION is a boolean indicating resumability
can_resume = data.get("d") is True
print(f"Gateway indicated INVALID_SESSION. Resumable: {can_resume}")
if not can_resume:
self._session_id = None # Clear session_id to force re-identify
self._last_sequence = None
# Close and reconnect. The connect logic will decide to resume or identify.
await self.close(
code=4000 if can_resume else 4009
) # 4009 for non-resumable
elif op == GatewayOpcode.HELLO:
hello_d_payload = data.get("d")
if (
not isinstance(hello_d_payload, dict)
or "heartbeat_interval" not in hello_d_payload
):
print(
f"Error: HELLO event 'd' payload is invalid or missing heartbeat_interval: {hello_d_payload}"
)
await self.close(code=1011) # Internal error, malformed HELLO
return
self._heartbeat_interval = hello_d_payload["heartbeat_interval"]
print(f"Gateway HELLO. Heartbeat interval: {self._heartbeat_interval}ms.")
# Start heartbeating
if self._keep_alive_task:
self._keep_alive_task.cancel()
self._keep_alive_task = self._loop.create_task(self._keep_alive())
# Identify or Resume
if self._session_id and self._resume_gateway_url: # Check if we can resume
print("Attempting to RESUME session.")
await self._resume()
else:
print("Performing initial IDENTIFY.")
await self._identify()
elif op == GatewayOpcode.HEARTBEAT_ACK:
# print("Received heartbeat ACK.")
pass # Good, connection is alive
else:
print(f"Received unhandled Gateway Opcode: {op} with data: {data}")
async def _receive_loop(self):
"""Continuously receives and processes messages from the WebSocket."""
if not self._ws or self._ws.closed:
print("Receive loop cannot start: WebSocket is not connected or closed.")
return
try:
async for msg in self._ws:
await self._process_message(msg)
except asyncio.CancelledError:
print("Receive_loop task cancelled.")
except aiohttp.ClientConnectionError as e:
print(f"ClientConnectionError in receive_loop: {e}. Attempting reconnect.")
# This might be handled by an outer reconnect loop in the Client class
await self.close(code=1006) # Abnormal closure
except Exception as e:
print(f"Unexpected error in receive_loop: {e}")
traceback.print_exc()
# Consider specific error types for more granular handling
await self.close(code=1011) # Internal error
finally:
print("Receive_loop ended.")
# If the loop ends unexpectedly (not due to explicit close),
# the main client might want to try reconnecting.
async def connect(self):
"""Connects to the Discord Gateway."""
if self._ws and not self._ws.closed:
print("Gateway already connected or connecting.")
return
gateway_url = (
self._resume_gateway_url or (await self._http.get_gateway_bot())["url"]
)
if not gateway_url.endswith("?v=10&encoding=json&compress=zlib-stream"):
gateway_url += "?v=10&encoding=json&compress=zlib-stream"
print(f"Connecting to Gateway: {gateway_url}")
try:
await self._http._ensure_session() # Ensure the HTTP client's session is active
assert (
self._http._session is not None
), "HTTPClient session not initialized after ensure_session"
self._ws = await self._http._session.ws_connect(gateway_url, max_msg_size=0)
print("Gateway WebSocket connection established.")
if self._receive_task:
self._receive_task.cancel()
self._receive_task = self._loop.create_task(self._receive_loop())
except aiohttp.ClientConnectorError as e:
raise GatewayException(
f"Failed to connect to Gateway (Connector Error): {e}"
) from e
except aiohttp.WSServerHandshakeError as e:
if e.status == 401: # Unauthorized during handshake
raise AuthenticationError(
f"Gateway handshake failed (401 Unauthorized): {e.message}. Check your bot token."
) from e
raise GatewayException(
f"Gateway handshake failed (Status: {e.status}): {e.message}"
) from e
except Exception as e: # Catch other potential errors during connection
raise GatewayException(
f"An unexpected error occurred during Gateway connection: {e}"
) from e
async def close(self, code: int = 1000):
"""Closes the Gateway connection."""
print(f"Closing Gateway connection with code {code}...")
if self._keep_alive_task and not self._keep_alive_task.done():
self._keep_alive_task.cancel()
try:
await self._keep_alive_task
except asyncio.CancelledError:
pass # Expected
if self._receive_task and not self._receive_task.done():
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass # Expected
if self._ws and not self._ws.closed:
await self._ws.close(code=code)
print("Gateway WebSocket closed.")
self._ws = None
# Do not reset session_id, last_sequence, or resume_gateway_url here
# if the close code indicates a resumable disconnect (e.g. 4000-4009, or server-initiated RECONNECT)
# The connect logic will decide whether to resume or re-identify.
# However, if it's a non-resumable close (e.g. Invalid Session non-resumable), clear them.
if code == 4009: # Invalid session, not resumable
print("Clearing session state due to non-resumable invalid session.")
self._session_id = None
self._last_sequence = None
self._resume_gateway_url = None # This might be re-fetched anyway

657
disagreement/http.py Normal file
View File

@ -0,0 +1,657 @@
# disagreement/http.py
"""
HTTP client for interacting with the Discord REST API.
"""
import asyncio
import aiohttp # pylint: disable=import-error
import json
from urllib.parse import quote
from typing import Optional, Dict, Any, Union, TYPE_CHECKING, List
from .errors import (
HTTPException,
RateLimitError,
AuthenticationError,
DisagreementException,
)
from . import __version__ # For User-Agent
if TYPE_CHECKING:
from .client import Client
from .models import Message
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
# Discord API constants
API_BASE_URL = "https://discord.com/api/v10" # Using API v10
class HTTPClient:
"""Handles HTTP requests to the Discord API."""
def __init__(
self,
token: str,
client_session: Optional[aiohttp.ClientSession] = None,
verbose: bool = False,
):
self.token = token
self._session: Optional[aiohttp.ClientSession] = (
client_session # Can be externally managed
)
self.user_agent = f"DiscordBot (https://github.com/yourusername/disagreement, {__version__})" # Customize URL
self.verbose = verbose
self._global_rate_limit_lock = asyncio.Event()
self._global_rate_limit_lock.set() # Initially unlocked
async def _ensure_session(self):
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
async def close(self):
"""Closes the underlying aiohttp.ClientSession."""
if self._session and not self._session.closed:
await self._session.close()
async def request(
self,
method: str,
endpoint: str,
payload: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[Dict[str, Any]] = None,
is_json: bool = True,
use_auth_header: bool = True,
custom_headers: Optional[Dict[str, str]] = None,
) -> Any:
"""Makes an HTTP request to the Discord API."""
await self._ensure_session()
url = f"{API_BASE_URL}{endpoint}"
final_headers: Dict[str, str] = { # Renamed to final_headers
"User-Agent": self.user_agent,
}
if use_auth_header:
final_headers["Authorization"] = f"Bot {self.token}"
if is_json and payload:
final_headers["Content-Type"] = "application/json"
if custom_headers: # Merge custom headers
final_headers.update(custom_headers)
if self.verbose:
print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
# Global rate limit handling
await self._global_rate_limit_lock.wait()
for attempt in range(5): # Max 5 retries for rate limits
assert self._session is not None, "ClientSession not initialized"
async with self._session.request(
method,
url,
json=payload if is_json else None,
data=payload if not is_json else None,
headers=final_headers,
params=params,
) as response:
data = None
try:
if response.headers.get("Content-Type", "").startswith(
"application/json"
):
data = await response.json()
else:
# For non-JSON responses, like fetching images or other files
# We might return the raw response or handle it differently
# For now, let's assume most API calls expect JSON
data = await response.text()
except (aiohttp.ContentTypeError, json.JSONDecodeError):
data = (
await response.text()
) # Fallback to text if JSON parsing fails
if self.verbose:
print(f"HTTP RESPONSE: {response.status} {url} | {data}")
if 200 <= response.status < 300:
if response.status == 204:
return None
return data
# Rate limit handling
if response.status == 429: # Rate limited
retry_after_str = response.headers.get("Retry-After", "1")
try:
retry_after = float(retry_after_str)
except ValueError:
retry_after = 1.0 # Default retry if header is malformed
is_global = (
response.headers.get("X-RateLimit-Global", "false").lower()
== "true"
)
error_message = f"Rate limited on {method} {endpoint}."
if data and isinstance(data, dict) and "message" in data:
error_message += f" Discord says: {data['message']}"
if is_global:
self._global_rate_limit_lock.clear()
await asyncio.sleep(retry_after)
self._global_rate_limit_lock.set()
else:
await asyncio.sleep(retry_after)
if attempt < 4: # Don't log on the last attempt before raising
print(
f"{error_message} Retrying after {retry_after}s (Attempt {attempt + 1}/5). Global: {is_global}"
)
continue # Retry the request
else: # Last attempt failed
raise RateLimitError(
response,
message=error_message,
retry_after=retry_after,
is_global=is_global,
)
# Other error handling
if response.status == 401: # Unauthorized
raise AuthenticationError(response, "Invalid token provided.")
if response.status == 403: # Forbidden
raise HTTPException(
response,
"Missing permissions or access denied.",
status=response.status,
text=str(data),
)
# General HTTP error
error_text = str(data) if data else "Unknown error"
discord_error_code = (
data.get("code") if isinstance(data, dict) else None
)
raise HTTPException(
response,
f"API Error on {method} {endpoint}: {error_text}",
status=response.status,
text=error_text,
error_code=discord_error_code,
)
# Should not be reached if retries are exhausted by RateLimitError
raise DisagreementException(
f"Failed request to {method} {endpoint} after multiple retries."
)
# --- Specific API call methods ---
async def get_gateway_bot(self) -> Dict[str, Any]:
"""Gets the WSS URL and sharding information for the Gateway."""
return await self.request("GET", "/gateway/bot")
async def send_message(
self,
channel_id: str,
content: Optional[str] = None,
tts: bool = False,
embeds: Optional[List[Dict[str, Any]]] = None,
components: Optional[List[Dict[str, Any]]] = None,
allowed_mentions: Optional[dict] = None,
message_reference: Optional[Dict[str, Any]] = None,
flags: Optional[int] = None,
) -> Dict[str, Any]:
"""Sends a message to a channel.
Returns the created message data as a dict.
"""
payload: Dict[str, Any] = {}
if content is not None: # Content is optional if embeds/components are present
payload["content"] = content
if tts:
payload["tts"] = True
if embeds:
payload["embeds"] = embeds
if components:
payload["components"] = components
if allowed_mentions:
payload["allowed_mentions"] = allowed_mentions
if flags:
payload["flags"] = flags
if message_reference:
payload["message_reference"] = message_reference
if not payload:
raise ValueError("Message must have content, embeds, or components.")
return await self.request(
"POST", f"/channels/{channel_id}/messages", payload=payload
)
async def edit_message(
self,
channel_id: str,
message_id: str,
payload: Dict[str, Any],
) -> Dict[str, Any]:
"""Edits a message in a channel."""
return await self.request(
"PATCH",
f"/channels/{channel_id}/messages/{message_id}",
payload=payload,
)
async def get_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> Dict[str, Any]:
"""Fetches a message from a channel."""
return await self.request(
"GET", f"/channels/{channel_id}/messages/{message_id}"
)
async def create_reaction(
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
) -> None:
"""Adds a reaction to a message as the current user."""
encoded = quote(emoji)
await self.request(
"PUT",
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
)
async def delete_reaction(
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
) -> None:
"""Removes the current user's reaction from a message."""
encoded = quote(emoji)
await self.request(
"DELETE",
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
)
async def get_reactions(
self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
) -> List[Dict[str, Any]]:
"""Fetches the users that reacted with a specific emoji."""
encoded = quote(emoji)
return await self.request(
"GET",
f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}",
)
async def delete_channel(
self, channel_id: str, reason: Optional[str] = None
) -> None:
"""Deletes a channel.
If the channel is a guild channel, requires the MANAGE_CHANNELS permission.
If the channel is a thread, requires the MANAGE_THREADS permission (if locked) or
be the thread creator (if not locked).
Deleting a category does not delete its child channels.
"""
custom_headers = {}
if reason:
custom_headers["X-Audit-Log-Reason"] = reason
await self.request(
"DELETE",
f"/channels/{channel_id}",
custom_headers=custom_headers if custom_headers else None,
)
async def get_channel(self, channel_id: str) -> Dict[str, Any]:
"""Fetches a channel by ID."""
return await self.request("GET", f"/channels/{channel_id}")
async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]:
"""Fetches a user object for a given user ID."""
return await self.request("GET", f"/users/{user_id}")
async def get_guild_member(
self, guild_id: "Snowflake", user_id: "Snowflake"
) -> Dict[str, Any]:
"""Returns a guild member object for the specified user."""
return await self.request("GET", f"/guilds/{guild_id}/members/{user_id}")
async def kick_member(
self, guild_id: "Snowflake", user_id: "Snowflake", reason: Optional[str] = None
) -> None:
"""Kicks a member from the guild."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
await self.request(
"DELETE",
f"/guilds/{guild_id}/members/{user_id}",
custom_headers=headers,
)
async def ban_member(
self,
guild_id: "Snowflake",
user_id: "Snowflake",
*,
delete_message_seconds: int = 0,
reason: Optional[str] = None,
) -> None:
"""Bans a member from the guild."""
payload = {}
if delete_message_seconds:
payload["delete_message_seconds"] = delete_message_seconds
headers = {"X-Audit-Log-Reason": reason} if reason else None
await self.request(
"PUT",
f"/guilds/{guild_id}/bans/{user_id}",
payload=payload if payload else None,
custom_headers=headers,
)
async def timeout_member(
self,
guild_id: "Snowflake",
user_id: "Snowflake",
*,
until: Optional[str],
reason: Optional[str] = None,
) -> Dict[str, Any]:
"""Times out a member until the given ISO8601 timestamp."""
payload = {"communication_disabled_until": until}
headers = {"X-Audit-Log-Reason": reason} if reason else None
return await self.request(
"PATCH",
f"/guilds/{guild_id}/members/{user_id}",
payload=payload,
custom_headers=headers,
)
async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
"""Returns a list of role objects for the guild."""
return await self.request("GET", f"/guilds/{guild_id}/roles")
async def get_guild(self, guild_id: "Snowflake") -> Dict[str, Any]:
"""Fetches a guild object for a given guild ID."""
return await self.request("GET", f"/guilds/{guild_id}")
# Add other methods like:
# async def get_guild(self, guild_id: str) -> Dict[str, Any]: ...
# async def create_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: ...
# etc.
# --- Application Command Endpoints ---
# Global Application Commands
async def get_global_application_commands(
self, application_id: "Snowflake", with_localizations: bool = False
) -> List["ApplicationCommand"]:
"""Fetches all global commands for your application."""
params = {"with_localizations": str(with_localizations).lower()}
data = await self.request(
"GET", f"/applications/{application_id}/commands", params=params
)
from .interactions import ApplicationCommand # Ensure constructor is available
return [ApplicationCommand(cmd_data) for cmd_data in data]
async def create_global_application_command(
self, application_id: "Snowflake", payload: Dict[str, Any]
) -> "ApplicationCommand":
"""Creates a new global command."""
data = await self.request(
"POST", f"/applications/{application_id}/commands", payload=payload
)
from .interactions import ApplicationCommand
return ApplicationCommand(data)
async def get_global_application_command(
self, application_id: "Snowflake", command_id: "Snowflake"
) -> "ApplicationCommand":
"""Fetches a specific global command."""
data = await self.request(
"GET", f"/applications/{application_id}/commands/{command_id}"
)
from .interactions import ApplicationCommand
return ApplicationCommand(data)
async def edit_global_application_command(
self,
application_id: "Snowflake",
command_id: "Snowflake",
payload: Dict[str, Any],
) -> "ApplicationCommand":
"""Edits a specific global command."""
data = await self.request(
"PATCH",
f"/applications/{application_id}/commands/{command_id}",
payload=payload,
)
from .interactions import ApplicationCommand
return ApplicationCommand(data)
async def delete_global_application_command(
self, application_id: "Snowflake", command_id: "Snowflake"
) -> None:
"""Deletes a specific global command."""
await self.request(
"DELETE", f"/applications/{application_id}/commands/{command_id}"
)
async def bulk_overwrite_global_application_commands(
self, application_id: "Snowflake", payload: List[Dict[str, Any]]
) -> List["ApplicationCommand"]:
"""Bulk overwrites all global commands for your application."""
data = await self.request(
"PUT", f"/applications/{application_id}/commands", payload=payload
)
from .interactions import ApplicationCommand
return [ApplicationCommand(cmd_data) for cmd_data in data]
# Guild Application Commands
async def get_guild_application_commands(
self,
application_id: "Snowflake",
guild_id: "Snowflake",
with_localizations: bool = False,
) -> List["ApplicationCommand"]:
"""Fetches all commands for your application for a specific guild."""
params = {"with_localizations": str(with_localizations).lower()}
data = await self.request(
"GET",
f"/applications/{application_id}/guilds/{guild_id}/commands",
params=params,
)
from .interactions import ApplicationCommand
return [ApplicationCommand(cmd_data) for cmd_data in data]
async def create_guild_application_command(
self,
application_id: "Snowflake",
guild_id: "Snowflake",
payload: Dict[str, Any],
) -> "ApplicationCommand":
"""Creates a new guild command."""
data = await self.request(
"POST",
f"/applications/{application_id}/guilds/{guild_id}/commands",
payload=payload,
)
from .interactions import ApplicationCommand
return ApplicationCommand(data)
async def get_guild_application_command(
self,
application_id: "Snowflake",
guild_id: "Snowflake",
command_id: "Snowflake",
) -> "ApplicationCommand":
"""Fetches a specific guild command."""
data = await self.request(
"GET",
f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
)
from .interactions import ApplicationCommand
return ApplicationCommand(data)
async def edit_guild_application_command(
self,
application_id: "Snowflake",
guild_id: "Snowflake",
command_id: "Snowflake",
payload: Dict[str, Any],
) -> "ApplicationCommand":
"""Edits a specific guild command."""
data = await self.request(
"PATCH",
f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
payload=payload,
)
from .interactions import ApplicationCommand
return ApplicationCommand(data)
async def delete_guild_application_command(
self,
application_id: "Snowflake",
guild_id: "Snowflake",
command_id: "Snowflake",
) -> None:
"""Deletes a specific guild command."""
await self.request(
"DELETE",
f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
)
async def bulk_overwrite_guild_application_commands(
self,
application_id: "Snowflake",
guild_id: "Snowflake",
payload: List[Dict[str, Any]],
) -> List["ApplicationCommand"]:
"""Bulk overwrites all commands for your application for a specific guild."""
data = await self.request(
"PUT",
f"/applications/{application_id}/guilds/{guild_id}/commands",
payload=payload,
)
from .interactions import ApplicationCommand
return [ApplicationCommand(cmd_data) for cmd_data in data]
# --- Interaction Response Endpoints ---
# Note: These methods return Dict[str, Any] representing the Message data.
# The caller (e.g., AppCommandHandler) will be responsible for constructing Message models
# if needed, as Message model instantiation requires a `client_instance`.
async def create_interaction_response(
self,
interaction_id: "Snowflake",
interaction_token: str,
payload: "InteractionResponsePayload",
*,
ephemeral: bool = False,
) -> None:
"""Creates a response to an Interaction.
Parameters
----------
ephemeral: bool
Ignored parameter for test compatibility.
"""
# Interaction responses do not use the bot token in the Authorization header.
# They are authenticated by the interaction_token in the URL.
await self.request(
"POST",
f"/interactions/{interaction_id}/{interaction_token}/callback",
payload=payload.to_dict(),
use_auth_header=False,
)
async def get_original_interaction_response(
self, application_id: "Snowflake", interaction_token: str
) -> Dict[str, Any]:
"""Gets the initial Interaction response."""
# This endpoint uses the bot token for auth.
return await self.request(
"GET", f"/webhooks/{application_id}/{interaction_token}/messages/@original"
)
async def edit_original_interaction_response(
self,
application_id: "Snowflake",
interaction_token: str,
payload: Dict[str, Any],
) -> Dict[str, Any]:
"""Edits the initial Interaction response."""
return await self.request(
"PATCH",
f"/webhooks/{application_id}/{interaction_token}/messages/@original",
payload=payload,
use_auth_header=False,
) # Docs imply webhook-style auth
async def delete_original_interaction_response(
self, application_id: "Snowflake", interaction_token: str
) -> None:
"""Deletes the initial Interaction response."""
await self.request(
"DELETE",
f"/webhooks/{application_id}/{interaction_token}/messages/@original",
use_auth_header=False,
) # Docs imply webhook-style auth
async def create_followup_message(
self,
application_id: "Snowflake",
interaction_token: str,
payload: Dict[str, Any],
) -> Dict[str, Any]:
"""Creates a followup message for an Interaction."""
# Followup messages are sent to a webhook endpoint.
return await self.request(
"POST",
f"/webhooks/{application_id}/{interaction_token}",
payload=payload,
use_auth_header=False,
) # Docs imply webhook-style auth
async def edit_followup_message(
self,
application_id: "Snowflake",
interaction_token: str,
message_id: "Snowflake",
payload: Dict[str, Any],
) -> Dict[str, Any]:
"""Edits a followup message for an Interaction."""
return await self.request(
"PATCH",
f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}",
payload=payload,
use_auth_header=False,
) # Docs imply webhook-style auth
async def delete_followup_message(
self,
application_id: "Snowflake",
interaction_token: str,
message_id: "Snowflake",
) -> None:
"""Deletes a followup message for an Interaction."""
await self.request(
"DELETE",
f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}",
use_auth_header=False,
)
async def trigger_typing(self, channel_id: str) -> None:
"""Sends a typing indicator to the specified channel."""
await self.request("POST", f"/channels/{channel_id}/typing")

View File

@ -0,0 +1,32 @@
"""Utility class for working with either command or app contexts."""
from __future__ import annotations
from typing import Any, Union
from .ext.commands.core import CommandContext
from .ext.app_commands.context import AppCommandContext
class HybridContext:
"""Wraps :class:`CommandContext` and :class:`AppCommandContext`.
Provides a single :meth:`send` method that proxies to ``reply`` for
prefix commands and to ``send`` for slash commands.
"""
def __init__(self, ctx: Union[CommandContext, AppCommandContext]):
self._ctx = ctx
async def send(self, *args: Any, **kwargs: Any):
if isinstance(self._ctx, AppCommandContext):
return await self._ctx.send(*args, **kwargs)
return await self._ctx.reply(*args, **kwargs)
async def edit(self, *args: Any, **kwargs: Any):
if hasattr(self._ctx, "edit"):
return await self._ctx.edit(*args, **kwargs)
raise AttributeError("Underlying context does not support editing.")
def __getattr__(self, name: str) -> Any:
return getattr(self._ctx, name)

22
disagreement/i18n.py Normal file
View File

@ -0,0 +1,22 @@
import json
from typing import Dict, Optional
_translations: Dict[str, Dict[str, str]] = {}
def set_translations(locale: str, mapping: Dict[str, str]) -> None:
"""Set translations for a locale."""
_translations[locale] = mapping
def load_translations(locale: str, file_path: str) -> None:
"""Load translations for *locale* from a JSON file."""
with open(file_path, "r", encoding="utf-8") as handle:
_translations[locale] = json.load(handle)
def translate(key: str, locale: str, *, default: Optional[str] = None) -> str:
"""Return the translated string for *key* in *locale*."""
return _translations.get(locale, {}).get(
key, default if default is not None else key
)

View File

@ -0,0 +1,572 @@
# disagreement/interactions.py
"""
Data models for Discord Interaction objects.
"""
from typing import Optional, List, Dict, Union, Any, TYPE_CHECKING
from .enums import (
ApplicationCommandType,
ApplicationCommandOptionType,
InteractionType,
InteractionCallbackType,
IntegrationType,
InteractionContextType,
ChannelType,
)
# Runtime imports for models used in this module
from .models import (
User,
Message,
Member,
Role,
Embed,
PartialChannel,
Attachment,
ActionRow,
Component,
AllowedMentions,
)
if TYPE_CHECKING:
# Import Client type only for type checking to avoid circular imports
from .client import Client
from .ui.modal import Modal
# MessageFlags, PartialAttachment can be added if/when defined
Snowflake = str
# Based on Application Command Option Choice Structure
class ApplicationCommandOptionChoice:
"""Represents a choice for an application command option."""
def __init__(self, data: dict):
self.name: str = data["name"]
self.value: Union[str, int, float] = data["value"]
self.name_localizations: Optional[Dict[str, str]] = data.get(
"name_localizations"
)
def __repr__(self) -> str:
return (
f"<ApplicationCommandOptionChoice name='{self.name}' value={self.value!r}>"
)
def to_dict(self) -> Dict[str, Any]:
payload: Dict[str, Any] = {"name": self.name, "value": self.value}
if self.name_localizations:
payload["name_localizations"] = self.name_localizations
return payload
# Based on Application Command Option Structure
class ApplicationCommandOption:
"""Represents an option for an application command."""
def __init__(self, data: dict):
self.type: ApplicationCommandOptionType = ApplicationCommandOptionType(
data["type"]
)
self.name: str = data["name"]
self.description: str = data["description"]
self.required: bool = data.get("required", False)
self.choices: Optional[List[ApplicationCommandOptionChoice]] = (
[ApplicationCommandOptionChoice(c) for c in data["choices"]]
if data.get("choices")
else None
)
self.options: Optional[List["ApplicationCommandOption"]] = (
[ApplicationCommandOption(o) for o in data["options"]]
if data.get("options")
else None
) # For subcommands/groups
self.channel_types: Optional[List[ChannelType]] = (
[ChannelType(ct) for ct in data.get("channel_types", [])]
if data.get("channel_types")
else None
)
self.min_value: Optional[Union[int, float]] = data.get("min_value")
self.max_value: Optional[Union[int, float]] = data.get("max_value")
self.min_length: Optional[int] = data.get("min_length")
self.max_length: Optional[int] = data.get("max_length")
self.autocomplete: bool = data.get("autocomplete", False)
self.name_localizations: Optional[Dict[str, str]] = data.get(
"name_localizations"
)
self.description_localizations: Optional[Dict[str, str]] = data.get(
"description_localizations"
)
def __repr__(self) -> str:
return f"<ApplicationCommandOption name='{self.name}' type={self.type!r}>"
def to_dict(self) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"type": self.type.value,
"name": self.name,
"description": self.description,
}
if self.required: # Defaults to False, only include if True
payload["required"] = self.required
if self.choices:
payload["choices"] = [c.to_dict() for c in self.choices]
if self.options: # For subcommands/groups
payload["options"] = [o.to_dict() for o in self.options]
if self.channel_types:
payload["channel_types"] = [ct.value for ct in self.channel_types]
if self.min_value is not None:
payload["min_value"] = self.min_value
if self.max_value is not None:
payload["max_value"] = self.max_value
if self.min_length is not None:
payload["min_length"] = self.min_length
if self.max_length is not None:
payload["max_length"] = self.max_length
if self.autocomplete: # Defaults to False, only include if True
payload["autocomplete"] = self.autocomplete
if self.name_localizations:
payload["name_localizations"] = self.name_localizations
if self.description_localizations:
payload["description_localizations"] = self.description_localizations
return payload
# Based on Application Command Structure
class ApplicationCommand:
"""Represents an application command."""
def __init__(self, data: dict):
self.id: Optional[Snowflake] = data.get("id")
self.type: ApplicationCommandType = ApplicationCommandType(
data.get("type", 1)
) # Default to CHAT_INPUT
self.application_id: Optional[Snowflake] = data.get("application_id")
self.guild_id: Optional[Snowflake] = data.get("guild_id")
self.name: str = data["name"]
self.description: str = data.get(
"description", ""
) # Empty for USER/MESSAGE commands
self.options: Optional[List[ApplicationCommandOption]] = (
[ApplicationCommandOption(o) for o in data["options"]]
if data.get("options")
else None
)
self.default_member_permissions: Optional[str] = data.get(
"default_member_permissions"
)
self.dm_permission: Optional[bool] = data.get("dm_permission") # Deprecated
self.nsfw: bool = data.get("nsfw", False)
self.version: Optional[Snowflake] = data.get("version")
self.name_localizations: Optional[Dict[str, str]] = data.get(
"name_localizations"
)
self.description_localizations: Optional[Dict[str, str]] = data.get(
"description_localizations"
)
self.integration_types: Optional[List[IntegrationType]] = (
[IntegrationType(it) for it in data["integration_types"]]
if data.get("integration_types")
else None
)
self.contexts: Optional[List[InteractionContextType]] = (
[InteractionContextType(c) for c in data["contexts"]]
if data.get("contexts")
else None
)
def __repr__(self) -> str:
return (
f"<ApplicationCommand id='{self.id}' name='{self.name}' type={self.type!r}>"
)
# Based on Interaction Object's Resolved Data Structure
class ResolvedData:
"""Represents resolved data for an interaction."""
def __init__(
self, data: dict, client_instance: Optional["Client"] = None
): # client_instance for model hydration
# Models are now imported in TYPE_CHECKING block
users_data = data.get("users", {})
self.users: Dict[Snowflake, "User"] = {
uid: User(udata) for uid, udata in users_data.items()
}
self.members: Dict[Snowflake, "Member"] = {}
for mid, mdata in data.get("members", {}).items():
member_payload = dict(mdata)
member_payload.setdefault("id", mid)
if "user" not in member_payload and mid in users_data:
member_payload["user"] = users_data[mid]
self.members[mid] = Member(member_payload, client_instance=client_instance)
self.roles: Dict[Snowflake, "Role"] = {
rid: Role(rdata) for rid, rdata in data.get("roles", {}).items()
}
self.channels: Dict[Snowflake, "PartialChannel"] = {
cid: PartialChannel(cdata, client_instance=client_instance)
for cid, cdata in data.get("channels", {}).items()
}
self.messages: Dict[Snowflake, "Message"] = (
{
mid: Message(mdata, client_instance=client_instance) for mid, mdata in data.get("messages", {}).items() # type: ignore[misc]
}
if client_instance
else {}
) # Only hydrate if client is available
self.attachments: Dict[Snowflake, "Attachment"] = {
aid: Attachment(adata) for aid, adata in data.get("attachments", {}).items()
}
def __repr__(self) -> str:
return f"<ResolvedData users={len(self.users)} members={len(self.members)} roles={len(self.roles)} channels={len(self.channels)} messages={len(self.messages)} attachments={len(self.attachments)}>"
# Based on Interaction Object's Data Structure (for Application Commands)
class InteractionData:
"""Represents the data payload for an interaction."""
def __init__(self, data: dict, client_instance: Optional["Client"] = None):
self.id: Optional[Snowflake] = data.get("id") # Command ID
self.name: Optional[str] = data.get("name") # Command name
self.type: Optional[ApplicationCommandType] = (
ApplicationCommandType(data["type"]) if data.get("type") else None
)
self.resolved: Optional[ResolvedData] = (
ResolvedData(data["resolved"], client_instance=client_instance)
if data.get("resolved")
else None
)
# For CHAT_INPUT, this is List[ApplicationCommandInteractionDataOption]
# For USER/MESSAGE, this is not present or different.
# For now, storing as raw list of dicts. Parsing can happen in handler.
self.options: Optional[List[Dict[str, Any]]] = data.get("options")
# For message components
self.custom_id: Optional[str] = data.get("custom_id")
self.component_type: Optional[int] = data.get("component_type")
self.values: Optional[List[str]] = data.get("values")
self.guild_id: Optional[Snowflake] = data.get("guild_id")
self.target_id: Optional[Snowflake] = data.get(
"target_id"
) # For USER/MESSAGE commands
def __repr__(self) -> str:
return f"<InteractionData id='{self.id}' name='{self.name}' type={self.type!r}>"
# Based on Interaction Object Structure
class Interaction:
"""Represents an interaction from Discord."""
def __init__(self, data: dict, client_instance: "Client"):
self._client: "Client" = client_instance
self.id: Snowflake = data["id"]
self.application_id: Snowflake = data["application_id"]
self.type: InteractionType = InteractionType(data["type"])
self.data: Optional[InteractionData] = (
InteractionData(data["data"], client_instance=client_instance)
if data.get("data")
else None
)
self.guild_id: Optional[Snowflake] = data.get("guild_id")
self.channel_id: Optional[Snowflake] = data.get(
"channel_id"
) # Will be present on command invocations
member_data = data.get("member")
user_data_from_member = (
member_data.get("user") if isinstance(member_data, dict) else None
)
self.member: Optional["Member"] = (
Member(member_data, client_instance=self._client) if member_data else None
)
# User object is included within member if in guild, otherwise it's top-level
# If self.member was successfully hydrated, its .user attribute should be preferred if it exists.
# However, Member.__init__ handles setting User attributes.
# The primary source for User is data.get("user") or member_data.get("user").
if data.get("user"):
self.user: Optional["User"] = User(data["user"])
elif user_data_from_member:
self.user: Optional["User"] = User(user_data_from_member)
elif (
self.member
): # If member was hydrated and has user attributes (e.g. from Member(User) inheritance)
# This assumes Member correctly populates its User parts.
self.user: Optional["User"] = self.member # Member is a User subclass
else:
self.user: Optional["User"] = None
self.token: str = data["token"] # For responding to the interaction
self.version: int = data["version"]
self.message: Optional["Message"] = (
Message(data["message"], client_instance=client_instance)
if data.get("message")
else None
) # For component interactions
self.app_permissions: Optional[str] = data.get(
"app_permissions"
) # Bitwise set of permissions the app has in the source channel
self.locale: Optional[str] = data.get(
"locale"
) # Selected language of the invoking user
self.guild_locale: Optional[str] = data.get(
"guild_locale"
) # Guild's preferred language
self.response = InteractionResponse(self)
async def respond(
self,
content: Optional[str] = None,
*,
embed: Optional[Embed] = None,
embeds: Optional[List[Embed]] = None,
components: Optional[List[ActionRow]] = None,
ephemeral: bool = False,
tts: bool = False,
) -> None:
"""|coro|
Responds to this interaction.
Parameters:
content (Optional[str]): The content of the message.
embed (Optional[Embed]): A single embed to send.
embeds (Optional[List[Embed]]): A list of embeds to send.
components (Optional[List[ActionRow]]): A list of ActionRow components.
ephemeral (bool): Whether the response should be ephemeral (only visible to the user).
tts (bool): Whether the message should be sent with text-to-speech.
"""
if embed and embeds:
raise ValueError("Cannot provide both embed and embeds.")
data: Dict[str, Any] = {}
if tts:
data["tts"] = True
if content:
data["content"] = content
if embed:
data["embeds"] = [embed.to_dict()]
elif embeds:
data["embeds"] = [e.to_dict() for e in embeds]
if components:
data["components"] = [c.to_dict() for c in components]
if ephemeral:
data["flags"] = 1 << 6 # EPHEMERAL flag
payload = InteractionResponsePayload(
type=InteractionCallbackType.CHANNEL_MESSAGE_WITH_SOURCE,
data=InteractionCallbackData(data),
)
await self._client._http.create_interaction_response(
interaction_id=self.id,
interaction_token=self.token,
payload=payload,
)
async def respond_modal(self, modal: "Modal") -> None:
"""|coro| Send a modal in response to this interaction."""
from typing import Any, cast
payload = {
"type": InteractionCallbackType.MODAL.value,
"data": modal.to_dict(),
}
await self._client._http.create_interaction_response(
interaction_id=self.id,
interaction_token=self.token,
payload=cast(Any, payload),
)
async def edit(
self,
content: Optional[str] = None,
*,
embed: Optional[Embed] = None,
embeds: Optional[List[Embed]] = None,
components: Optional[List[ActionRow]] = None,
attachments: Optional[List[Any]] = None,
allowed_mentions: Optional[Dict[str, Any]] = None,
) -> None:
"""|coro|
Edits the original response to this interaction.
If the interaction is from a component, this will acknowledge the
interaction and update the message in one operation.
Parameters:
content (Optional[str]): The new message content.
embed (Optional[Embed]): A single embed to send. Ignored if
``embeds`` is provided.
embeds (Optional[List[Embed]]): A list of embeds to send.
components (Optional[List[ActionRow]]): Updated components for the
message.
attachments (Optional[List[Any]]): Attachments to include with the
message.
allowed_mentions (Optional[Dict[str, Any]]): Controls mentions in the
message.
"""
if embed and embeds:
raise ValueError("Cannot provide both embed and embeds.")
payload_data: Dict[str, Any] = {}
if content is not None:
payload_data["content"] = content
if embed:
payload_data["embeds"] = [embed.to_dict()]
elif embeds is not None:
payload_data["embeds"] = [e.to_dict() for e in embeds]
if components is not None:
payload_data["components"] = [c.to_dict() for c in components]
if attachments is not None:
payload_data["attachments"] = [
a.to_dict() if hasattr(a, "to_dict") else a for a in attachments
]
if allowed_mentions is not None:
payload_data["allowed_mentions"] = allowed_mentions
if self.type == InteractionType.MESSAGE_COMPONENT:
# For component interactions, we send an UPDATE_MESSAGE response
# to acknowledge the interaction and edit the message simultaneously.
payload = InteractionResponsePayload(
type=InteractionCallbackType.UPDATE_MESSAGE,
data=InteractionCallbackData(payload_data),
)
await self._client._http.create_interaction_response(
self.id, self.token, payload
)
else:
# For other interaction types (like an initial slash command response),
# we edit the original response via the webhook endpoint.
await self._client._http.edit_original_interaction_response(
application_id=self.application_id,
interaction_token=self.token,
payload=payload_data,
)
def __repr__(self) -> str:
return f"<Interaction id='{self.id}' type={self.type!r}>"
class InteractionResponse:
"""Helper for sending responses for an :class:`Interaction`."""
def __init__(self, interaction: "Interaction") -> None:
self._interaction = interaction
async def send_modal(self, modal: "Modal") -> None:
"""Sends a modal response."""
payload = InteractionResponsePayload(
type=InteractionCallbackType.MODAL,
data=InteractionCallbackData(modal.to_dict()),
)
await self._interaction._client._http.create_interaction_response(
self._interaction.id,
self._interaction.token,
payload,
)
# Based on Interaction Response Object's Data Structure
class InteractionCallbackData:
"""Data for an interaction response."""
def __init__(self, data: dict):
self.tts: Optional[bool] = data.get("tts")
self.content: Optional[str] = data.get("content")
self.embeds: Optional[List[Embed]] = (
[Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None
)
self.allowed_mentions: Optional[AllowedMentions] = (
AllowedMentions(data["allowed_mentions"])
if data.get("allowed_mentions")
else None
)
self.flags: Optional[int] = data.get("flags") # MessageFlags enum could be used
from .components import component_factory
self.components: Optional[List[Component]] = (
[component_factory(c) for c in data.get("components", [])]
if data.get("components")
else None
)
self.attachments: Optional[List[Attachment]] = (
[Attachment(a) for a in data.get("attachments", [])]
if data.get("attachments")
else None
)
def to_dict(self) -> dict:
# Helper to convert to dict for sending to Discord API
payload = {}
if self.tts is not None:
payload["tts"] = self.tts
if self.content is not None:
payload["content"] = self.content
if self.embeds is not None:
payload["embeds"] = [e.to_dict() for e in self.embeds]
if self.allowed_mentions is not None:
payload["allowed_mentions"] = self.allowed_mentions.to_dict()
if self.flags is not None:
payload["flags"] = self.flags
if self.components is not None:
payload["components"] = [c.to_dict() for c in self.components]
if self.attachments is not None:
payload["attachments"] = [a.to_dict() for a in self.attachments]
return payload
def __repr__(self) -> str:
return f"<InteractionCallbackData content='{self.content[:20] if self.content else None}'>"
# Based on Interaction Response Object Structure
class InteractionResponsePayload:
"""Payload for responding to an interaction."""
def __init__(
self,
type: InteractionCallbackType,
data: Optional[InteractionCallbackData] = None,
):
self.type: InteractionCallbackType = type
self.data: Optional[InteractionCallbackData] = data
def to_dict(self) -> Dict[str, Any]:
payload: Dict[str, Any] = {"type": self.type.value}
if self.data:
payload["data"] = self.data.to_dict()
return payload
def __repr__(self) -> str:
return f"<InteractionResponsePayload type={self.type!r}>"

View File

@ -0,0 +1,26 @@
import logging
from typing import Optional
def setup_logging(level: int, file: Optional[str] = None) -> None:
"""Configure logging for the library.
Parameters
----------
level:
Logging level from the :mod:`logging` module.
file:
Optional file path to write logs to. If ``None``, logs are sent to
standard output.
"""
handlers: list[logging.Handler] = []
if file is None:
handlers.append(logging.StreamHandler())
else:
handlers.append(logging.FileHandler(file))
logging.basicConfig(
level=level,
handlers=handlers,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)

1642
disagreement/models.py Normal file

File diff suppressed because it is too large Load Diff

109
disagreement/oauth.py Normal file
View File

@ -0,0 +1,109 @@
"""OAuth2 utilities."""
from __future__ import annotations
import aiohttp
from typing import List, Optional, Dict, Any, Union
from urllib.parse import urlencode
from .errors import HTTPException
def build_authorization_url(
client_id: str,
redirect_uri: str,
scope: Union[str, List[str]],
*,
state: Optional[str] = None,
response_type: str = "code",
prompt: Optional[str] = None,
) -> str:
"""Return the Discord OAuth2 authorization URL."""
if isinstance(scope, list):
scope = " ".join(scope)
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": response_type,
"scope": scope,
}
if state is not None:
params["state"] = state
if prompt is not None:
params["prompt"] = prompt
return "https://discord.com/oauth2/authorize?" + urlencode(params)
async def exchange_code_for_token(
client_id: str,
client_secret: str,
code: str,
redirect_uri: str,
*,
session: Optional[aiohttp.ClientSession] = None,
) -> Dict[str, Any]:
"""Exchange an authorization code for an access token."""
close = False
if session is None:
session = aiohttp.ClientSession()
close = True
data = {
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
}
resp = await session.post(
"https://discord.com/api/v10/oauth2/token",
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
try:
json_data = await resp.json()
if resp.status != 200:
raise HTTPException(resp, message="OAuth token exchange failed")
finally:
if close:
await session.close()
return json_data
async def refresh_access_token(
refresh_token: str,
client_id: str,
client_secret: str,
*,
session: Optional[aiohttp.ClientSession] = None,
) -> Dict[str, Any]:
"""Refresh an access token using a refresh token."""
close = False
if session is None:
session = aiohttp.ClientSession()
close = True
data = {
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "refresh_token",
"refresh_token": refresh_token,
}
resp = await session.post(
"https://discord.com/api/v10/oauth2/token",
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
try:
json_data = await resp.json()
if resp.status != 200:
raise HTTPException(resp, message="OAuth token refresh failed")
finally:
if close:
await session.close()
return json_data

View File

@ -0,0 +1,99 @@
"""Utility helpers for working with Discord permission bitmasks."""
from __future__ import annotations
from enum import IntFlag
from typing import Iterable, List
class Permissions(IntFlag):
"""Discord guild and channel permissions."""
CREATE_INSTANT_INVITE = 1 << 0
KICK_MEMBERS = 1 << 1
BAN_MEMBERS = 1 << 2
ADMINISTRATOR = 1 << 3
MANAGE_CHANNELS = 1 << 4
MANAGE_GUILD = 1 << 5
ADD_REACTIONS = 1 << 6
VIEW_AUDIT_LOG = 1 << 7
PRIORITY_SPEAKER = 1 << 8
STREAM = 1 << 9
VIEW_CHANNEL = 1 << 10
SEND_MESSAGES = 1 << 11
SEND_TTS_MESSAGES = 1 << 12
MANAGE_MESSAGES = 1 << 13
EMBED_LINKS = 1 << 14
ATTACH_FILES = 1 << 15
READ_MESSAGE_HISTORY = 1 << 16
MENTION_EVERYONE = 1 << 17
USE_EXTERNAL_EMOJIS = 1 << 18
VIEW_GUILD_INSIGHTS = 1 << 19
CONNECT = 1 << 20
SPEAK = 1 << 21
MUTE_MEMBERS = 1 << 22
DEAFEN_MEMBERS = 1 << 23
MOVE_MEMBERS = 1 << 24
USE_VAD = 1 << 25
CHANGE_NICKNAME = 1 << 26
MANAGE_NICKNAMES = 1 << 27
MANAGE_ROLES = 1 << 28
MANAGE_WEBHOOKS = 1 << 29
MANAGE_GUILD_EXPRESSIONS = 1 << 30
USE_APPLICATION_COMMANDS = 1 << 31
REQUEST_TO_SPEAK = 1 << 32
MANAGE_EVENTS = 1 << 33
MANAGE_THREADS = 1 << 34
CREATE_PUBLIC_THREADS = 1 << 35
CREATE_PRIVATE_THREADS = 1 << 36
USE_EXTERNAL_STICKERS = 1 << 37
SEND_MESSAGES_IN_THREADS = 1 << 38
USE_EMBEDDED_ACTIVITIES = 1 << 39
MODERATE_MEMBERS = 1 << 40
VIEW_CREATOR_MONETIZATION_ANALYTICS = 1 << 41
USE_SOUNDBOARD = 1 << 42
CREATE_GUILD_EXPRESSIONS = 1 << 43
CREATE_EVENTS = 1 << 44
USE_EXTERNAL_SOUNDS = 1 << 45
SEND_VOICE_MESSAGES = 1 << 46
def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int:
"""Return a combined integer value for multiple permissions."""
value = 0
for perm in perms:
if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)):
value |= permissions_value(*perm)
else:
value |= int(perm)
return value
def has_permissions(
current: int | str | Permissions,
*perms: Permissions | int | Iterable[Permissions | int],
) -> bool:
"""Return ``True`` if ``current`` includes all ``perms``."""
current_val = int(current)
needed = permissions_value(*perms)
return (current_val & needed) == needed
def missing_permissions(
current: int | str | Permissions,
*perms: Permissions | int | Iterable[Permissions | int],
) -> List[Permissions]:
"""Return the subset of ``perms`` not present in ``current``."""
current_val = int(current)
missing: List[Permissions] = []
for perm in perms:
if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)):
missing.extend(missing_permissions(current_val, *perm))
else:
perm_val = int(perm)
if not current_val & perm_val:
missing.append(Permissions(perm_val))
return missing

View File

@ -0,0 +1,75 @@
"""Asynchronous rate limiter for Discord HTTP requests."""
from __future__ import annotations
import asyncio
import time
from typing import Dict, Mapping
class _Bucket:
def __init__(self) -> None:
self.remaining: int = 1
self.reset_at: float = 0.0
self.lock = asyncio.Lock()
class RateLimiter:
"""Rate limiter implementing per-route buckets and a global queue."""
def __init__(self) -> None:
self._buckets: Dict[str, _Bucket] = {}
self._global_event = asyncio.Event()
self._global_event.set()
def _get_bucket(self, route: str) -> _Bucket:
bucket = self._buckets.get(route)
if bucket is None:
bucket = _Bucket()
self._buckets[route] = bucket
return bucket
async def acquire(self, route: str) -> _Bucket:
bucket = self._get_bucket(route)
while True:
await self._global_event.wait()
async with bucket.lock:
now = time.monotonic()
if bucket.remaining <= 0 and now < bucket.reset_at:
await asyncio.sleep(bucket.reset_at - now)
continue
if bucket.remaining > 0:
bucket.remaining -= 1
return bucket
def release(self, route: str, headers: Mapping[str, str]) -> None:
bucket = self._get_bucket(route)
try:
remaining = int(headers.get("X-RateLimit-Remaining", bucket.remaining))
reset_after = float(headers.get("X-RateLimit-Reset-After", "0"))
bucket.remaining = remaining
bucket.reset_at = time.monotonic() + reset_after
except ValueError:
pass
if headers.get("X-RateLimit-Global", "false").lower() == "true":
retry_after = float(headers.get("Retry-After", "0"))
self._global_event.clear()
asyncio.create_task(self._lift_global(retry_after))
async def handle_rate_limit(
self, route: str, retry_after: float, is_global: bool
) -> None:
bucket = self._get_bucket(route)
bucket.remaining = 0
bucket.reset_at = time.monotonic() + retry_after
if is_global:
self._global_event.clear()
await asyncio.sleep(retry_after)
self._global_event.set()
else:
await asyncio.sleep(retry_after)
async def _lift_global(self, delay: float) -> None:
await asyncio.sleep(delay)
self._global_event.set()

View File

@ -0,0 +1,65 @@
# disagreement/shard_manager.py
"""Sharding utilities for managing multiple gateway connections."""
from __future__ import annotations
import asyncio
from typing import List, TYPE_CHECKING
from .gateway import GatewayClient
if TYPE_CHECKING: # pragma: no cover - for type checking only
from .client import Client
class Shard:
"""Represents a single gateway shard."""
def __init__(self, shard_id: int, shard_count: int, gateway: GatewayClient) -> None:
self.id: int = shard_id
self.count: int = shard_count
self.gateway: GatewayClient = gateway
async def connect(self) -> None:
"""Connects this shard's gateway."""
await self.gateway.connect()
async def close(self) -> None:
"""Closes this shard's gateway."""
await self.gateway.close()
class ShardManager:
"""Manages multiple :class:`Shard` instances."""
def __init__(self, client: "Client", shard_count: int) -> None:
self.client: "Client" = client
self.shard_count: int = shard_count
self.shards: List[Shard] = []
def _create_shards(self) -> None:
if self.shards:
return
for shard_id in range(self.shard_count):
gateway = GatewayClient(
http_client=self.client._http,
event_dispatcher=self.client._event_dispatcher,
token=self.client.token,
intents=self.client.intents,
client_instance=self.client,
verbose=self.client.verbose,
shard_id=shard_id,
shard_count=self.shard_count,
)
self.shards.append(Shard(shard_id, self.shard_count, gateway))
async def start(self) -> None:
"""Starts all shards."""
self._create_shards()
await asyncio.gather(*(s.connect() for s in self.shards))
async def close(self) -> None:
"""Closes all shards."""
await asyncio.gather(*(s.close() for s in self.shards))
self.shards.clear()

42
disagreement/typing.py Normal file
View File

@ -0,0 +1,42 @@
import asyncio
from contextlib import suppress
from typing import Optional, TYPE_CHECKING
from .errors import DisagreementException
if TYPE_CHECKING:
from .client import Client
if __name__ == "typing":
# For direct module execution testing
pass
class Typing:
"""Async context manager for Discord typing indicator."""
def __init__(self, client: "Client", channel_id: str) -> None:
self._client = client
self._channel_id = channel_id
self._task: Optional[asyncio.Task] = None
async def _run(self) -> None:
try:
while True:
await self._client._http.trigger_typing(self._channel_id)
await asyncio.sleep(5)
except asyncio.CancelledError:
pass
async def __aenter__(self) -> "Typing":
if self._client._closed:
raise DisagreementException("Client is closed.")
await self._client._http.trigger_typing(self._channel_id)
self._task = asyncio.create_task(self._run())
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
if self._task:
self._task.cancel()
with suppress(asyncio.CancelledError):
await self._task

View File

@ -0,0 +1,17 @@
from .view import View
from .item import Item
from .button import Button, button
from .select import Select, select
from .modal import Modal, TextInput, text_input
__all__ = [
"View",
"Item",
"Button",
"button",
"Select",
"select",
"Modal",
"TextInput",
"text_input",
]

99
disagreement/ui/button.py Normal file
View File

@ -0,0 +1,99 @@
from __future__ import annotations
import asyncio
from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING
from .item import Item
from ..enums import ComponentType, ButtonStyle
from ..models import PartialEmoji, to_partial_emoji
if TYPE_CHECKING:
from ..interactions import Interaction
class Button(Item):
"""Represents a button component in a View.
Args:
style (ButtonStyle): The style of the button.
label (Optional[str]): The text that appears on the button.
emoji (Optional[str | PartialEmoji]): The emoji that appears on the button.
custom_id (Optional[str]): The developer-defined identifier for the button.
url (Optional[str]): The URL for the button.
disabled (bool): Whether the button is disabled.
row (Optional[int]): The row the button should be placed in, from 0 to 4.
"""
def __init__(
self,
*,
style: ButtonStyle = ButtonStyle.SECONDARY,
label: Optional[str] = None,
emoji: Optional[str | PartialEmoji] = None,
custom_id: Optional[str] = None,
url: Optional[str] = None,
disabled: bool = False,
row: Optional[int] = None,
):
super().__init__(type=ComponentType.BUTTON)
if not label and not emoji:
raise ValueError("A button must have a label and/or an emoji.")
if url and custom_id:
raise ValueError("A button cannot have both a URL and a custom_id.")
self.style = style
self.label = label
self.emoji = to_partial_emoji(emoji)
self.custom_id = custom_id
self.url = url
self.disabled = disabled
self._row = row
def to_dict(self) -> dict[str, Any]:
"""Converts the button to a dictionary that can be sent to Discord."""
payload = {
"type": ComponentType.BUTTON.value,
"style": self.style.value,
"disabled": self.disabled,
}
if self.label:
payload["label"] = self.label
if self.emoji:
payload["emoji"] = self.emoji.to_dict()
if self.url:
payload["url"] = self.url
if self.custom_id:
payload["custom_id"] = self.custom_id
return payload
def button(
*,
label: Optional[str] = None,
custom_id: Optional[str] = None,
style: ButtonStyle = ButtonStyle.SECONDARY,
emoji: Optional[str | PartialEmoji] = None,
url: Optional[str] = None,
disabled: bool = False,
row: Optional[int] = None,
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Button]:
"""A decorator to create a button in a View."""
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Button:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Button callback must be a coroutine function.")
item = Button(
label=label,
custom_id=custom_id,
style=style,
emoji=emoji,
url=url,
disabled=disabled,
row=row,
)
item.callback = func
return item
return decorator

38
disagreement/ui/item.py Normal file
View File

@ -0,0 +1,38 @@
from __future__ import annotations
from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING
from ..models import Component
if TYPE_CHECKING:
from .view import View
from ..interactions import Interaction
class Item(Component):
"""Represents a UI item that can be placed in a View.
This is a base class and is not meant to be used directly.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._view: Optional[View] = None
self._row: Optional[int] = None
# This is the callback associated with this item.
self.callback: Optional[
Callable[["View", Interaction], Coroutine[Any, Any, Any]]
] = None
@property
def view(self) -> Optional[View]:
return self._view
@property
def row(self) -> Optional[int]:
return self._row
def _refresh_from_data(self, data: dict[str, Any]) -> None:
# This is used to update the item's state from incoming interaction data.
# For example, a button's disabled state could be updated here.
pass

132
disagreement/ui/modal.py Normal file
View File

@ -0,0 +1,132 @@
from __future__ import annotations
from typing import Any, Callable, Coroutine, Optional, List, TYPE_CHECKING
import asyncio
from .item import Item
from .view import View
from ..enums import ComponentType, TextInputStyle
from ..models import ActionRow
if TYPE_CHECKING: # pragma: no cover - for type hints only
from ..interactions import Interaction
class TextInput(Item):
"""Represents a text input component inside a modal."""
def __init__(
self,
*,
label: str,
custom_id: Optional[str] = None,
style: TextInputStyle = TextInputStyle.SHORT,
placeholder: Optional[str] = None,
required: bool = True,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
row: Optional[int] = None,
) -> None:
super().__init__(type=ComponentType.TEXT_INPUT)
self.label = label
self.custom_id = custom_id
self.style = style
self.placeholder = placeholder
self.required = required
self.min_length = min_length
self.max_length = max_length
self._row = row
def to_dict(self) -> dict[str, Any]:
payload = {
"type": ComponentType.TEXT_INPUT.value,
"style": self.style.value,
"label": self.label,
"required": self.required,
}
if self.custom_id:
payload["custom_id"] = self.custom_id
if self.placeholder:
payload["placeholder"] = self.placeholder
if self.min_length is not None:
payload["min_length"] = self.min_length
if self.max_length is not None:
payload["max_length"] = self.max_length
return payload
def text_input(
*,
label: str,
custom_id: Optional[str] = None,
style: TextInputStyle = TextInputStyle.SHORT,
placeholder: Optional[str] = None,
required: bool = True,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
row: Optional[int] = None,
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], TextInput]:
"""Decorator to define a text input callback inside a :class:`Modal`."""
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> TextInput:
if not asyncio.iscoroutinefunction(func):
raise TypeError("TextInput callback must be a coroutine function.")
item = TextInput(
label=label,
custom_id=custom_id,
style=style,
placeholder=placeholder,
required=required,
min_length=min_length,
max_length=max_length,
row=row,
)
item.callback = func
return item
return decorator
class Modal:
"""Represents a modal dialog."""
def __init__(self, *, title: str, custom_id: str) -> None:
self.title = title
self.custom_id = custom_id
self._children: List[TextInput] = []
for item in self.__class__.__dict__.values():
if isinstance(item, TextInput):
self.add_item(item)
@property
def children(self) -> List[TextInput]:
return self._children
def add_item(self, item: TextInput) -> None:
if not isinstance(item, TextInput):
raise TypeError("Only TextInput items can be added to a Modal.")
if len(self._children) >= 5:
raise ValueError("A modal can only have up to 5 text inputs.")
item._view = None # Not part of a view but reuse item base
self._children.append(item)
def to_components(self) -> List[ActionRow]:
rows: List[ActionRow] = []
for child in self.children:
row = ActionRow(components=[child])
rows.append(row)
return rows
def to_dict(self) -> dict[str, Any]:
return {
"title": self.title,
"custom_id": self.custom_id,
"components": [r.to_dict() for r in self.to_components()],
}
async def callback(
self, interaction: Interaction
) -> None: # pragma: no cover - default
pass

92
disagreement/ui/select.py Normal file
View File

@ -0,0 +1,92 @@
from __future__ import annotations
import asyncio
from typing import Any, Callable, Coroutine, List, Optional, TYPE_CHECKING
from .item import Item
from ..enums import ComponentType
from ..models import SelectOption
if TYPE_CHECKING:
from ..interactions import Interaction
class Select(Item):
"""Represents a select menu component in a View.
Args:
custom_id (str): The developer-defined identifier for the select menu.
options (List[SelectOption]): The choices in the select menu.
placeholder (Optional[str]): The placeholder text that is shown if nothing is selected.
min_values (int): The minimum number of items that must be chosen.
max_values (int): The maximum number of items that can be chosen.
disabled (bool): Whether the select menu is disabled.
row (Optional[int]): The row the select menu should be placed in, from 0 to 4.
"""
def __init__(
self,
*,
custom_id: str,
options: List[SelectOption],
placeholder: Optional[str] = None,
min_values: int = 1,
max_values: int = 1,
disabled: bool = False,
row: Optional[int] = None,
):
super().__init__(type=ComponentType.STRING_SELECT)
self.custom_id = custom_id
self.options = options
self.placeholder = placeholder
self.min_values = min_values
self.max_values = max_values
self.disabled = disabled
self._row = row
def to_dict(self) -> dict[str, Any]:
"""Converts the select menu to a dictionary that can be sent to Discord."""
payload = {
"type": ComponentType.STRING_SELECT.value,
"custom_id": self.custom_id,
"options": [option.to_dict() for option in self.options],
"disabled": self.disabled,
}
if self.placeholder:
payload["placeholder"] = self.placeholder
if self.min_values is not None:
payload["min_values"] = self.min_values
if self.max_values is not None:
payload["max_values"] = self.max_values
return payload
def select(
*,
custom_id: str,
options: List[SelectOption],
placeholder: Optional[str] = None,
min_values: int = 1,
max_values: int = 1,
disabled: bool = False,
row: Optional[int] = None,
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Select]:
"""A decorator to create a select menu in a View."""
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Select:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Select callback must be a coroutine function.")
item = Select(
custom_id=custom_id,
options=options,
placeholder=placeholder,
min_values=min_values,
max_values=max_values,
disabled=disabled,
row=row,
)
item.callback = func
return item
return decorator

165
disagreement/ui/view.py Normal file
View File

@ -0,0 +1,165 @@
from __future__ import annotations
import asyncio
import uuid
from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING
from ..models import ActionRow
from .item import Item
if TYPE_CHECKING:
from ..client import Client
from ..interactions import Interaction
class View:
"""Represents a container for UI components that can be sent with a message.
Args:
timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out.
Defaults to 180.
"""
def __init__(self, *, timeout: Optional[float] = 180.0):
self.timeout = timeout
self.id = str(uuid.uuid4())
self.__children: List[Item] = []
self.__stopped = asyncio.Event()
self._client: Optional[Client] = None
self._message_id: Optional[str] = None
for item in self.__class__.__dict__.values():
if isinstance(item, Item):
self.add_item(item)
@property
def children(self) -> List[Item]:
return self.__children
def add_item(self, item: Item):
"""Adds an item to the view."""
if not isinstance(item, Item):
raise TypeError("Only instances of 'Item' can be added to a View.")
if len(self.__children) >= 25:
raise ValueError("A view can only have a maximum of 25 components.")
item._view = self
self.__children.append(item)
@property
def message_id(self) -> Optional[str]:
return self._message_id
@message_id.setter
def message_id(self, value: str):
self._message_id = value
def to_components(self) -> List[ActionRow]:
"""Converts the view's children into a list of ActionRow components.
This retains the original, simple layout behaviour where each item is
placed in its own :class:`ActionRow` to ensure backward compatibility.
"""
rows: List[ActionRow] = []
for item in self.children:
if item.custom_id is None:
item.custom_id = (
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
)
rows.append(ActionRow(components=[item]))
return rows
def layout_components_advanced(self) -> List[ActionRow]:
"""Group compatible components into rows following Discord rules."""
rows: List[ActionRow] = []
for item in self.children:
if item.custom_id is None:
item.custom_id = (
f"{self.id}:{item.__class__.__name__}:{len(self.__children)}"
)
target_row = item.row
if target_row is not None:
if not 0 <= target_row <= 4:
raise ValueError("Row index must be between 0 and 4.")
while len(rows) <= target_row:
if len(rows) >= 5:
raise ValueError("A view can have at most 5 action rows.")
rows.append(ActionRow())
rows[target_row].add_component(item)
continue
placed = False
for row in rows:
try:
row.add_component(item)
placed = True
break
except ValueError:
continue
if not placed:
if len(rows) >= 5:
raise ValueError("A view can have at most 5 action rows.")
new_row = ActionRow([item])
rows.append(new_row)
return rows
def to_components_payload(self) -> List[Dict[str, Any]]:
"""Converts the view's children into a list of component dictionaries
that can be sent to the Discord API."""
return [row.to_dict() for row in self.to_components()]
async def _dispatch(self, interaction: Interaction):
"""Called by the client to dispatch an interaction to the correct item."""
if self.timeout is not None:
self.__stopped.set() # Reset the timeout on each interaction
self.__stopped.clear()
if interaction.data:
custom_id = interaction.data.custom_id
for child in self.children:
if child.custom_id == custom_id:
if child.callback:
await child.callback(self, interaction)
break
async def wait(self) -> bool:
"""Waits until the view has stopped interacting."""
return await self.__stopped.wait()
def stop(self):
"""Stops the view from listening to interactions."""
if not self.__stopped.is_set():
self.__stopped.set()
async def on_timeout(self):
"""Called when the view times out."""
pass # User can override this
async def _start(self, client: Client):
"""Starts the view's internal listener."""
self._client = client
if self.timeout is not None:
asyncio.create_task(self._timeout_task())
async def _timeout_task(self):
"""The task that waits for the timeout and then stops the view."""
try:
await asyncio.wait_for(self.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
self.stop()
await self.on_timeout()
if self._client and self._message_id:
# Remove the view from the client's listeners
self._client._views.pop(self._message_id, None)

View File

@ -0,0 +1,120 @@
# disagreement/voice_client.py
"""Voice gateway and UDP audio client."""
from __future__ import annotations
import asyncio
import contextlib
import socket
from typing import Optional, Sequence
import aiohttp
class VoiceClient:
"""Handles the Discord voice WebSocket connection and UDP streaming."""
def __init__(
self,
endpoint: str,
session_id: str,
token: str,
guild_id: int,
user_id: int,
*,
ws=None,
udp: Optional[socket.socket] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
verbose: bool = False,
) -> None:
self.endpoint = endpoint
self.session_id = session_id
self.token = token
self.guild_id = str(guild_id)
self.user_id = str(user_id)
self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws
self._udp = udp
self._session: Optional[aiohttp.ClientSession] = None
self._heartbeat_task: Optional[asyncio.Task] = None
self._heartbeat_interval: Optional[float] = None
self._loop = loop or asyncio.get_event_loop()
self.verbose = verbose
self.ssrc: Optional[int] = None
self.secret_key: Optional[Sequence[int]] = None
self._server_ip: Optional[str] = None
self._server_port: Optional[int] = None
async def connect(self) -> None:
if self._ws is None:
self._session = aiohttp.ClientSession()
self._ws = await self._session.ws_connect(self.endpoint)
hello = await self._ws.receive_json()
self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
self._heartbeat_task = self._loop.create_task(self._heartbeat())
await self._ws.send_json(
{
"op": 0,
"d": {
"server_id": self.guild_id,
"user_id": self.user_id,
"session_id": self.session_id,
"token": self.token,
},
}
)
ready = await self._ws.receive_json()
data = ready["d"]
self.ssrc = data["ssrc"]
self._server_ip = data["ip"]
self._server_port = data["port"]
if self._udp is None:
self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._udp.connect((self._server_ip, self._server_port))
await self._ws.send_json(
{
"op": 1,
"d": {
"protocol": "udp",
"data": {
"address": self._udp.getsockname()[0],
"port": self._udp.getsockname()[1],
"mode": "xsalsa20_poly1305",
},
},
}
)
session_desc = await self._ws.receive_json()
self.secret_key = session_desc["d"].get("secret_key")
async def _heartbeat(self) -> None:
assert self._ws is not None
assert self._heartbeat_interval is not None
try:
while True:
await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)})
await asyncio.sleep(self._heartbeat_interval)
except asyncio.CancelledError:
pass
async def send_audio_frame(self, frame: bytes) -> None:
if not self._udp:
raise RuntimeError("UDP socket not initialised")
self._udp.send(frame)
async def close(self) -> None:
if self._heartbeat_task:
self._heartbeat_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._heartbeat_task
if self._ws:
await self._ws.close()
if self._session:
await self._session.close()
if self._udp:
self._udp.close()

18
docs/caching.md Normal file
View File

@ -0,0 +1,18 @@
# Caching
Disagreement ships with a simple in-memory cache used by the HTTP and Gateway clients. Cached objects reduce API requests and improve performance.
The client automatically caches guilds, channels and users as they are received from events or HTTP calls. You can access cached data through lookup helpers such as `Client.get_guild`.
The cache can be cleared manually if needed:
```python
client.cache.clear()
```
## Next Steps
- [Components](using_components.md)
- [Slash Commands](slash_commands.md)
- [Voice Features](voice_features.md)

51
docs/commands.md Normal file
View File

@ -0,0 +1,51 @@
# Commands Extension
This guide covers the built-in prefix command system.
## Help Command
The command handler registers a `help` command automatically. Use it to list all available commands or get information about a single command.
```
!help # lists commands
!help ping # shows help for the "ping" command
```
The help command will show each command's brief description if provided.
## Checks
Use `commands.check` to prevent a command from running unless a predicate
returns ``True``. Checks may be regular or async callables that accept a
`CommandContext`.
```python
from disagreement.ext.commands import command, check, CheckFailure
def is_owner(ctx):
return ctx.author.id == "1"
@command()
@check(is_owner)
async def secret(ctx):
await ctx.send("Only for the owner!")
```
When a check fails a :class:`CheckFailure` is raised and dispatched through the
command error handler.
## Cooldowns
Commands can be rate limited using the ``cooldown`` decorator. The example
below restricts usage to once every three seconds per user:
```python
from disagreement.ext.commands import command, cooldown
@command()
@cooldown(1, 3.0)
async def ping(ctx):
await ctx.send("Pong!")
```
Invoking a command while it is on cooldown raises :class:`CommandOnCooldown`.

21
docs/context_menus.md Normal file
View File

@ -0,0 +1,21 @@
# Context Menu Commands
`disagreement` supports Discord's user and message context menu commands. Use the
`user_command` and `message_command` decorators from `ext.app_commands` to
define them.
```python
from disagreement.ext.app_commands import user_command, message_command, AppCommandContext
from disagreement.models import User, Message
@user_command(name="User Info")
async def user_info(ctx: AppCommandContext, user: User) -> None:
await ctx.send(f"User: {user.username}#{user.discriminator}")
@message_command(name="Quote")
async def quote(ctx: AppCommandContext, message: Message) -> None:
await ctx.send(message.content)
```
Add the commands to your client's handler and run `sync_commands()` to register
them with Discord.

25
docs/converters.md Normal file
View File

@ -0,0 +1,25 @@
# Command Argument Converters
`disagreement.ext.commands` provides a number of built in converters that will parse string arguments into richer objects. These converters are automatically used when a command callback annotates its parameters with one of the supported types.
## Supported Types
- `int`, `float`, `bool`, and `str`
- `Member` resolves a user mention or ID to a `Member` object for the current guild
- `Role` resolves a role mention or ID to a `Role` object
- `Guild` resolves a guild ID to a `Guild` object
## Example
```python
from disagreement.ext.commands import command
from disagreement.ext.commands.core import CommandContext
from disagreement.models import Member
@command()
async def kick(ctx: CommandContext, target: Member):
await target.kick()
await ctx.send(f"Kicked {target.display_name}")
```
The framework will automatically convert the first argument to a `Member` using the mention or ID provided by the user.

25
docs/events.md Normal file
View File

@ -0,0 +1,25 @@
# Events
Disagreement dispatches Gateway events to asynchronous callbacks. Handlers can be registered with `@client.event` or `client.on_event`.
Listeners may be removed later using `EventDispatcher.unregister(event_name, coro)`.
## PRESENCE_UPDATE
Triggered when a user's presence changes. The callback receives a `PresenceUpdate` model.
```python
@client.event
async def on_presence_update(presence: disagreement.PresenceUpdate):
...
```
## TYPING_START
Dispatched when a user begins typing in a channel. The callback receives a `TypingStart` model.
```python
@client.event
async def on_typing_start(typing: disagreement.TypingStart):
...
```

17
docs/gateway.md Normal file
View File

@ -0,0 +1,17 @@
# Gateway Connection and Reconnection
`GatewayClient` manages the library's WebSocket connection. When the connection drops unexpectedly, it will now automatically attempt to reconnect using an exponential backoff strategy with jitter.
The default behaviour tries up to five reconnect attempts, doubling the delay each time up to a configurable maximum. A small random jitter is added to spread out reconnect attempts when multiple clients restart at once.
You can control the maximum number of retries and the backoff cap when constructing `Client`:
```python
bot = Client(
token="your-token",
gateway_max_retries=10,
gateway_max_backoff=120.0,
)
```
These values are passed to `GatewayClient` and applied whenever the connection needs to be re-established.

36
docs/i18n.md Normal file
View File

@ -0,0 +1,36 @@
# Internationalization
Disagreement can translate command names, descriptions and other text using JSON translation files.
## Providing Translations
Use `disagreement.i18n.load_translations` to load a JSON file for a locale.
```python
from disagreement import i18n
i18n.load_translations("es", "path/to/es.json")
```
The JSON file should map translation keys to translated strings:
```json
{
"greet": "Hola",
"description": "Comando de saludo"
}
```
You can also set translations programmatically with `i18n.set_translations`.
## Using with Commands
Pass a `locale` argument when defining an `AppCommand` or using the decorators. The command name and description will be looked up using the loaded translations.
```python
@slash_command(name="greet", description="description", locale="es")
async def greet(ctx):
await ctx.send(i18n.translate("greet", ctx.locale or "es"))
```
If a translation is missing the key itself is returned.

48
docs/oauth2.md Normal file
View File

@ -0,0 +1,48 @@
# OAuth2 Setup
This guide explains how to perform a basic OAuth2 flow with `disagreement`.
1. Generate the authorization URL:
```python
from disagreement.oauth import build_authorization_url
url = build_authorization_url(
client_id="YOUR_CLIENT_ID",
redirect_uri="https://your.app/callback",
scope=["identify"],
)
print(url)
```
2. After the user authorizes your application and you receive a code, exchange it for a token:
```python
import aiohttp
from disagreement.oauth import exchange_code_for_token
async def get_token(code: str):
return await exchange_code_for_token(
client_id="YOUR_CLIENT_ID",
client_secret="YOUR_CLIENT_SECRET",
code=code,
redirect_uri="https://your.app/callback",
)
```
`exchange_code_for_token` returns the JSON payload from Discord which includes
`access_token`, `refresh_token` and expiry information.
3. When the access token expires, you can refresh it using the provided refresh
token:
```python
from disagreement.oauth import refresh_access_token
async def refresh(token: str):
return await refresh_access_token(
refresh_token=token,
client_id="YOUR_CLIENT_ID",
client_secret="YOUR_CLIENT_SECRET",
)
```

62
docs/permissions.md Normal file
View File

@ -0,0 +1,62 @@
# Permission Helpers
The `disagreement.permissions` module defines an :class:`~enum.IntFlag`
`Permissions` enumeration along with helper functions for working with
the Discord permission bitmask.
## Permissions Enum
Each attribute of ``Permissions`` represents a single permission bit. The value
is a power of two so multiple permissions can be combined using bitwise OR.
```python
from disagreement.permissions import Permissions
value = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
```
## Helper Functions
### ``permissions_value``
```python
permissions_value(*perms) -> int
```
Return an integer bitmask from one or more ``Permissions`` values. Nested
iterables are flattened automatically.
### ``has_permissions``
```python
has_permissions(current, *perms) -> bool
```
Return ``True`` if ``current`` (an ``int`` or ``Permissions``) contains all of
the provided permissions.
### ``missing_permissions``
```python
missing_permissions(current, *perms) -> List[Permissions]
```
Return a list of permissions that ``current`` does not contain.
## Example
```python
from disagreement.permissions import (
Permissions,
has_permissions,
missing_permissions,
)
current = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES
if has_permissions(current, Permissions.SEND_MESSAGES):
print("Can send messages")
print(missing_permissions(current, Permissions.ADMINISTRATOR))
```

29
docs/presence.md Normal file
View File

@ -0,0 +1,29 @@
# Updating Presence
The `Client.change_presence` method allows you to update the bot's status and displayed activity.
## Status Strings
- `online` show the bot as online
- `idle` mark the bot as away
- `dnd` do not disturb
- `invisible` appear offline
## Activity Types
An activity dictionary must include a `name` and a `type` field. The type value corresponds to Discord's activity types:
| Type | Meaning |
|-----:|--------------|
| `0` | Playing |
| `1` | Streaming |
| `2` | Listening |
| `3` | Watching |
| `4` | Custom |
| `5` | Competing |
Example:
```python
await client.change_presence(status="idle", activity={"name": "with Discord", "type": 0})
```

32
docs/reactions.md Normal file
View File

@ -0,0 +1,32 @@
# Handling Reactions
`disagreement` provides simple helpers for adding and removing message reactions.
## HTTP Methods
Use the `HTTPClient` methods directly if you need lower level control:
```python
await client._http.create_reaction(channel_id, message_id, "👍")
await client._http.delete_reaction(channel_id, message_id, "👍")
users = await client._http.get_reactions(channel_id, message_id, "👍")
```
You can also use the higher level helpers on :class:`Client`:
```python
await client.create_reaction(channel_id, message_id, "👍")
await client.delete_reaction(channel_id, message_id, "👍")
users = await client.get_reactions(channel_id, message_id, "👍")
```
## Reaction Events
Register listeners for `MESSAGE_REACTION_ADD` and `MESSAGE_REACTION_REMOVE`.
Each listener receives a `Reaction` model instance.
```python
@client.on_event("MESSAGE_REACTION_ADD")
async def on_reaction(reaction: disagreement.Reaction):
print(f"{reaction.user_id} reacted with {reaction.emoji}")
```

22
docs/slash_commands.md Normal file
View File

@ -0,0 +1,22 @@
# Using Slash Commands
The library provides a slash command framework via the `ext.app_commands` package. Define commands with decorators and register them with Discord.
```python
from disagreement.ext.app_commands import AppCommandGroup
bot_commands = AppCommandGroup("bot", "Bot commands")
@bot_commands.command(name="ping")
async def ping(ctx):
await ctx.respond("Pong!")
```
Use `AppCommandGroup` to group related commands. See the [components guide](using_components.md) for building interactive responses.
## Next Steps
- [Components](using_components.md)
- [Caching](caching.md)
- [Voice Features](voice_features.md)

15
docs/task_loop.md Normal file
View File

@ -0,0 +1,15 @@
# Task Loops
The tasks extension allows you to run functions periodically. Decorate an async function with `@tasks.loop(seconds=...)` and start it using `.start()`.
```python
from disagreement.ext import tasks
@tasks.loop(seconds=5.0)
async def announce():
print("Hello from a loop")
announce.start()
```
Stop the loop with `.stop()` when you no longer need it.

16
docs/typing_indicator.md Normal file
View File

@ -0,0 +1,16 @@
# Typing Indicator
The library exposes an async context manager to send the typing indicator for a channel.
```python
import asyncio
import disagreement
client = disagreement.Client(token="YOUR_TOKEN")
async def indicate(channel_id: str):
async with client.typing(channel_id):
await long_running_task()
```
This uses the underlying HTTP endpoint `/channels/{channel_id}/typing`.

168
docs/using_components.md Normal file
View File

@ -0,0 +1,168 @@
# Using Message Components
This guide explains how to work with the `disagreement` message component models. These examples are up to date with the current code base.
## Enabling the New Component System
Messages that use the component system must include the flag `IS_COMPONENTS_V2` (value `1 << 15`). Once this flag is set on a message it cannot be removed.
## Component Categories
The library exposes three broad categories of components:
- **Layout Components** organize the placement of other components.
- **Content Components** display static text or media.
- **Interactive Components** allow the user to interact with your message.
## Action Row
`ActionRow` is a layout container. It may hold up to five buttons or a single select menu.
```python
from disagreement.models import ActionRow, Button
from disagreement.enums import ButtonStyle
row = ActionRow(components=[
Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="btn")
])
```
## Button
Buttons provide a clickable UI element.
```python
from disagreement.models import Button
from disagreement.enums import ButtonStyle
button = Button(
style=ButtonStyle.SUCCESS,
label="Confirm",
custom_id="confirm_button",
)
```
## Select Menus
`SelectMenu` lets the user choose one or more options. The `type` parameter controls the menu variety (`STRING_SELECT`, `USER_SELECT`, `ROLE_SELECT`, `MENTIONABLE_SELECT`, `CHANNEL_SELECT`).
```python
from disagreement.models import SelectMenu, SelectOption
from disagreement.enums import ComponentType, ChannelType
menu = SelectMenu(
custom_id="example",
options=[
SelectOption(label="Option 1", value="1"),
SelectOption(label="Option 2", value="2"),
],
placeholder="Choose an option",
min_values=1,
max_values=1,
type=ComponentType.STRING_SELECT,
)
```
For channel selects you may pass `channel_types` with a list of allowed `ChannelType` values.
## Section
`Section` groups one or more `TextDisplay` components and can include an accessory `Button` or `Thumbnail`.
```python
from disagreement.models import Section, TextDisplay, Thumbnail, UnfurledMediaItem
section = Section(
components=[
TextDisplay(content="## Section Title"),
TextDisplay(content="Sections can hold multiple text displays."),
],
accessory=Thumbnail(media=UnfurledMediaItem(url="https://example.com/img.png")),
)
```
## Text Display
`TextDisplay` simply renders markdown text.
```python
from disagreement.models import TextDisplay
text_display = TextDisplay(content="**Bold text**")
```
## Thumbnail
`Thumbnail` shows a small image. Set `spoiler=True` to hide the image until clicked.
```python
from disagreement.models import Thumbnail, UnfurledMediaItem
thumb = Thumbnail(
media=UnfurledMediaItem(url="https://example.com/image.png"),
description="A picture",
spoiler=False,
)
```
## Media Gallery
`MediaGallery` holds multiple `MediaGalleryItem` objects.
```python
from disagreement.models import MediaGallery, MediaGalleryItem, UnfurledMediaItem
gallery = MediaGallery(
items=[
MediaGalleryItem(media=UnfurledMediaItem(url="https://example.com/1.png")),
MediaGalleryItem(media=UnfurledMediaItem(url="https://example.com/2.png")),
]
)
```
## File
`File` displays an uploaded file. Use `spoiler=True` to mark it as a spoiler.
```python
from disagreement.models import File, UnfurledMediaItem
file_component = File(
file=UnfurledMediaItem(url="attachment://file.zip"),
spoiler=False,
)
```
## Separator
`Separator` adds vertical spacing or an optional divider line between components.
```python
from disagreement.models import Separator
separator = Separator(divider=True, spacing=2)
```
## Container
`Container` visually groups a set of components and can apply an accent colour or spoiler.
```python
from disagreement.models import Container, TextDisplay
container = Container(
components=[TextDisplay(content="Inside a container")],
accent_color=0xFF0000,
spoiler=False,
)
```
A container can itself contain layout and content components, letting you build complex messages.
## Next Steps
- [Slash Commands](slash_commands.md)
- [Caching](caching.md)
- [Voice Features](voice_features.md)

29
docs/voice_client.md Normal file
View File

@ -0,0 +1,29 @@
# VoiceClient
`VoiceClient` provides a minimal interface to Discord's voice gateway. It handles the WebSocket handshake and lets you stream audio over UDP.
## Basic Usage
```python
import asyncio
import os
import disagreement
vc = disagreement.VoiceClient(
os.environ["DISCORD_VOICE_ENDPOINT"],
os.environ["DISCORD_SESSION_ID"],
os.environ["DISCORD_VOICE_TOKEN"],
int(os.environ["DISCORD_GUILD_ID"]),
int(os.environ["DISCORD_USER_ID"]),
)
asyncio.run(vc.connect())
```
After connecting you can send raw Opus frames:
```python
await vc.send_audio_frame(opus_bytes)
```
Call `await vc.close()` when finished.

17
docs/voice_features.md Normal file
View File

@ -0,0 +1,17 @@
# Voice Features
Disagreement includes experimental support for connecting to voice channels. You can join a voice channel and play audio using an FFmpeg subprocess.
```python
voice = await client.join_voice(guild_id, channel_id)
voice.play_file("welcome.mp3")
```
Voice support is optional and may require additional system dependencies such as FFmpeg.
## Next Steps
- [Components](using_components.md)
- [Slash Commands](slash_commands.md)
- [Caching](caching.md)

34
docs/webhooks.md Normal file
View File

@ -0,0 +1,34 @@
# Working with Webhooks
The `HTTPClient` includes helper methods for creating, editing and deleting Discord webhooks.
## Create a webhook
```python
from disagreement.http import HTTPClient
http = HTTPClient(token="TOKEN")
payload = {"name": "My Webhook"}
webhook_data = await http.create_webhook("123", payload)
```
## Edit a webhook
```python
await http.edit_webhook("456", {"name": "Renamed"})
```
## Delete a webhook
```python
await http.delete_webhook("456")
```
The methods return the raw webhook JSON. You can construct a `Webhook` model if needed:
```python
from disagreement.models import Webhook
webhook = Webhook(webhook_data)
print(webhook.id, webhook.name)
```

218
examples/basic_bot.py Normal file
View File

@ -0,0 +1,218 @@
# examples/basic_bot.py
"""
A basic example bot using the Disagreement library.
To run this bot:
1. Make sure you have the 'disagreement' library installed or accessible in your PYTHONPATH.
If running from the project root, it should be discoverable.
2. Set the DISCORD_BOT_TOKEN environment variable to your bot's token.
e.g., export DISCORD_BOT_TOKEN="your_actual_token_here" (Linux/macOS)
set DISCORD_BOT_TOKEN="your_actual_token_here" (Windows CMD)
$env:DISCORD_BOT_TOKEN="your_actual_token_here" (Windows PowerShell)
3. Run this script: python examples/basic_bot.py
"""
import asyncio
import os
import logging # Optional: for more detailed logging
# Assuming the 'disagreement' package is in the parent directory or installed
import sys
import traceback
# Add project root to path if running script directly from examples folder
# and disagreement is not installed
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
try:
import disagreement
from disagreement.ext import commands # Import the new commands extension
except ImportError:
print(
"Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly."
)
print(
"If running from the 'examples' directory, try running from the project root: python -m examples.basic_bot"
)
sys.exit(1)
from dotenv import load_dotenv
load_dotenv()
# Optional: Configure logging for more insight, especially for gateway events
# logging.basicConfig(level=logging.DEBUG) # For very verbose output
# logging.getLogger('disagreement.gateway').setLevel(logging.INFO) # Or DEBUG
# logging.getLogger('disagreement.http').setLevel(logging.INFO)
# --- Bot Configuration ---
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
# --- Intents Configuration ---
# Define the intents your bot needs. For basic message reading and responding:
intents = (
disagreement.GatewayIntent.GUILDS
| disagreement.GatewayIntent.GUILD_MESSAGES
| disagreement.GatewayIntent.MESSAGE_CONTENT
) # MESSAGE_CONTENT is privileged!
# If you don't need message content and only react to commands/mentions,
# you might not need MESSAGE_CONTENT intent.
# intents = disagreement.GatewayIntent.default() # A good starting point without privileged intents
# intents |= disagreement.GatewayIntent.MESSAGE_CONTENT # Add if needed
# --- Initialize the Client ---
if not BOT_TOKEN:
print("Error: The DISCORD_BOT_TOKEN environment variable is not set.")
print("Please set it before running the bot.")
sys.exit(1)
# Initialize Client with a command prefix
client = disagreement.Client(token=BOT_TOKEN, intents=intents, command_prefix="!")
# --- Define a Cog for example commands ---
class ExampleCog(commands.Cog): # Ensuring this uses commands.Cog
def __init__(
self, bot_client
): # Renamed client to bot_client to avoid conflict with self.client
super().__init__(bot_client) # Pass the client instance to the base Cog
@commands.command(name="hello", aliases=["hi"])
async def hello_command(self, ctx: commands.CommandContext, *, who: str = "world"):
"""Greets someone."""
await ctx.reply(f"Hello {ctx.author.mention} and {who}!")
print(f"Executed 'hello' command for {ctx.author.username}, greeting {who}")
@commands.command()
async def ping(self, ctx: commands.CommandContext):
"""Responds with Pong!"""
await ctx.reply("Pong!")
print(f"Executed 'ping' command for {ctx.author.username}")
@commands.command()
async def me(self, ctx: commands.CommandContext):
"""Shows information about the invoking user."""
reply_content = (
f"Hello {ctx.author.mention}!\n"
f"Your User ID is: {ctx.author.id}\n"
f"Your Username: {ctx.author.username}#{ctx.author.discriminator}\n"
f"Are you a bot? {'Yes' if ctx.author.bot else 'No'}"
)
await ctx.reply(reply_content)
print(f"Executed 'me' command for {ctx.author.username}")
@commands.command(name="add")
async def add_numbers(self, ctx: commands.CommandContext, num1: int, num2: int):
"""Adds two numbers."""
result = num1 + num2
await ctx.reply(f"The sum of {num1} and {num2} is {result}.")
print(
f"Executed 'add' command for {ctx.author.username}: {num1} + {num2} = {result}"
)
@commands.command(name="say")
async def say_something(self, ctx: commands.CommandContext, *, text_to_say: str):
"""Repeats the text you provide."""
await ctx.reply(f"You said: {text_to_say}")
print(
f"Executed 'say' command for {ctx.author.username}, saying: {text_to_say}"
)
@commands.command(name="quit")
async def quit_command(self, ctx: commands.CommandContext):
"""Shuts down the bot (requires YOUR_USER_ID to be set)."""
# Replace YOUR_USER_ID with your actual Discord User ID for a safe shutdown command
your_user_id = "YOUR_USER_ID_REPLACE_ME" # IMPORTANT: Replace this
if str(ctx.author.id) == your_user_id:
print("Quit command received. Shutting down...")
await ctx.reply("Shutting down...")
await self.client.close() # Access client via self.client from Cog
else:
await ctx.reply("You are not authorized to use this command.")
print(
f"Unauthorized quit attempt by {ctx.author.username} ({ctx.author.id})"
)
# --- Event Handlers ---
@client.event
async def on_ready():
"""Called when the bot is ready and connected to Discord."""
if client.user:
print(
f"Bot is ready! Logged in as {client.user.username}#{client.user.discriminator}"
)
print(f"User ID: {client.user.id}")
else:
print("Bot is ready, but client.user is missing!")
print("------")
print("Disagreement Bot is operational.")
print("Listening for commands...")
@client.event
async def on_message(message: disagreement.Message):
"""Called when a message is created and received."""
# Command processing is now handled by the CommandHandler via client._process_message_for_commands
# This on_message can be used for other message-related logic if needed,
# or removed if all message handling is command-based.
# Example: Log all messages (excluding bot's own, if client.user was available)
# if client.user and message.author.id == client.user.id:
# return
print(
f"General on_message: #{message.channel_id} from {message.author.username}: {message.content}"
)
# The old if/elif command structure is no longer needed here.
@client.on_event(
"GUILD_CREATE"
) # Example of listening to a specific event by its Discord name
async def on_guild_available(guild_data: dict): # Receives raw data for now
# In a real scenario, guild_data would be parsed into a Guild model
print(f"Guild available: {guild_data.get('name')} (ID: {guild_data.get('id')})")
# --- Main Execution ---
async def main():
print("Starting Disagreement Bot...")
try:
# Add the Cog to the client
client.add_cog(ExampleCog(client)) # Pass client instance to Cog constructor
# client.add_cog is synchronous, but it schedules cog.cog_load() if it's async.
await client.run()
except disagreement.AuthenticationError:
print(
"Authentication failed. Please check your bot token and ensure it's correct."
)
except disagreement.DisagreementException as e:
print(f"A Disagreement library error occurred: {e}")
except KeyboardInterrupt:
print("Bot shutting down due to KeyboardInterrupt...")
except Exception as e:
print(f"An unexpected error occurred: {e}")
traceback.print_exc()
finally:
if not client.is_closed():
print("Ensuring client is closed...")
await client.close()
print("Bot has been shut down.")
if __name__ == "__main__":
# Note: On Windows, the default asyncio event loop policy might not support add_signal_handler.
# If you encounter issues with Ctrl+C not working as expected,
# you might need to adjust the event loop policy or handle shutdown differently.
# For example, for Windows:
# if os.name == 'nt':
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(main())

292
examples/component_bot.py Normal file
View File

@ -0,0 +1,292 @@
import os
import asyncio
from typing import Union, Optional
from disagreement import Client, ui, HybridContext
from disagreement.models import (
Message,
SelectOption,
User,
Member,
Role,
Attachment,
Channel,
ActionRow,
Button,
Section,
TextDisplay,
Thumbnail,
UnfurledMediaItem,
MediaGallery,
MediaGalleryItem,
Container,
)
from disagreement.enums import (
ButtonStyle,
GatewayIntent,
ChannelType,
MessageFlags,
InteractionCallbackType,
MessageFlags,
)
from disagreement.ext.commands.cog import Cog
from disagreement.ext.commands.core import CommandContext
from disagreement.ext.app_commands.decorators import hybrid_command, slash_command
from disagreement.ext.app_commands.context import AppCommandContext
from disagreement.interactions import (
Interaction,
InteractionResponsePayload,
InteractionCallbackData,
)
from dotenv import load_dotenv
load_dotenv()
# Get the bot token and application ID from the environment variables
token = os.getenv("DISCORD_BOT_TOKEN")
application_id = os.getenv("DISCORD_APPLICATION_ID")
if not token:
raise ValueError("Bot token not found in environment variables")
if not application_id:
raise ValueError("Application ID not found in environment variables")
# Define the intents
intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
# Create a new client
client = Client(
token=token,
intents=intents,
command_prefix="!",
mention_replies=True,
)
# Simple stock data used for the stock cycling example
STOCKS = [
{"symbol": "AAPL", "name": "Apple Inc.", "price": "$175"},
{"symbol": "MSFT", "name": "Microsoft Corp.", "price": "$315"},
{"symbol": "GOOGL", "name": "Alphabet Inc.", "price": "$128"},
]
# Define a View class that contains our components
class MyView(ui.View):
def __init__(self):
super().__init__(timeout=180) # 180-second timeout
self.click_count = 0
@ui.button(label="Click Me!", style=ButtonStyle.SUCCESS, emoji="🖱️")
async def click_me(self, interaction: Interaction):
self.click_count += 1
await interaction.respond(
content=f"You've clicked the button {self.click_count} times!",
ephemeral=True,
)
@ui.select(
custom_id="string_select",
placeholder="Choose an option",
options=[
SelectOption(
label="Option 1", value="opt1", description="This is the first option"
),
SelectOption(
label="Option 2", value="opt2", description="This is the second option"
),
],
)
async def select_menu(self, interaction: Interaction):
if interaction.data and interaction.data.values:
await interaction.respond(
content=f"You selected: {interaction.data.values[0]}",
ephemeral=True,
)
async def on_timeout(self):
# This method is called when the view times out.
# You can use this to edit the original message, for example.
print("View has timed out!")
# View for cycling through available stocks
class StockView(ui.View):
def __init__(self):
super().__init__(timeout=180)
self.index = 0
@ui.button(label="Next Stock", style=ButtonStyle.PRIMARY)
async def next_stock(self, interaction: Interaction):
self.index = (self.index + 1) % len(STOCKS)
stock = STOCKS[self.index]
# Edit the message by responding to the interaction with an update
await interaction.edit(
content=f"**{stock['symbol']}** - {stock['name']}\nPrice: {stock['price']}",
components=self.to_components(),
# Preserve the original reply mention
allowed_mentions={"replied_user": True},
)
class ComponentCommandsCog(Cog):
def __init__(self, client: Client):
super().__init__(client)
@hybrid_command(name="components", description="Sends interactive components.")
async def components_command(self, ctx: Union[CommandContext, AppCommandContext]):
# Send a message with the view using a hybrid context helper
hybrid = HybridContext(ctx)
await hybrid.send("Here are your components:", view=MyView())
@hybrid_command(
name="stocks", description="Shows stock information with navigation."
)
async def stocks_command(self, ctx: Union[CommandContext, AppCommandContext]):
# Show the first stock and attach a button to cycle through them
first = STOCKS[0]
hybrid = HybridContext(ctx)
await hybrid.send(
f"**{first['symbol']}** - {first['name']}\nPrice: {first['price']}",
view=StockView(),
)
@hybrid_command(name="sectiondemo", description="Shows a section layout.")
async def section_demo(self, ctx: Union[CommandContext, AppCommandContext]) -> None:
section = Section(
components=[
TextDisplay(content="## Advanced Components"),
TextDisplay(content="Sections group text with accessories."),
],
accessory=Thumbnail(
media=UnfurledMediaItem(url="https://placehold.co/100x100.png")
),
)
container = Container(components=[section], accent_color=0x5865F2)
hybrid = HybridContext(ctx)
await hybrid.send(
components=[container],
flags=MessageFlags.IS_COMPONENTS_V2.value,
)
@hybrid_command(name="gallerydemo", description="Shows a media gallery.")
async def gallery_demo(self, ctx: Union[CommandContext, AppCommandContext]) -> None:
gallery = MediaGallery(
items=[
MediaGalleryItem(
media=UnfurledMediaItem(url="https://placehold.co/600x400.png")
),
MediaGalleryItem(
media=UnfurledMediaItem(url="https://placehold.co/600x400.jpg")
),
]
)
hybrid = HybridContext(ctx)
await hybrid.send(
components=[gallery],
flags=MessageFlags.IS_COMPONENTS_V2.value,
)
@hybrid_command(
name="complex_components",
description="Shows a complex layout with multiple containers.",
)
async def complex_components(
self, ctx: Union[CommandContext, AppCommandContext]
) -> None:
container1 = Container(
components=[
Section(
components=[
TextDisplay(content="## Complex Layout Example"),
TextDisplay(
content="This container has an accent color and includes a section with an action row of buttons. There is a thumbnail accessory to the right."
),
],
accessory=Thumbnail(
media=UnfurledMediaItem(url="https://placehold.co/100x100.png")
),
),
ActionRow(
components=[
Button(
style=ButtonStyle.PRIMARY,
label="Primary",
custom_id="complex_primary",
),
Button(
style=ButtonStyle.SUCCESS,
label="Success",
custom_id="complex_success",
),
Button(
style=ButtonStyle.DANGER,
label="Destructive",
custom_id="complex_destructive",
),
]
),
],
accent_color=0x5865F2,
)
container2 = Container(
components=[
TextDisplay(
content="## Another Container\nThis container has no accent color and includes a media gallery."
),
MediaGallery(
items=[
MediaGalleryItem(
media=UnfurledMediaItem(
url="https://placehold.co/300x200.png"
)
),
MediaGalleryItem(
media=UnfurledMediaItem(
url="https://placehold.co/300x200.jpg"
)
),
MediaGalleryItem(
media=UnfurledMediaItem(
url="https://placehold.co/300x200.gif"
)
),
]
),
]
)
hybrid = HybridContext(ctx)
await hybrid.send(
components=[container1, container2],
flags=MessageFlags.IS_COMPONENTS_V2.value,
)
async def main():
@client.event
async def on_ready():
if client.user:
print(f"Logged in as {client.user.username}")
if client.application_id:
try:
print("Attempting to sync application commands...")
await client.app_command_handler.sync_commands(
application_id=client.application_id
)
print("Application commands synced successfully.")
except Exception as e:
print(f"Error syncing application commands: {e}")
else:
print(
"Client's application ID is not set. Skipping application command sync."
)
client.add_cog(ComponentCommandsCog(client))
await client.run()
if __name__ == "__main__":
asyncio.run(main())

71
examples/context_menus.py Normal file
View File

@ -0,0 +1,71 @@
"""Examples showing how to use context menu commands."""
import os
import sys
# Allow running example from repository root
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from disagreement.client import Client
from disagreement.ext.app_commands import (
user_command,
message_command,
AppCommandContext,
)
from disagreement.models import User, Message
from dotenv import load_dotenv
load_dotenv()
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "")
client = Client(token=BOT_TOKEN, application_id=APP_ID)
@client.event
async def on_ready():
"""Called when the bot is ready and connected to Discord."""
if client.user:
print(f"Bot is ready! Logged in as {client.user.username}")
print("Attempting to sync application commands...")
try:
if client.application_id:
await client.app_command_handler.sync_commands(
application_id=client.application_id
)
print("Application commands synced successfully.")
else:
print("Skipping command sync: application ID is not set.")
except Exception as e:
print(f"Error syncing application commands: {e}")
else:
print("Bot is ready, but client.user is missing!")
print("------")
@user_command(name="User Info")
async def user_info(ctx: AppCommandContext, user: User) -> None:
await ctx.send(
f"Selected user: {user.username}#{user.discriminator}", ephemeral=True
)
@message_command(name="Quote")
async def quote(ctx: AppCommandContext, message: Message) -> None:
await ctx.send(f"Quoted message: {message.content}", ephemeral=True)
client.app_command_handler.add_command(user_info)
client.app_command_handler.add_command(quote)
async def main() -> None:
await client.run()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

315
examples/hybrid_bot.py Normal file
View File

@ -0,0 +1,315 @@
import asyncio
import os
import logging
from typing import Any, Optional, Literal, Union
from disagreement import HybridContext
from disagreement.client import Client
from disagreement.ext.commands.cog import Cog
from disagreement.ext.commands.core import CommandContext
from disagreement.ext.app_commands.decorators import (
slash_command,
user_command,
message_command,
hybrid_command,
)
from disagreement.ext.app_commands.commands import (
AppCommandGroup,
) # For defining groups
from disagreement.ext.app_commands.context import AppCommandContext # Added
from disagreement.models import (
User,
Member,
Role,
Attachment,
Message,
Channel,
) # For type hints
from disagreement.enums import (
ChannelType,
) # For channel option type hints, assuming it exists
# from disagreement.interactions import Interaction # Replaced by AppCommandContext
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
load_dotenv()
# --- Define a Test Cog ---
class TestCog(Cog):
def __init__(self, client: Client):
super().__init__(client)
@slash_command(name="greet", description="Sends a greeting.")
async def greet_slash(self, ctx: AppCommandContext, name: str):
await ctx.send(f"Hello, {name}! (Slash)")
@user_command(name="Show User Info")
async def show_user_info_user(
self, ctx: AppCommandContext, user: User
): # Target user is in ctx.interaction.data.target_id and resolved
target_user = (
ctx.interaction.data.resolved.users.get(ctx.interaction.data.target_id)
if ctx.interaction.data
and ctx.interaction.data.resolved
and ctx.interaction.data.target_id
else user
)
if target_user:
await ctx.send(
f"User: {target_user.username}#{target_user.discriminator} (ID: {target_user.id}) (User Cmd)",
ephemeral=True,
)
else:
await ctx.send("Could not find user information.", ephemeral=True)
@message_command(name="Quote Message")
async def quote_message_msg(
self, ctx: AppCommandContext, message: Message
): # Target message is in ctx.interaction.data.target_id and resolved
target_message = (
ctx.interaction.data.resolved.messages.get(ctx.interaction.data.target_id)
if ctx.interaction.data
and ctx.interaction.data.resolved
and ctx.interaction.data.target_id
else message
)
if target_message:
await ctx.send(
f'Quoting {target_message.author.username}: "{target_message.content}" (Message Cmd)',
ephemeral=True,
)
else:
await ctx.send("Could not find message to quote.", ephemeral=True)
@hybrid_command(name="ping", description="Checks bot latency.", aliases=["pong"])
async def ping_hybrid(
self, ctx: Union[CommandContext, AppCommandContext], arg: Optional[str] = None
):
# latency = self.client.latency # Assuming client has latency attribute from gateway - Commented out for now
latency_ms = "N/A" # Placeholder
hybrid = HybridContext(ctx)
await hybrid.send(f"Pong! Arg: {arg} (Hybrid)")
@slash_command(name="options_test", description="Tests various option types.")
async def options_test_slash(
self,
ctx: AppCommandContext,
text: str,
integer: int,
boolean: bool,
number: float,
user_option: User,
role_option: Role,
attachment_option: Attachment,
choice_option_str: Literal["apple", "banana", "cherry"],
# Channel and member options as well as numeric Literal choices are
# not yet exercised in tests pending full library support.
):
response_parts = [
f"Text: {text}",
f"Integer: {integer}",
f"Boolean: {boolean}",
f"Number: {number}",
f"User: {user_option.username}#{user_option.discriminator}",
f"Role: {role_option.name}",
f"Attachment: {attachment_option.filename} (URL: {attachment_option.url})",
f"Choice Str: {choice_option_str}",
]
await ctx.send("\n".join(response_parts), ephemeral=True)
# --- Subcommand Group Test ---
# Define the group as a class attribute.
# The AppCommandHandler's discovery mechanism (via Cog) should pick up AppCommandGroup instances.
settings_group = AppCommandGroup(
name="settings",
description="Manage bot settings.",
# guild_ids can be added here if the group is guild-specific
)
@slash_command(
name="show", description="Shows current setting values.", parent=settings_group
)
async def settings_show(
self, ctx: AppCommandContext, setting_name: Optional[str] = None
):
if setting_name:
await ctx.send(
f"Showing value for setting: {setting_name} (Value: Placeholder)",
ephemeral=True,
)
else:
await ctx.send(
"Showing all settings: (Placeholder for all settings)", ephemeral=True
)
@slash_command(
name="update", description="Updates a setting.", parent=settings_group
)
async def settings_update(
self, ctx: AppCommandContext, setting_name: str, value: str
):
await ctx.send(
f"Updated setting: {setting_name} to value: {value}", ephemeral=True
)
# The Cog's metaclass or command registration logic should handle adding `settings_group`
# (and its subcommands) to the client's AppCommandHandler.
# The decorators now handle associating subcommands with their parent group.
@slash_command(
name="numeric_choices_test", description="Tests integer and float choices."
)
async def numeric_choices_test_slash(
self,
ctx: AppCommandContext,
int_choice: Literal[10, 20, 30, 42],
float_choice: float,
):
response = (
f"Integer Choice: {int_choice} (Type: {type(int_choice).__name__})\n"
f"Float Choice: {float_choice} (Type: {type(float_choice).__name__})"
)
await ctx.send(response, ephemeral=True)
@slash_command(
name="numeric_choices_extended",
description="Tests additional integer and float choice handling.",
)
async def numeric_choices_extended_slash(
self,
ctx: AppCommandContext,
int_choice: Literal[-5, 0, 5],
float_choice: float,
):
response = (
f"Int Choice: {int_choice} (Type: {type(int_choice).__name__})\n"
f"Float Choice: {float_choice} (Type: {type(float_choice).__name__})"
)
await ctx.send(response, ephemeral=True)
@slash_command(
name="channel_member_test",
description="Tests channel and member options.",
)
async def channel_member_test_slash(
self,
ctx: AppCommandContext,
channel: Channel,
member: Member,
):
response = (
f"Channel: {channel.name} (Type: {channel.type.name})\n"
f"Member: {member.username}#{member.discriminator}"
)
await ctx.send(response, ephemeral=True)
@slash_command(
name="channel_types_test",
description="Demonstrates multiple channel type options.",
)
async def channel_types_test_slash(
self,
ctx: AppCommandContext,
text_channel: Channel,
voice_channel: Channel,
category_channel: Channel,
):
response = (
f"Text: {text_channel.type.name}\n"
f"Voice: {voice_channel.type.name}\n"
f"Category: {category_channel.type.name}"
)
await ctx.send(response, ephemeral=True)
# --- Main Bot Script ---
async def main():
bot_token = os.getenv("DISCORD_BOT_TOKEN")
application_id = os.getenv("DISCORD_APPLICATION_ID")
if not bot_token:
logger.error("Error: DISCORD_BOT_TOKEN environment variable not set.")
return
if not application_id:
logger.error("Error: DISCORD_APPLICATION_ID environment variable not set.")
return
client = Client(token=bot_token, command_prefix="!", application_id=application_id)
@client.event
async def on_ready():
if client.user:
logger.info(
f"Bot logged in as {client.user.username}#{client.user.discriminator}"
)
else:
logger.error(
"Client ready, but client.user is not populated! This should not happen."
)
return # Avoid proceeding if basic client info isn't there
if client.application_id:
logger.info(f"Application ID is: {client.application_id}")
# Sync application commands (global in this case)
try:
logger.info("Attempting to sync application commands...")
# Ensure application_id is not None before passing
app_id_to_sync = client.application_id
if (
app_id_to_sync is not None
): # Redundant due to outer if, but good for clarity
await client.app_command_handler.sync_commands(
application_id=app_id_to_sync
)
logger.info("Application commands synced successfully.")
else: # Should not be reached if outer if client.application_id is true
logger.error(
"Application ID was None despite initial check. Skipping sync."
)
except Exception as e:
logger.error(f"Error syncing application commands: {e}", exc_info=True)
else:
# This case should be less likely now that Client gets it from READY.
# If DISCORD_APPLICATION_ID was critical as a fallback, that logic would be here.
# For now, we rely on the READY event.
logger.warning(
"Client's application ID is not set after READY. Skipping application command sync."
)
# Check if the environment variable was provided, as a diagnostic.
if not application_id:
logger.warning(
"DISCORD_APPLICATION_ID environment variable was also not provided."
)
client.add_cog(TestCog(client))
try:
await client.run()
except KeyboardInterrupt:
logger.info("Bot shutting down...")
except Exception as e:
logger.error(
f"An error occurred in the bot's main run loop: {e}", exc_info=True
)
finally:
if not client.is_closed():
await client.close()
logger.info("Bot has been closed.")
if __name__ == "__main__":
# For Windows, to allow graceful shutdown with Ctrl+C
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Main loop interrupted. Exiting.")

65
examples/modal_command.py Normal file
View File

@ -0,0 +1,65 @@
"""Example showing how to present a modal using a slash command."""
import os
import asyncio
from dotenv import load_dotenv
from disagreement import Client, ui
from disagreement.enums import GatewayIntent, TextInputStyle
from disagreement.ext.app_commands.decorators import slash_command
from disagreement.ext.app_commands.context import AppCommandContext
load_dotenv()
token = os.getenv("DISCORD_BOT_TOKEN", "")
application_id = os.getenv("DISCORD_APPLICATION_ID", "")
client = Client(
token=token, application_id=application_id, intents=GatewayIntent.default()
)
class FeedbackModal(ui.Modal):
def __init__(self) -> None:
super().__init__(title="Feedback", custom_id="feedback")
@ui.text_input(label="Your feedback", style=TextInputStyle.PARAGRAPH)
async def feedback(self, interaction):
await interaction.respond(content="Thanks for your feedback!", ephemeral=True)
@slash_command(name="feedback", description="Send feedback via a modal")
async def feedback_command(ctx: AppCommandContext):
await ctx.interaction.respond_modal(FeedbackModal())
client.app_command_handler.add_command(feedback_command)
@client.event
async def on_ready():
"""Called when the bot is ready and connected to Discord."""
if client.user:
print(f"Bot is ready! Logged in as {client.user.username}")
print("Attempting to sync application commands...")
try:
if client.application_id:
await client.app_command_handler.sync_commands(
application_id=client.application_id
)
print("Application commands synced successfully.")
else:
print("Skipping command sync: application ID is not set.")
except Exception as e:
print(f"Error syncing application commands: {e}")
else:
print("Bot is ready, but client.user is missing!")
print("------")
async def main():
await client.run()
if __name__ == "__main__":
asyncio.run(main())

63
examples/modal_send.py Normal file
View File

@ -0,0 +1,63 @@
"""Example showing how to send a modal."""
import os
import sys
from dotenv import load_dotenv
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from disagreement import Client, GatewayIntent, ui # type: ignore
from disagreement.ext.app_commands.decorators import slash_command
from disagreement.ext.app_commands.context import AppCommandContext
load_dotenv()
TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
APP_ID = os.getenv("DISCORD_APPLICATION_ID", "")
if not TOKEN:
print("DISCORD_BOT_TOKEN not set")
sys.exit(1)
client = Client(token=TOKEN, intents=GatewayIntent.default(), application_id=APP_ID)
class NameModal(ui.Modal):
def __init__(self):
super().__init__(title="Your Name", custom_id="name_modal")
self.name = ui.TextInput(label="Name", custom_id="name")
@slash_command(name="namemodal", description="Shows a modal")
async def _namemodal(ctx: AppCommandContext):
await ctx.interaction.response.send_modal(NameModal())
client.app_command_handler.add_command(_namemodal)
@client.event
async def on_ready():
"""Called when the bot is ready and connected to Discord."""
if client.user:
print(f"Bot is ready! Logged in as {client.user.username}")
print("Attempting to sync application commands...")
try:
if client.application_id:
await client.app_command_handler.sync_commands(
application_id=client.application_id
)
print("Application commands synced successfully.")
else:
print("Skipping command sync: application ID is not set.")
except Exception as e:
print(f"Error syncing application commands: {e}")
else:
print("Bot is ready, but client.user is missing!")
print("------")
if __name__ == "__main__":
import asyncio
asyncio.run(client.run())

36
examples/sharded_bot.py Normal file
View File

@ -0,0 +1,36 @@
"""Example bot demonstrating gateway sharding."""
import asyncio
import os
import sys
# Ensure local package is importable when running from the examples directory
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import disagreement
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
if not TOKEN:
raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set")
client = disagreement.Client(token=TOKEN, shard_count=2)
@client.event
async def on_ready():
if client.user:
print(f"Shard bot ready as {client.user.username}#{client.user.discriminator}")
else:
print("Shard bot ready")
async def main():
if not TOKEN:
print("DISCORD_BOT_TOKEN environment variable not set")
return
await client.run()
if __name__ == "__main__":
asyncio.run(main())

30
examples/task_loop.py Normal file
View File

@ -0,0 +1,30 @@
"""Example showing the tasks extension."""
import asyncio
import os
import sys
# Allow running from the examples folder without installing
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from disagreement.ext import tasks
counter = 0
@tasks.loop(seconds=1.0)
async def ticker() -> None:
global counter
counter += 1
print(f"Tick {counter}")
async def main() -> None:
ticker.start()
await asyncio.sleep(5)
ticker.stop()
if __name__ == "__main__":
asyncio.run(main())

61
examples/voice_bot.py Normal file
View File

@ -0,0 +1,61 @@
"""Example bot demonstrating VoiceClient usage."""
import os
import asyncio
import sys
# If running from the examples directory
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from typing import cast
from dotenv import load_dotenv
import disagreement
load_dotenv()
_VOICE_ENDPOINT = os.getenv("DISCORD_VOICE_ENDPOINT")
_VOICE_TOKEN = os.getenv("DISCORD_VOICE_TOKEN")
_VOICE_SESSION_ID = os.getenv("DISCORD_SESSION_ID")
_GUILD_ID = os.getenv("DISCORD_GUILD_ID")
_USER_ID = os.getenv("DISCORD_USER_ID")
if not all([_VOICE_ENDPOINT, _VOICE_TOKEN, _VOICE_SESSION_ID, _GUILD_ID, _USER_ID]):
print("Missing one or more required environment variables for voice connection")
sys.exit(1)
assert _VOICE_ENDPOINT
assert _VOICE_TOKEN
assert _VOICE_SESSION_ID
assert _GUILD_ID
assert _USER_ID
VOICE_ENDPOINT = cast(str, _VOICE_ENDPOINT)
VOICE_TOKEN = cast(str, _VOICE_TOKEN)
VOICE_SESSION_ID = cast(str, _VOICE_SESSION_ID)
GUILD_ID = int(cast(str, _GUILD_ID))
USER_ID = int(cast(str, _USER_ID))
async def main() -> None:
vc = disagreement.VoiceClient(
VOICE_ENDPOINT,
VOICE_SESSION_ID,
VOICE_TOKEN,
GUILD_ID,
USER_ID,
)
await vc.connect()
try:
# Send silence frame as an example
await vc.send_audio_frame(b"\xf8\xff\xfe")
await asyncio.sleep(1)
finally:
await vc.close()
if __name__ == "__main__":
asyncio.run(main())

55
pyproject.toml Normal file
View File

@ -0,0 +1,55 @@
[project]
name = "disagreement"
version = "0.0.1"
description = "A Python library for the Discord API."
readme = "README.md"
requires-python = ">=3.11"
license = {text = "BSD 3-Clause"}
authors = [
{name = "Slipstream", email = "me@slipstreamm.dev"}
]
keywords = ["discord", "api", "bot", "async", "aiohttp"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Internet",
]
dependencies = [
"aiohttp>=3.9.0,<4.0.0",
]
[project.optional-dependencies]
test = [
"pytest>=8.0.0",
"pytest-asyncio>=1.0.0",
"hypothesis>=6.89.0",
]
dev = [
"dotenv>=0.0.5",
]
[project.urls]
Homepage = "https://github.com/Slipstreamm/disagreement"
Issues = "https://github.com/Slipstreamm/disagreement/issues"
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Optional: for linting/formatting, e.g., Ruff
# [tool.ruff]
# line-length = 88
# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set
# ignore = []
# [tool.ruff.format]
# quote-style = "double"

9
pyrightconfig.json Normal file
View File

@ -0,0 +1,9 @@
{
"include": ["."],
"exclude": ["**/node_modules", "**/__pycache__", "**/.venv", "**/.git", "**/dist", "**/build", "**/tests/**", "tavilytool.py"],
"ignore": [],
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"pythonVersion": "3.13",
"typeCheckingMode": "standard"
}

1
requirements.txt Normal file
View File

@ -0,0 +1 @@
-e .[test,dev]

2
setup.cfg Normal file
View File

@ -0,0 +1,2 @@
[options]
packages = find:

171
tavilytool.py Normal file
View File

@ -0,0 +1,171 @@
#!/usr/bin/env python3
"""
Tavily API Script for AI Agents
Execute with: python tavily.py "your search query"
"""
import os
import sys
import json
import requests # type: ignore
import argparse
from typing import Dict, List, Optional
class TavilyAPI:
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://api.tavily.com"
def search(
self,
query: str,
search_depth: str = "basic",
include_answer: bool = True,
include_images: bool = False,
include_raw_content: bool = False,
max_results: int = 5,
include_domains: Optional[List[str]] = None,
exclude_domains: Optional[List[str]] = None,
) -> Dict:
"""
Perform a search using Tavily API
Args:
query: Search query string
search_depth: "basic" or "advanced"
include_answer: Include AI-generated answer
include_images: Include images in results
include_raw_content: Include raw HTML content
max_results: Maximum number of results (1-20)
include_domains: List of domains to include
exclude_domains: List of domains to exclude
Returns:
Dictionary containing search results
"""
url = f"{self.base_url}/search"
payload = {
"api_key": self.api_key,
"query": query,
"search_depth": search_depth,
"include_answer": include_answer,
"include_images": include_images,
"include_raw_content": include_raw_content,
"max_results": max_results,
}
if include_domains:
payload["include_domains"] = include_domains
if exclude_domains:
payload["exclude_domains"] = exclude_domains
try:
response = requests.post(url, json=payload, timeout=30)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
return {"error": f"API request failed: {str(e)}"}
except json.JSONDecodeError:
return {"error": "Invalid JSON response from API"}
def format_results(results: Dict) -> str:
"""Format search results for display"""
if "error" in results:
return f"❌ Error: {results['error']}"
output = []
# Add answer if available
if results.get("answer"):
output.append("🤖 AI Answer:")
output.append(f" {results['answer']}")
output.append("")
# Add search results
if results.get("results"):
output.append("🔍 Search Results:")
for i, result in enumerate(results["results"], 1):
output.append(f" {i}. {result.get('title', 'No title')}")
output.append(f" URL: {result.get('url', 'No URL')}")
if result.get("content"):
# Truncate content to first 200 chars
content = (
result["content"][:200] + "..."
if len(result["content"]) > 200
else result["content"]
)
output.append(f" Content: {content}")
output.append("")
# Add images if available
if results.get("images"):
output.append("🖼️ Images:")
for img in results["images"][:3]: # Show first 3 images
output.append(f" {img}")
output.append("")
return "\n".join(output)
def main():
parser = argparse.ArgumentParser(description="Search using Tavily API")
parser.add_argument("query", help="Search query")
parser.add_argument(
"--depth",
choices=["basic", "advanced"],
default="basic",
help="Search depth (default: basic)",
)
parser.add_argument(
"--max-results",
type=int,
default=5,
help="Maximum number of results (default: 5)",
)
parser.add_argument(
"--include-images", action="store_true", help="Include images in results"
)
parser.add_argument(
"--no-answer", action="store_true", help="Don't include AI-generated answer"
)
parser.add_argument(
"--include-domains", nargs="+", help="Include only these domains"
)
parser.add_argument("--exclude-domains", nargs="+", help="Exclude these domains")
parser.add_argument("--raw", action="store_true", help="Output raw JSON response")
args = parser.parse_args()
# Get API key from environment
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
print("❌ Error: TAVILY_API_KEY environment variable not set")
print("Set it with: export TAVILY_API_KEY='your-api-key-here'")
sys.exit(1)
# Initialize Tavily API
tavily = TavilyAPI(api_key)
# Perform search
results = tavily.search(
query=args.query,
search_depth=args.depth,
include_answer=not args.no_answer,
include_images=args.include_images,
max_results=args.max_results,
include_domains=args.include_domains,
exclude_domains=args.exclude_domains,
)
# Output results
if args.raw:
print(json.dumps(results, indent=2))
else:
print(format_results(results))
if __name__ == "__main__":
main()

111
tests/conftest.py Normal file
View File

@ -0,0 +1,111 @@
import pytest
from unittest.mock import AsyncMock
from disagreement.interactions import Interaction
from disagreement.enums import InteractionType
from disagreement.models import Message
class DummyHTTP:
def __init__(self):
self.create_interaction_response = AsyncMock()
self.create_followup_message = AsyncMock(
return_value={
"id": "123",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
)
self.edit_original_interaction_response = AsyncMock(return_value={"id": "123"})
self.edit_followup_message = AsyncMock(
return_value={
"id": "123",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
)
class DummyBot:
def __init__(self):
self._http = DummyHTTP()
self.application_id = "app123"
self._guilds = {}
self._channels = {}
def get_guild(self, gid):
return self._guilds.get(gid)
async def fetch_channel(self, cid):
return self._channels.get(cid)
class DummyCommandHTTP:
def __init__(self):
self.edit_message = AsyncMock(return_value={"id": "321", "channel_id": "c"})
class DummyCommandBot:
def __init__(self):
self._http = DummyCommandHTTP()
async def edit_message(self, channel_id, message_id, *, content=None, **kwargs):
return await self._http.edit_message(
channel_id, message_id, {"content": content, **kwargs}
)
class DummyClient:
pass
class DummyInteraction:
data = None
@pytest.fixture()
def dummy_bot():
return DummyBot()
@pytest.fixture()
def interaction(dummy_bot):
data = {
"id": "1",
"application_id": dummy_bot.application_id,
"type": InteractionType.APPLICATION_COMMAND.value,
"token": "tok",
"version": 1,
}
return Interaction(data, client_instance=dummy_bot)
@pytest.fixture()
def command_bot():
return DummyCommandBot()
@pytest.fixture()
def message(command_bot):
message_data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
return Message(message_data, client_instance=command_bot)
@pytest.fixture()
def dummy_client():
return DummyClient()
@pytest.fixture()
def basic_interaction():
return DummyInteraction()

View File

@ -0,0 +1,172 @@
import pytest
from disagreement.ext.commands.converters import run_converters
from disagreement.ext.commands.core import CommandContext, Command
from disagreement.ext.commands.errors import BadArgument
from disagreement.models import Message, Member, Role, Guild
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
)
class DummyBot:
def __init__(self, guild: Guild):
self._guilds = {guild.id: guild}
def get_guild(self, gid):
return self._guilds.get(gid)
async def fetch_member(self, gid, mid):
guild = self._guilds.get(gid)
return guild.get_member(mid) if guild else None
async def fetch_role(self, gid, rid):
guild = self._guilds.get(gid)
return guild.get_role(rid) if guild else None
async def fetch_guild(self, gid):
return self._guilds.get(gid)
@pytest.fixture()
def guild_objects():
guild_data = {
"id": "1",
"name": "g",
"owner_id": "2",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
guild = Guild(guild_data, client_instance=None)
member = Member(
{
"user": {"id": "3", "username": "m", "discriminator": "0001"},
"joined_at": "t",
"roles": [],
},
None,
)
member.guild_id = guild.id
role = Role(
{
"id": "5",
"name": "r",
"color": 0,
"hoist": False,
"position": 0,
"permissions": "0",
"managed": False,
"mentionable": True,
}
)
guild._members[member.id] = member
guild.roles.append(role)
return guild, member, role
@pytest.fixture()
def command_context(guild_objects):
guild, member, role = guild_objects
bot = DummyBot(guild)
message_data = {
"id": "10",
"channel_id": "20",
"guild_id": guild.id,
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
msg = Message(message_data, client_instance=bot)
async def dummy(ctx):
pass
cmd = Command(dummy)
return CommandContext(
message=msg, bot=bot, prefix="!", command=cmd, invoked_with="dummy"
)
@pytest.mark.asyncio
async def test_member_converter(command_context, guild_objects):
_, member, _ = guild_objects
mention = f"<@!{member.id}>"
result = await run_converters(command_context, Member, mention)
assert result is member
result = await run_converters(command_context, Member, member.id)
assert result is member
@pytest.mark.asyncio
async def test_role_converter(command_context, guild_objects):
_, _, role = guild_objects
mention = f"<@&{role.id}>"
result = await run_converters(command_context, Role, mention)
assert result is role
result = await run_converters(command_context, Role, role.id)
assert result is role
@pytest.mark.asyncio
async def test_guild_converter(command_context, guild_objects):
guild, _, _ = guild_objects
result = await run_converters(command_context, Guild, guild.id)
assert result is guild
@pytest.mark.asyncio
async def test_member_converter_no_guild():
guild_data = {
"id": "99",
"name": "g",
"owner_id": "2",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
guild = Guild(guild_data, client_instance=None)
bot = DummyBot(guild)
message_data = {
"id": "11",
"channel_id": "20",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
msg = Message(message_data, client_instance=bot)
async def dummy(ctx):
pass
ctx = CommandContext(
message=msg, bot=bot, prefix="!", command=Command(dummy), invoked_with="dummy"
)
with pytest.raises(BadArgument):
await run_converters(ctx, Member, "<@!1>")

17
tests/test_cache.py Normal file
View File

@ -0,0 +1,17 @@
import time
from disagreement.cache import Cache
def test_cache_store_and_get():
cache = Cache()
cache.set("a", 123)
assert cache.get("a") == 123
def test_cache_ttl_expiry():
cache = Cache(ttl=0.01)
cache.set("b", 1)
assert cache.get("b") == 1
time.sleep(0.02)
assert cache.get("b") is None

View File

@ -0,0 +1,33 @@
import asyncio
import pytest
from unittest.mock import AsyncMock
from disagreement.client import Client
@pytest.mark.asyncio
async def test_client_async_context_closes(monkeypatch):
client = Client(token="t")
monkeypatch.setattr(client, "connect", AsyncMock())
monkeypatch.setattr(client._http, "close", AsyncMock())
async with client:
client.connect.assert_awaited_once()
client._http.close.assert_awaited_once()
assert client.is_closed()
@pytest.mark.asyncio
async def test_client_async_context_closes_on_exception(monkeypatch):
client = Client(token="t")
monkeypatch.setattr(client, "connect", AsyncMock())
monkeypatch.setattr(client._http, "close", AsyncMock())
with pytest.raises(ValueError):
async with client:
raise ValueError("boom")
client.connect.assert_awaited_once()
client._http.close.assert_awaited_once()
assert client.is_closed()

View File

@ -0,0 +1,51 @@
import asyncio
import pytest
from disagreement.ext.commands.core import Command, CommandContext
from disagreement.ext.commands.decorators import check, cooldown
from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown
@pytest.mark.asyncio
async def test_check_decorator_blocks(message):
async def cb(ctx):
pass
cmd = Command(check(lambda c: False)(cb))
ctx = CommandContext(
message=message,
bot=message._client,
prefix="!",
command=cmd,
invoked_with="test",
)
with pytest.raises(CheckFailure):
await cmd.invoke(ctx)
@pytest.mark.asyncio
async def test_cooldown_per_user(message):
uses = []
@cooldown(1, 0.05)
async def cb(ctx):
uses.append(1)
cmd = Command(cb)
ctx = CommandContext(
message=message,
bot=message._client,
prefix="!",
command=cmd,
invoked_with="test",
)
await cmd.invoke(ctx)
with pytest.raises(CommandOnCooldown):
await cmd.invoke(ctx)
await asyncio.sleep(0.05)
await cmd.invoke(ctx)
assert len(uses) == 2

View File

@ -0,0 +1,29 @@
from disagreement.components import component_factory
from disagreement.enums import ComponentType, ButtonStyle
def test_component_factory_button():
data = {
"type": ComponentType.BUTTON.value,
"style": ButtonStyle.PRIMARY.value,
"label": "Click",
"custom_id": "x",
}
comp = component_factory(data)
assert comp.to_dict()["label"] == "Click"
def test_component_factory_action_row():
data = {
"type": ComponentType.ACTION_ROW.value,
"components": [
{
"type": ComponentType.BUTTON.value,
"style": ButtonStyle.PRIMARY.value,
"label": "A",
"custom_id": "b",
}
],
}
row = component_factory(data)
assert len(row.components) == 1

199
tests/test_context.py Normal file
View File

@ -0,0 +1,199 @@
import asyncio
from unittest.mock import AsyncMock
import pytest
from disagreement.ext.app_commands.context import AppCommandContext
from disagreement.interactions import Interaction
from disagreement.enums import InteractionType, MessageFlags
from disagreement.models import (
Embed,
ActionRow,
Button,
Container,
TextDisplay,
)
from disagreement.enums import ButtonStyle, ComponentType
class DummyHTTP:
def __init__(self):
self.create_interaction_response = AsyncMock()
self.create_followup_message = AsyncMock(return_value={"id": "123"})
self.edit_original_interaction_response = AsyncMock(return_value={"id": "123"})
self.edit_followup_message = AsyncMock(return_value={"id": "123"})
class DummyBot:
def __init__(self):
self._http = DummyHTTP()
self.application_id = "app123"
self._guilds = {}
self._channels = {}
def get_guild(self, gid):
return self._guilds.get(gid)
async def fetch_channel(self, cid):
return self._channels.get(cid)
from disagreement.ext.commands.core import CommandContext, Command
from disagreement.enums import MessageFlags, ButtonStyle, ComponentType
from disagreement.models import ActionRow, Button, Container, TextDisplay
@pytest.mark.asyncio
async def test_sends_extra_payload(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
button = Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="a")
row = ActionRow([button])
await ctx.send(
content="hi",
tts=True,
components=[row],
allowed_mentions={"parse": []},
files=[{"id": 1, "filename": "f.txt"}],
ephemeral=True,
)
dummy_bot._http.create_interaction_response.assert_called_once()
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
"payload"
].data.to_dict()
assert payload["tts"] is True
assert payload["components"]
assert payload["allowed_mentions"] == {"parse": []}
assert payload["attachments"] == [{"id": 1, "filename": "f.txt"}]
assert payload["flags"] == MessageFlags.EPHEMERAL.value
await ctx.send_followup(content="again")
dummy_bot._http.create_followup_message.assert_called_once()
@pytest.mark.asyncio
async def test_second_send_is_followup(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
await ctx.send(content="first")
await ctx.send_followup(content="second")
assert dummy_bot._http.create_interaction_response.call_count == 1
assert dummy_bot._http.create_followup_message.call_count == 1
@pytest.mark.asyncio
async def test_edit_with_components_and_attachments(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
await ctx.send(content="orig")
row = ActionRow([Button(style=ButtonStyle.PRIMARY, label="B", custom_id="b")])
await ctx.edit(content="new", components=[row], attachments=[{"id": 1}])
dummy_bot._http.edit_original_interaction_response.assert_called_once()
payload = dummy_bot._http.edit_original_interaction_response.call_args.kwargs[
"payload"
]
assert payload["components"]
assert payload["attachments"] == [{"id": 1}]
@pytest.mark.asyncio
async def test_send_with_flags(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value)
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
"payload"
].data.to_dict()
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
@pytest.mark.asyncio
async def test_send_container_component(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
container = Container(components=[TextDisplay(content="hi")])
await ctx.send(components=[container], flags=MessageFlags.IS_COMPONENTS_V2.value)
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
"payload"
].data.to_dict()
assert payload["components"][0]["type"] == ComponentType.CONTAINER.value
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
@pytest.mark.asyncio
async def test_command_context_edit(command_bot, message):
async def dummy(ctx):
pass
cmd = Command(dummy)
ctx = CommandContext(
message=message, bot=command_bot, prefix="!", command=cmd, invoked_with="dummy"
)
await ctx.edit(message.id, content="new")
command_bot._http.edit_message.assert_called_once()
args = command_bot._http.edit_message.call_args[0]
assert args[0] == message.channel_id
assert args[1] == message.id
assert args[2]["content"] == "new"
@pytest.mark.asyncio
async def test_send_http_error_propagates(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
dummy_bot._http.create_interaction_response.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError):
await ctx.send(content="hi")
@pytest.mark.asyncio
async def test_concurrent_send_only_initial_once(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
async def send_msg(i: int):
if i == 0:
await ctx.send(content=str(i))
else:
await ctx.send_followup(content=str(i))
await asyncio.gather(*(send_msg(i) for i in range(50)))
assert dummy_bot._http.create_interaction_response.call_count == 1
assert dummy_bot._http.create_followup_message.call_count == 49
@pytest.mark.asyncio
async def test_send_with_flags_2():
bot = DummyBot()
interaction = Interaction(
{
"id": "1",
"application_id": bot.application_id,
"type": InteractionType.APPLICATION_COMMAND.value,
"token": "tok",
"version": 1,
},
client_instance=bot,
)
ctx = AppCommandContext(bot, interaction)
await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value)
payload = bot._http.create_interaction_response.call_args[1][
"payload"
].data.to_dict()
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
@pytest.mark.asyncio
async def test_send_container_component_2():
bot = DummyBot()
interaction = Interaction(
{
"id": "1",
"application_id": bot.application_id,
"type": InteractionType.APPLICATION_COMMAND.value,
"token": "tok",
"version": 1,
},
client_instance=bot,
)
ctx = AppCommandContext(bot, interaction)
container = Container(components=[TextDisplay(content="hi")])
await ctx.send(
components=[container],
flags=MessageFlags.IS_COMPONENTS_V2.value,
)
payload = bot._http.create_interaction_response.call_args[1][
"payload"
].data.to_dict()
assert payload["components"][0]["type"] == ComponentType.CONTAINER.value
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value

View File

@ -0,0 +1,80 @@
import pytest
from disagreement.ext.app_commands.handler import AppCommandHandler
from disagreement.ext.app_commands.decorators import user_command, message_command
from disagreement.enums import ApplicationCommandType, InteractionType
from disagreement.interactions import Interaction
from disagreement.models import User, Message
@pytest.mark.asyncio
async def test_user_context_menu_invokes(dummy_bot):
handler = AppCommandHandler(dummy_bot)
captured = {}
@user_command(name="Info")
async def info(ctx, user: User):
captured["user"] = user
handler.add_command(info)
data = {
"id": "cmd",
"name": "Info",
"type": ApplicationCommandType.USER.value,
"target_id": "42",
"resolved": {
"users": {"42": {"id": "42", "username": "Bob", "discriminator": "0001"}}
},
}
payload = {
"id": "1",
"application_id": dummy_bot.application_id,
"type": InteractionType.APPLICATION_COMMAND.value,
"token": "tok",
"version": 1,
"data": data,
}
interaction = Interaction(payload, client_instance=dummy_bot)
await handler.process_interaction(interaction)
assert isinstance(captured.get("user"), User)
assert captured["user"].id == "42"
@pytest.mark.asyncio
async def test_message_context_menu_invokes(dummy_bot):
handler = AppCommandHandler(dummy_bot)
captured = {}
@message_command(name="Quote")
async def quote(ctx, message: Message):
captured["msg"] = message
handler.add_command(quote)
msg_data = {
"id": "99",
"channel_id": "c",
"author": {"id": "2", "username": "Ann", "discriminator": "0001"},
"content": "Hello",
"timestamp": "t",
}
data = {
"id": "cmd",
"name": "Quote",
"type": ApplicationCommandType.MESSAGE.value,
"target_id": "99",
"resolved": {"messages": {"99": msg_data}},
}
payload = {
"id": "1",
"application_id": dummy_bot.application_id,
"type": InteractionType.APPLICATION_COMMAND.value,
"token": "tok",
"version": 1,
"data": data,
}
interaction = Interaction(payload, client_instance=dummy_bot)
await handler.process_interaction(interaction)
assert isinstance(captured.get("msg"), Message)
assert captured["msg"].id == "99"

View File

@ -0,0 +1,30 @@
import pytest
from disagreement.ext.app_commands.handler import AppCommandHandler
from disagreement.ext.app_commands.converters import Converter
class MyType:
def __init__(self, value):
self.value = value
class MyConverter(Converter[MyType]):
async def convert(self, interaction, value):
return MyType(f"converted-{value}")
@pytest.mark.asyncio
async def test_custom_converter_registration(dummy_client):
handler = AppCommandHandler(client=dummy_client)
handler.register_converter(MyType, MyConverter)
assert handler.get_converter(MyType) is MyConverter
result = await handler._resolve_value(
value="example",
expected_type=MyType,
resolved_data=None,
guild_id=None,
)
assert isinstance(result, MyType)
assert result.value == "converted-example"

113
tests/test_converters.py Normal file
View File

@ -0,0 +1,113 @@
import asyncio
import pytest
from hypothesis import given, strategies as st
from disagreement.ext.app_commands.converters import run_converters
from disagreement.enums import ApplicationCommandOptionType
from disagreement.errors import AppCommandOptionConversionError
from conftest import DummyInteraction, DummyClient
@pytest.mark.asyncio
@pytest.mark.parametrize(
"py_type, option_type, input_value, expected",
[
(str, ApplicationCommandOptionType.STRING, "hello", "hello"),
(int, ApplicationCommandOptionType.INTEGER, "42", 42),
(bool, ApplicationCommandOptionType.BOOLEAN, "true", True),
(float, ApplicationCommandOptionType.NUMBER, "3.14", pytest.approx(3.14)),
],
)
async def test_basic_type_converters(
basic_interaction, dummy_client, py_type, option_type, input_value, expected
):
result = await run_converters(
basic_interaction, py_type, option_type, input_value, dummy_client
)
assert result == expected
@pytest.mark.asyncio
async def test_run_converters_error_cases(basic_interaction, dummy_client):
with pytest.raises(AppCommandOptionConversionError):
await run_converters(
basic_interaction,
bool,
ApplicationCommandOptionType.BOOLEAN,
"maybe",
dummy_client,
)
with pytest.raises(AppCommandOptionConversionError):
await run_converters(
basic_interaction,
list,
ApplicationCommandOptionType.MENTIONABLE,
"x",
dummy_client,
)
@given(st.text())
def test_string_roundtrip(value):
interaction = DummyInteraction()
client = DummyClient()
result = asyncio.run(
run_converters(
interaction,
str,
ApplicationCommandOptionType.STRING,
value,
client,
)
)
assert result == value
@given(st.integers())
def test_integer_roundtrip(value):
interaction = DummyInteraction()
client = DummyClient()
result = asyncio.run(
run_converters(
interaction,
int,
ApplicationCommandOptionType.INTEGER,
str(value),
client,
)
)
assert result == value
@given(st.booleans())
def test_boolean_roundtrip(value):
interaction = DummyInteraction()
client = DummyClient()
raw = "true" if value else "false"
result = asyncio.run(
run_converters(
interaction,
bool,
ApplicationCommandOptionType.BOOLEAN,
raw,
client,
)
)
assert result is value
@given(st.floats(allow_nan=False, allow_infinity=False))
def test_number_roundtrip(value):
interaction = DummyInteraction()
client = DummyClient()
result = asyncio.run(
run_converters(
interaction,
float,
ApplicationCommandOptionType.NUMBER,
str(value),
client,
)
)
assert result == pytest.approx(value)

View File

@ -0,0 +1,38 @@
import asyncio
import logging
from types import SimpleNamespace
import pytest
from disagreement.error_handler import setup_global_error_handler
@pytest.mark.asyncio
async def test_handle_exception_logs_error(monkeypatch, capsys):
loop = asyncio.new_event_loop()
records = []
def fake_error(msg, *args, **kwargs):
records.append(msg % args if args else msg)
monkeypatch.setattr(logging, "error", fake_error)
setup_global_error_handler(loop)
exc = RuntimeError("boom")
loop.call_exception_handler({"exception": exc})
assert any("Unhandled exception" in r for r in records)
loop.close()
@pytest.mark.asyncio
async def test_handle_message_logs_error(monkeypatch):
loop = asyncio.new_event_loop()
logged = {}
def fake_error(msg, *args, **kwargs):
logged["msg"] = msg % args if args else msg
monkeypatch.setattr(logging, "error", fake_error)
setup_global_error_handler(loop)
loop.call_exception_handler({"message": "oops"})
assert "oops" in logged["msg"]
loop.close()

23
tests/test_errors.py Normal file
View File

@ -0,0 +1,23 @@
import pytest
from disagreement.errors import (
HTTPException,
RateLimitError,
AppCommandOptionConversionError,
)
def test_http_exception_message():
exc = HTTPException(message="Bad", status=400)
assert str(exc) == "HTTP 400: Bad"
def test_rate_limit_error_inherits_httpexception():
exc = RateLimitError(response=None, retry_after=1.0, is_global=True)
assert isinstance(exc, HTTPException)
assert "Rate limited" in str(exc)
def test_app_command_option_conversion_error():
exc = AppCommandOptionConversionError("bad", option_name="opt", original_value="x")
assert "opt" in str(exc) and "x" in str(exc)

View File

@ -0,0 +1,68 @@
import asyncio
import pytest
from disagreement.event_dispatcher import EventDispatcher
class DummyClient:
def __init__(self):
self.parsed = {}
def parse_message(self, data):
self.parsed["message"] = True
return data
def parse_guild(self, data):
self.parsed["guild"] = True
return data
def parse_channel(self, data):
self.parsed["channel"] = True
return data
@pytest.mark.asyncio
async def test_dispatch_calls_listener():
client = DummyClient()
dispatcher = EventDispatcher(client)
called = {}
async def listener(payload):
called["data"] = payload
dispatcher.register("MESSAGE_CREATE", listener)
await dispatcher.dispatch("MESSAGE_CREATE", {"id": 1})
assert called["data"] == {"id": 1}
assert client.parsed.get("message")
@pytest.mark.asyncio
async def test_dispatch_listener_no_args():
client = DummyClient()
dispatcher = EventDispatcher(client)
called = False
async def listener():
nonlocal called
called = True
dispatcher.register("GUILD_CREATE", listener)
await dispatcher.dispatch("GUILD_CREATE", {"id": 123})
assert called
@pytest.mark.asyncio
async def test_unregister_listener():
client = DummyClient()
dispatcher = EventDispatcher(client)
called = False
async def listener(_):
nonlocal called
called = True
dispatcher.register("MESSAGE_CREATE", listener)
dispatcher.unregister("MESSAGE_CREATE", listener)
await dispatcher.dispatch("MESSAGE_CREATE", {"id": 1})
assert not called

View File

@ -0,0 +1,42 @@
import pytest
from unittest.mock import AsyncMock
from disagreement.event_dispatcher import EventDispatcher
class DummyClient:
pass
@pytest.mark.asyncio
async def test_dispatch_error_hook_called():
dispatcher = EventDispatcher(DummyClient())
hook = AsyncMock()
dispatcher.on_dispatch_error = hook
async def listener(_):
raise RuntimeError("boom")
dispatcher.register("TEST_EVENT", listener)
await dispatcher.dispatch("TEST_EVENT", {})
hook.assert_awaited_once()
args = hook.call_args.args
assert args[0] == "TEST_EVENT"
assert isinstance(args[1], RuntimeError)
assert args[2] is listener
@pytest.mark.asyncio
async def test_dispatch_error_hook_not_called_when_ok():
dispatcher = EventDispatcher(DummyClient())
hook = AsyncMock()
dispatcher.on_dispatch_error = hook
async def listener(_):
return
dispatcher.register("TEST_EVENT", listener)
await dispatcher.dispatch("TEST_EVENT", {})
hook.assert_not_awaited()

View File

@ -0,0 +1,44 @@
import sys
import types
import pytest
from disagreement.ext import loader
def create_dummy_module(name):
mod = types.ModuleType(name)
called = {"setup": False, "teardown": False}
def setup():
called["setup"] = True
def teardown():
called["teardown"] = True
mod.setup = setup
mod.teardown = teardown
sys.modules[name] = mod
return called
def test_load_and_unload_extension():
called = create_dummy_module("dummy_ext")
module = loader.load_extension("dummy_ext")
assert module is sys.modules["dummy_ext"]
assert called["setup"] is True
loader.unload_extension("dummy_ext")
assert called["teardown"] is True
assert "dummy_ext" not in loader._loaded_extensions
assert "dummy_ext" not in sys.modules
def test_load_extension_twice_raises():
called = create_dummy_module("repeat_ext")
loader.load_extension("repeat_ext")
with pytest.raises(ValueError):
loader.load_extension("repeat_ext")
loader.unload_extension("repeat_ext")
assert called["teardown"] is True

View File

@ -0,0 +1,67 @@
import asyncio
from unittest.mock import AsyncMock
import pytest
from disagreement.gateway import GatewayClient, GatewayException
from disagreement.client import Client
class DummyHTTP:
async def get_gateway_bot(self):
return {"url": "ws://example"}
async def _ensure_session(self):
self._session = AsyncMock()
self._session.ws_connect = AsyncMock()
class DummyDispatcher:
async def dispatch(self, *_):
pass
class DummyClient:
def __init__(self):
self.loop = asyncio.get_event_loop()
self.application_id = None # Mock application_id for Client.connect
@pytest.mark.asyncio
async def test_client_connect_backoff(monkeypatch):
http = DummyHTTP()
# Mock the GatewayClient's connect method to simulate failures and then success
mock_gateway_connect = AsyncMock(
side_effect=[GatewayException("boom"), GatewayException("boom"), None]
)
# Create a dummy client instance
client = Client(
token="test_token",
intents=0,
loop=asyncio.get_event_loop(),
command_prefix="!",
verbose=False,
mention_replies=False,
shard_count=None,
)
# Patch the internal _gateway attribute after client initialization
# This ensures _initialize_gateway is called and _gateway is set
await client._initialize_gateway()
monkeypatch.setattr(client._gateway, "connect", mock_gateway_connect)
# Mock wait_until_ready to prevent it from blocking the test
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
delays = []
async def fake_sleep(d):
delays.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
# Call the client's connect method, which contains the backoff logic
await client.connect()
# Assert that GatewayClient.connect was called the correct number of times
assert mock_gateway_connect.call_count == 3
# Assert the delays experienced due to exponential backoff
assert delays == [5, 10]

View File

@ -0,0 +1,57 @@
import pytest
from disagreement.ext.commands.core import CommandHandler, Command
from disagreement.models import Message
class DummyBot:
def __init__(self):
self.sent = []
async def send_message(self, channel_id, content, **kwargs):
self.sent.append(content)
return {"id": "1", "channel_id": channel_id, "content": content}
@pytest.mark.asyncio
async def test_help_lists_commands():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
async def foo(ctx):
pass
handler.add_command(Command(foo, name="foo", brief="Foo cmd"))
msg_data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "!help",
"timestamp": "t",
}
msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg)
assert any("foo" in m for m in bot.sent)
@pytest.mark.asyncio
async def test_help_specific_command():
bot = DummyBot()
handler = CommandHandler(client=bot, prefix="!")
async def bar(ctx):
pass
handler.add_command(Command(bar, name="bar", description="Bar desc"))
msg_data = {
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "!help bar",
"timestamp": "t",
}
msg = Message(msg_data, client_instance=bot)
await handler.process_commands(msg)
assert any("Bar desc" in m for m in bot.sent)

View File

@ -0,0 +1,57 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.client import Client
from disagreement.errors import DisagreementException
from disagreement.models import User
@pytest.mark.asyncio
async def test_create_reaction_calls_http():
http = SimpleNamespace(create_reaction=AsyncMock())
client = Client.__new__(Client)
client._http = http
client._closed = False
await client.create_reaction("1", "2", "😀")
http.create_reaction.assert_called_once_with("1", "2", "😀")
@pytest.mark.asyncio
async def test_create_reaction_closed():
http = SimpleNamespace(create_reaction=AsyncMock())
client = Client.__new__(Client)
client._http = http
client._closed = True
with pytest.raises(DisagreementException):
await client.create_reaction("1", "2", "😀")
@pytest.mark.asyncio
async def test_delete_reaction_calls_http():
http = SimpleNamespace(delete_reaction=AsyncMock())
client = Client.__new__(Client)
client._http = http
client._closed = False
await client.delete_reaction("1", "2", "😀")
http.delete_reaction.assert_called_once_with("1", "2", "😀")
@pytest.mark.asyncio
async def test_get_reactions_parses_users():
users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}]
http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload))
client = Client.__new__(Client)
client._http = http
client._closed = False
client._users = {}
users = await client.get_reactions("1", "2", "😀")
http.get_reactions.assert_called_once_with("1", "2", "😀")
assert isinstance(users[0], User)

View File

@ -0,0 +1,44 @@
import asyncio
import pytest
from disagreement.hybrid_context import HybridContext
from disagreement.ext.app_commands.context import AppCommandContext
class DummyCommandCtx:
def __init__(self):
self.sent = []
async def reply(self, *a, **kw):
self.sent.append(("reply", a, kw))
async def edit(self, *a, **kw):
self.sent.append(("edit", a, kw))
class DummyAppCtx(AppCommandContext):
def __init__(self):
self.sent = []
async def send(self, *a, **kw):
self.sent.append(("send", a, kw))
@pytest.mark.asyncio
async def test_send_routes_based_on_context():
cctx = DummyCommandCtx()
actx = DummyAppCtx()
await HybridContext(cctx).send("hi")
await HybridContext(actx).send("hi")
assert cctx.sent[0][0] == "reply"
assert actx.sent[0][0] == "send"
@pytest.mark.asyncio
async def test_edit_delegation_and_error():
cctx = DummyCommandCtx()
hctx = HybridContext(cctx)
await hctx.edit("m")
assert cctx.sent[0][0] == "edit"
with pytest.raises(AttributeError):
await HybridContext(DummyAppCtx()).edit("m")

Some files were not shown because too many files have changed in this diff Show More