From 7c7cb4137c09bf8f9c81ea3a38c2d4b08463661e Mon Sep 17 00:00:00 2001 From: Slipstream Date: Mon, 9 Jun 2025 22:25:14 -0600 Subject: [PATCH] Initial commit --- .github/workflows/ci.yml | 34 + .gitignore | 129 ++ AGENTS.md | 64 + LICENSE | 26 + MANIFEST.in | 3 + README.md | 131 ++ disagreement/__init__.py | 36 + disagreement/cache.py | 55 + disagreement/client.py | 1144 +++++++++++++ disagreement/components.py | 166 ++ disagreement/enums.py | 357 ++++ disagreement/error_handler.py | 33 + disagreement/errors.py | 112 ++ disagreement/event_dispatcher.py | 243 +++ disagreement/ext/app_commands/__init__.py | 46 + disagreement/ext/app_commands/commands.py | 513 ++++++ disagreement/ext/app_commands/context.py | 556 +++++++ disagreement/ext/app_commands/converters.py | 478 ++++++ disagreement/ext/app_commands/decorators.py | 569 +++++++ disagreement/ext/app_commands/handler.py | 627 +++++++ disagreement/ext/commands/__init__.py | 49 + disagreement/ext/commands/cog.py | 155 ++ disagreement/ext/commands/converters.py | 175 ++ disagreement/ext/commands/core.py | 490 ++++++ disagreement/ext/commands/decorators.py | 150 ++ disagreement/ext/commands/errors.py | 76 + disagreement/ext/commands/help.py | 37 + disagreement/ext/commands/view.py | 103 ++ disagreement/ext/loader.py | 43 + disagreement/ext/tasks.py | 89 + disagreement/gateway.py | 490 ++++++ disagreement/http.py | 657 ++++++++ disagreement/hybrid_context.py | 32 + disagreement/i18n.py | 22 + disagreement/interactions.py | 572 +++++++ disagreement/logging_config.py | 26 + disagreement/models.py | 1642 +++++++++++++++++++ disagreement/oauth.py | 109 ++ disagreement/permissions.py | 99 ++ disagreement/rate_limiter.py | 75 + disagreement/shard_manager.py | 65 + disagreement/typing.py | 42 + disagreement/ui/__init__.py | 17 + disagreement/ui/button.py | 99 ++ disagreement/ui/item.py | 38 + disagreement/ui/modal.py | 132 ++ disagreement/ui/select.py | 92 ++ disagreement/ui/view.py | 165 ++ disagreement/voice_client.py | 120 ++ docs/caching.md | 18 + docs/commands.md | 51 + docs/context_menus.md | 21 + docs/converters.md | 25 + docs/events.md | 25 + docs/gateway.md | 17 + docs/i18n.md | 36 + docs/oauth2.md | 48 + docs/permissions.md | 62 + docs/presence.md | 29 + docs/reactions.md | 32 + docs/slash_commands.md | 22 + docs/task_loop.md | 15 + docs/typing_indicator.md | 16 + docs/using_components.md | 168 ++ docs/voice_client.md | 29 + docs/voice_features.md | 17 + docs/webhooks.md | 34 + examples/basic_bot.py | 218 +++ examples/component_bot.py | 292 ++++ examples/context_menus.py | 71 + examples/hybrid_bot.py | 315 ++++ examples/modal_command.py | 65 + examples/modal_send.py | 63 + examples/sharded_bot.py | 36 + examples/task_loop.py | 30 + examples/voice_bot.py | 61 + pyproject.toml | 55 + pyrightconfig.json | 9 + requirements.txt | 1 + setup.cfg | 2 + tavilytool.py | 171 ++ tests/conftest.py | 111 ++ tests/test_additional_converters.py | 172 ++ tests/test_cache.py | 17 + tests/test_client_context_manager.py | 33 + tests/test_command_checks.py | 51 + tests/test_components_factory.py | 29 + tests/test_context.py | 199 +++ tests/test_context_menus.py | 80 + tests/test_converter_registration.py | 30 + tests/test_converters.py | 113 ++ tests/test_error_handler.py | 38 + tests/test_errors.py | 23 + tests/test_event_dispatcher.py | 68 + tests/test_event_error_hook.py | 42 + tests/test_extension_loader.py | 44 + tests/test_gateway_backoff.py | 67 + tests/test_help_command.py | 57 + tests/test_http_reactions.py | 57 + tests/test_hybrid_context.py | 44 + tests/test_i18n.py | 21 + tests/test_interaction.py | 19 + tests/test_logging_config.py | 30 + tests/test_modal_send.py | 16 + tests/test_modals.py | 32 + tests/test_oauth.py | 114 ++ tests/test_permissions.py | 34 + tests/test_presence_and_typing.py | 39 + tests/test_presence_update.py | 33 + tests/test_rate_limiter.py | 30 + tests/test_reactions.py | 38 + tests/test_sharding.py | 52 + tests/test_slash_contexts.py | 15 + tests/test_tasks_extension.py | 24 + tests/test_typing_indicator.py | 64 + tests/test_ui.py | 65 + tests/test_view_layout.py | 53 + tests/test_voice_client.py | 75 + tests/test_wait_for.py | 35 + tests/test_webhooks.py | 44 + 120 files changed, 15345 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 AGENTS.md create mode 100644 LICENSE create mode 100644 MANIFEST.in create mode 100644 README.md create mode 100644 disagreement/__init__.py create mode 100644 disagreement/cache.py create mode 100644 disagreement/client.py create mode 100644 disagreement/components.py create mode 100644 disagreement/enums.py create mode 100644 disagreement/error_handler.py create mode 100644 disagreement/errors.py create mode 100644 disagreement/event_dispatcher.py create mode 100644 disagreement/ext/app_commands/__init__.py create mode 100644 disagreement/ext/app_commands/commands.py create mode 100644 disagreement/ext/app_commands/context.py create mode 100644 disagreement/ext/app_commands/converters.py create mode 100644 disagreement/ext/app_commands/decorators.py create mode 100644 disagreement/ext/app_commands/handler.py create mode 100644 disagreement/ext/commands/__init__.py create mode 100644 disagreement/ext/commands/cog.py create mode 100644 disagreement/ext/commands/converters.py create mode 100644 disagreement/ext/commands/core.py create mode 100644 disagreement/ext/commands/decorators.py create mode 100644 disagreement/ext/commands/errors.py create mode 100644 disagreement/ext/commands/help.py create mode 100644 disagreement/ext/commands/view.py create mode 100644 disagreement/ext/loader.py create mode 100644 disagreement/ext/tasks.py create mode 100644 disagreement/gateway.py create mode 100644 disagreement/http.py create mode 100644 disagreement/hybrid_context.py create mode 100644 disagreement/i18n.py create mode 100644 disagreement/interactions.py create mode 100644 disagreement/logging_config.py create mode 100644 disagreement/models.py create mode 100644 disagreement/oauth.py create mode 100644 disagreement/permissions.py create mode 100644 disagreement/rate_limiter.py create mode 100644 disagreement/shard_manager.py create mode 100644 disagreement/typing.py create mode 100644 disagreement/ui/__init__.py create mode 100644 disagreement/ui/button.py create mode 100644 disagreement/ui/item.py create mode 100644 disagreement/ui/modal.py create mode 100644 disagreement/ui/select.py create mode 100644 disagreement/ui/view.py create mode 100644 disagreement/voice_client.py create mode 100644 docs/caching.md create mode 100644 docs/commands.md create mode 100644 docs/context_menus.md create mode 100644 docs/converters.md create mode 100644 docs/events.md create mode 100644 docs/gateway.md create mode 100644 docs/i18n.md create mode 100644 docs/oauth2.md create mode 100644 docs/permissions.md create mode 100644 docs/presence.md create mode 100644 docs/reactions.md create mode 100644 docs/slash_commands.md create mode 100644 docs/task_loop.md create mode 100644 docs/typing_indicator.md create mode 100644 docs/using_components.md create mode 100644 docs/voice_client.md create mode 100644 docs/voice_features.md create mode 100644 docs/webhooks.md create mode 100644 examples/basic_bot.py create mode 100644 examples/component_bot.py create mode 100644 examples/context_menus.py create mode 100644 examples/hybrid_bot.py create mode 100644 examples/modal_command.py create mode 100644 examples/modal_send.py create mode 100644 examples/sharded_bot.py create mode 100644 examples/task_loop.py create mode 100644 examples/voice_bot.py create mode 100644 pyproject.toml create mode 100644 pyrightconfig.json create mode 100644 requirements.txt create mode 100644 setup.cfg create mode 100644 tavilytool.py create mode 100644 tests/conftest.py create mode 100644 tests/test_additional_converters.py create mode 100644 tests/test_cache.py create mode 100644 tests/test_client_context_manager.py create mode 100644 tests/test_command_checks.py create mode 100644 tests/test_components_factory.py create mode 100644 tests/test_context.py create mode 100644 tests/test_context_menus.py create mode 100644 tests/test_converter_registration.py create mode 100644 tests/test_converters.py create mode 100644 tests/test_error_handler.py create mode 100644 tests/test_errors.py create mode 100644 tests/test_event_dispatcher.py create mode 100644 tests/test_event_error_hook.py create mode 100644 tests/test_extension_loader.py create mode 100644 tests/test_gateway_backoff.py create mode 100644 tests/test_help_command.py create mode 100644 tests/test_http_reactions.py create mode 100644 tests/test_hybrid_context.py create mode 100644 tests/test_i18n.py create mode 100644 tests/test_interaction.py create mode 100644 tests/test_logging_config.py create mode 100644 tests/test_modal_send.py create mode 100644 tests/test_modals.py create mode 100644 tests/test_oauth.py create mode 100644 tests/test_permissions.py create mode 100644 tests/test_presence_and_typing.py create mode 100644 tests/test_presence_update.py create mode 100644 tests/test_rate_limiter.py create mode 100644 tests/test_reactions.py create mode 100644 tests/test_sharding.py create mode 100644 tests/test_slash_contexts.py create mode 100644 tests/test_tasks_extension.py create mode 100644 tests/test_typing_indicator.py create mode 100644 tests/test_ui.py create mode 100644 tests/test_view_layout.py create mode 100644 tests/test_voice_client.py create mode 100644 tests/test_wait_for.py create mode 100644 tests/test_webhooks.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..b9fb0c9 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: Python CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + npm install -g pyright + + - name: Run Pyright + run: pyright + + - name: Run tests + run: | + pytest tests/ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..53e0563 --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a PyInstaller script; this is deployed to PyInstaller's temporary directory. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# PEP 582; __pypackages__ +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE specific files +.idea/ +.vscode/ +.kilocode/ +*.swp +*.swo +*~ \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..d5b40e9 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,64 @@ +# Agents +- There are no nested `AGENTS.md` files; this is the only one in the project. +- Tools to use for testing: `pyright`, `pylint`, `pytest`, `black` + +- You have a Python script `tavilytool.py` in the project root that you can use to search the web. + +# Tavily API Script Usage Instructions + +## Basic Usage +Search for information using simple queries: +```bash +python tavilytool.py "your search query" +``` + +## Examples +```bash +python tavilytool.py "latest AI development 2024" +python tavilytool.py "how to make chocolate chip cookies" +python tavilytool.py "current weather in New York" +python tavilytool.py "best programming practices Python" +``` + +## Advanced Options + +### Search Depth +- **Basic search**: `python tavilytool.py "query"` (default) +- **Advanced search**: `python tavilytool.py "query" --depth advanced` + +### Control Results +- **Limit results**: `python tavilytool.py "query" --max-results 3` +- **Include images**: `python tavilytool.py "query" --include-images` +- **Skip AI answer**: `python tavilytool.py "query" --no-answer` + +### Domain Filtering +- **Include specific domains**: `python tavilytool.py "query" --include-domains reddit.com stackoverflow.com` +- **Exclude domains**: `python tavilytool.py "query" --exclude-domains wikipedia.org` + +### Output Format +- **Formatted output**: `python tavilytool.py "query"` (default - human readable) +- **Raw JSON**: `python tavilytool.py "query" --raw` (for programmatic processing) + +## Output Structure +The default formatted output includes: +- 🤖 **AI Answer**: Direct answer to your query +- 🔍 **Search Results**: Titles, URLs, and content snippets +- 🖼️ **Images**: Relevant images (when `--include-images` is used) + +## Command Combinations +```bash +# Advanced search with images, limited results +python tavilytool.py "machine learning tutorials" --depth advanced --include-images --max-results 3 + +# Search specific sites only, raw output +python tavilytool.py "Python best practices" --include-domains github.com stackoverflow.com --raw + +# Quick search without AI answer +python tavilytool.py "today's news" --no-answer --max-results 5 +``` + +## Tips +- Always quote your search queries to handle spaces and special characters +- Use `--max-results` to control response length and API usage +- Use `--raw` when you need to parse results programmatically +- Combine options as needed for specific use cases diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..651f7fc --- /dev/null +++ b/LICENSE @@ -0,0 +1,26 @@ +Copyright (c) 2025, Slipstream + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..95c0677 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +graft docs +graft examples +include LICENSE \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..49daefb --- /dev/null +++ b/README.md @@ -0,0 +1,131 @@ +# Disagreement + +A Python library for interacting with the Discord API, with a focus on bot development. + +## Features + +- Asynchronous design using `aiohttp` +- Gateway and HTTP API clients +- Slash command framework +- Message component helpers +- Built-in caching layer +- Experimental voice support +- Helpful error handling utilities + +## Installation + +```bash +python -m pip install -U pip +pip install disagreement +# or install from source for development +pip install -e . +``` + +Requires Python 3.11 or newer. + +## Basic Usage + +```python +import asyncio +import os +import disagreement + +# Ensure DISCORD_BOT_TOKEN is set in your environment +client = disagreement.Client(token=os.environ.get("DISCORD_BOT_TOKEN")) + +@client.on_event('MESSAGE_CREATE') +async def on_message(message: disagreement.Message): + print(f"Received: {message.content} from {message.author.username}") + if message.content.lower() == '!ping': + await message.reply('Pong!') + +async def main(): + if not client.token: + print("Error: DISCORD_BOT_TOKEN environment variable not set.") + return + try: + async with client: + await asyncio.Future() # run until cancelled + except KeyboardInterrupt: + print("Bot shutting down...") + # Add any other specific exception handling from your library, e.g., disagreement.AuthenticationError + +if __name__ == '__main__': + asyncio.run(main()) +``` + +### Global Error Handling + +To ensure unexpected errors don't crash your bot, you can enable the library's +global error handler: + +```python +import disagreement + +disagreement.setup_global_error_handler() +``` + +Call this early in your program to log unhandled exceptions instead of letting +them terminate the process. + +### Configuring Logging + +Use :func:`disagreement.logging_config.setup_logging` to configure logging for +your bot. The helper accepts a logging level and an optional file path. + +```python +import logging +from disagreement.logging_config import setup_logging + +setup_logging(logging.INFO) +# Or log to a file +setup_logging(logging.DEBUG, file="bot.log") +``` + +### Defining Subcommands with `AppCommandGroup` + +```python +from disagreement.ext.app_commands import AppCommandGroup + +settings = AppCommandGroup("settings", "Manage settings") + +@settings.command(name="show") +async def show(ctx): + """Displays a setting.""" + ... + +@settings.group("admin", description="Admin settings") +def admin_group(): + pass + +@admin_group.command(name="set") +async def set_setting(ctx, key: str, value: str): + ... +## Fetching Guilds + +Use `Client.fetch_guild` to retrieve a guild from the Discord API if it +isn't already cached. This is useful when working with guild IDs from +outside the gateway events. + +```python +guild = await client.fetch_guild("123456789012345678") +roles = await client.fetch_roles(guild.id) +``` + +## Sharding + +To run your bot across multiple gateway shards, pass `shard_count` when creating +the client: + +```python +client = disagreement.Client(token=BOT_TOKEN, shard_count=2) +``` + +See `examples/sharded_bot.py` for a full example. + +## Contributing + +Contributions are welcome! Please open an issue or submit a pull request. + +See the [docs](docs/) directory for detailed guides on components, slash commands, caching, and voice features. + diff --git a/disagreement/__init__.py b/disagreement/__init__.py new file mode 100644 index 0000000..2186476 --- /dev/null +++ b/disagreement/__init__.py @@ -0,0 +1,36 @@ +# disagreement/__init__.py + +""" +Disagreement +~~~~~~~~~~~~ + +A Python library for interacting with the Discord API. + +:copyright: (c) 2025 Slipstream +:license: BSD 3-Clause License, see LICENSE for more details. +""" + +__title__ = "disagreement" +__author__ = "Slipstream" +__license__ = "BSD 3-Clause License" +__copyright__ = "Copyright 2025 Slipstream" +__version__ = "0.0.1" + +from .client import Client +from .models import Message, User +from .voice_client import VoiceClient +from .typing import Typing +from .errors import ( + DisagreementException, + HTTPException, + GatewayException, + AuthenticationError, +) +from .enums import GatewayIntent, GatewayOpcode # Export enums +from .error_handler import setup_global_error_handler +from .hybrid_context import HybridContext +from .ext import tasks + +# Set up logging if desired +# import logging +# logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/disagreement/cache.py b/disagreement/cache.py new file mode 100644 index 0000000..666d46b --- /dev/null +++ b/disagreement/cache.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar + +if TYPE_CHECKING: + from .models import Channel, Guild + +T = TypeVar("T") + + +class Cache(Generic[T]): + """Simple in-memory cache with optional TTL support.""" + + def __init__(self, ttl: Optional[float] = None) -> None: + self.ttl = ttl + self._data: Dict[str, tuple[T, Optional[float]]] = {} + + def set(self, key: str, value: T) -> None: + expiry = time.monotonic() + self.ttl if self.ttl is not None else None + self._data[key] = (value, expiry) + + def get(self, key: str) -> Optional[T]: + item = self._data.get(key) + if not item: + return None + value, expiry = item + if expiry is not None and expiry < time.monotonic(): + self.invalidate(key) + return None + return value + + def invalidate(self, key: str) -> None: + self._data.pop(key, None) + + def clear(self) -> None: + self._data.clear() + + def values(self) -> list[T]: + now = time.monotonic() + items = [] + for key, (value, expiry) in list(self._data.items()): + if expiry is not None and expiry < now: + self.invalidate(key) + else: + items.append(value) + return items + + +class GuildCache(Cache["Guild"]): + """Cache specifically for :class:`Guild` objects.""" + + +class ChannelCache(Cache["Channel"]): + """Cache specifically for :class:`Channel` objects.""" diff --git a/disagreement/client.py b/disagreement/client.py new file mode 100644 index 0000000..df58c8a --- /dev/null +++ b/disagreement/client.py @@ -0,0 +1,1144 @@ +# disagreement/client.py + +""" +The main Client class for interacting with the Discord API. +""" + +import asyncio +import signal +from typing import ( + Optional, + Callable, + Any, + TYPE_CHECKING, + Awaitable, + Union, + List, + Dict, +) + +from .http import HTTPClient +from .gateway import GatewayClient +from .shard_manager import ShardManager +from .event_dispatcher import EventDispatcher +from .enums import GatewayIntent, InteractionType +from .errors import DisagreementException, AuthenticationError +from .typing import Typing +from .ext.commands.core import CommandHandler +from .ext.commands.cog import Cog +from .ext.app_commands.handler import AppCommandHandler +from .ext.app_commands.context import AppCommandContext +from .interactions import Interaction, Snowflake +from .error_handler import setup_global_error_handler + +if TYPE_CHECKING: + from .models import ( + Message, + Embed, + ActionRow, + Guild, + Channel, + User, + Member, + Role, + TextChannel, + VoiceChannel, + CategoryChannel, + Thread, + DMChannel, + ) + from .ui.view import View + from .enums import ChannelType as EnumChannelType + from .ext.commands.core import CommandContext + from .ext.commands.errors import CommandError, CommandInvokeError + from .ext.app_commands.commands import AppCommand, AppCommandGroup + + +class Client: + """ + Represents a client connection that connects to Discord. + This class is used to interact with the Discord WebSocket and API. + + Args: + token (str): The bot token for authentication. + intents (Optional[int]): The Gateway Intents to use. Defaults to `GatewayIntent.default()`. + You might need to enable privileged intents in your bot's application page. + loop (Optional[asyncio.AbstractEventLoop]): The event loop to use for asynchronous operations. + Defaults to `asyncio.get_event_loop()`. + command_prefix (Union[str, List[str], Callable[['Client', Message], Union[str, List[str]]]]): + The prefix(es) for commands. Defaults to '!'. + verbose (bool): If True, print raw HTTP and Gateway traffic for debugging. + """ + + def __init__( + self, + token: str, + intents: Optional[int] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + command_prefix: Union[ + str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] + ] = "!", + application_id: Optional[Union[str, int]] = None, + verbose: bool = False, + mention_replies: bool = False, + shard_count: Optional[int] = None, + ): + if not token: + raise ValueError("A bot token must be provided.") + + self.token: str = token + self.intents: int = intents if intents is not None else GatewayIntent.default() + self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + self.application_id: Optional[Snowflake] = ( + str(application_id) if application_id else None + ) + setup_global_error_handler(self.loop) + + self.verbose: bool = verbose + self._http: HTTPClient = HTTPClient(token=self.token, verbose=verbose) + self._event_dispatcher: EventDispatcher = EventDispatcher(client_instance=self) + self._gateway: Optional[GatewayClient] = ( + None # Initialized in run() or connect() + ) + self.shard_count: Optional[int] = shard_count + self._shard_manager: Optional[ShardManager] = None + + # Initialize CommandHandler + self.command_handler: CommandHandler = CommandHandler( + client=self, prefix=command_prefix + ) + self.app_command_handler: AppCommandHandler = AppCommandHandler(client=self) + # Register internal listener for processing commands from messages + self._event_dispatcher.register( + "MESSAGE_CREATE", self._process_message_for_commands + ) + + self._closed: bool = False + self._ready_event: asyncio.Event = asyncio.Event() + self.application_id: Optional[Snowflake] = None # For Application Commands + self.user: Optional["User"] = ( + None # The bot's own user object, populated on READY + ) + + # Initialize AppCommandHandler + self.app_command_handler: AppCommandHandler = AppCommandHandler(client=self) + + # Internal Caches + self._guilds: Dict[Snowflake, "Guild"] = {} + self._channels: Dict[Snowflake, "Channel"] = ( + {} + ) # Stores all channel types by ID + self._users: Dict[Snowflake, Any] = ( + {} + ) # Placeholder for User model cache if needed + self._messages: Dict[Snowflake, "Message"] = {} + self._views: Dict[Snowflake, "View"] = {} + + # Default whether replies mention the user + self.mention_replies: bool = mention_replies + + # Basic signal handling for graceful shutdown + # This might be better handled by the user's application code, but can be a nice default. + # For more robust handling, consider libraries or more advanced patterns. + try: + self.loop.add_signal_handler( + signal.SIGINT, lambda: self.loop.create_task(self.close()) + ) + self.loop.add_signal_handler( + signal.SIGTERM, lambda: self.loop.create_task(self.close()) + ) + except NotImplementedError: + # add_signal_handler is not available on all platforms (e.g., Windows default event loop policy) + # Users on these platforms would need to handle shutdown differently. + print( + "Warning: Signal handlers for SIGINT/SIGTERM could not be added. " + "Graceful shutdown via signals might not work as expected on this platform." + ) + + async def _initialize_gateway(self): + """Initializes the GatewayClient if it doesn't exist.""" + if self._gateway is None: + self._gateway = GatewayClient( + http_client=self._http, + event_dispatcher=self._event_dispatcher, + token=self.token, + intents=self.intents, + client_instance=self, + verbose=self.verbose, + ) + + async def _initialize_shard_manager(self) -> None: + """Initializes the :class:`ShardManager` if not already created.""" + if self._shard_manager is None: + count = self.shard_count or 1 + self._shard_manager = ShardManager(self, count) + + async def connect(self, reconnect: bool = True) -> None: + """ + Establishes a connection to Discord. This includes logging in and connecting to the Gateway. + This method is a coroutine. + + Args: + reconnect (bool): Whether to automatically attempt to reconnect on disconnect. + (Note: Basic reconnect logic is within GatewayClient for now) + + Raises: + GatewayException: If the connection to the gateway fails. + AuthenticationError: If the token is invalid. + """ + if self._closed: + raise DisagreementException("Client is closed and cannot connect.") + if self.shard_count and self.shard_count > 1: + await self._initialize_shard_manager() + assert self._shard_manager is not None + await self._shard_manager.start() + print( + f"Client connected using {self.shard_count} shards, waiting for READY signal..." + ) + await self.wait_until_ready() + print("Client is READY!") + return + + await self._initialize_gateway() + assert self._gateway is not None # Should be initialized by now + + retry_delay = 5 # seconds + max_retries = 5 # For initial connection attempts by Client.run, Gateway has its own internal retries for some cases. + + for attempt in range(max_retries): + try: + await self._gateway.connect() + # After successful connection, GatewayClient's HELLO handler will trigger IDENTIFY/RESUME + # and its READY handler will set self._ready_event via dispatcher. + print("Client connected to Gateway, waiting for READY signal...") + await self.wait_until_ready() # Wait for the READY event from Gateway + print("Client is READY!") + return # Successfully connected and ready + except AuthenticationError: # Non-recoverable by retry here + print("Authentication failed. Please check your bot token.") + await self.close() # Ensure cleanup + raise + except DisagreementException as e: # Includes GatewayException + print(f"Failed to connect (Attempt {attempt + 1}/{max_retries}): {e}") + if attempt < max_retries - 1: + print(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + retry_delay = min( + retry_delay * 2, 60 + ) # Exponential backoff up to 60s + else: + print("Max connection retries reached. Giving up.") + await self.close() # Ensure cleanup + raise + # Should not be reached if max_retries is > 0 + if max_retries == 0: # If max_retries was 0, means no retries attempted + raise DisagreementException("Connection failed with 0 retries allowed.") + + async def run(self) -> None: + """ + A blocking call that connects the client to Discord and runs until the client is closed. + This method is a coroutine. + It handles login, Gateway connection, and keeping the connection alive. + """ + if self._closed: + raise DisagreementException("Client is already closed.") + + try: + await self.connect() + # The GatewayClient's _receive_loop will keep running. + # This run method effectively waits until the client is closed or an unhandled error occurs. + # A more robust implementation might have a main loop here that monitors gateway health. + # For now, we rely on the gateway's tasks. + while not self._closed: + if ( + self._gateway + and self._gateway._receive_task + and self._gateway._receive_task.done() + ): + # If receive task ended unexpectedly, try to handle it or re-raise + try: + exc = self._gateway._receive_task.exception() + if exc: + print( + f"Gateway receive task ended with exception: {exc}. Attempting to reconnect..." + ) + # This is a basic reconnect strategy from the client side. + # GatewayClient itself might handle some reconnects. + await self.close_gateway( + code=1000 + ) # Close current gateway state + await asyncio.sleep(5) # Wait before reconnecting + if ( + not self._closed + ): # If client wasn't closed by the exception handler + await self.connect() + else: + break # Client was closed, exit run loop + else: + print( + "Gateway receive task ended without exception. Assuming clean shutdown or reconnect handled internally." + ) + if ( + not self._closed + ): # If not explicitly closed, might be an issue + print( + "Warning: Gateway receive task ended but client not closed. This might indicate an issue." + ) + # Consider a more robust health check or reconnect strategy here. + await asyncio.sleep( + 1 + ) # Prevent tight loop if something is wrong + else: + break # Client was closed + except asyncio.CancelledError: + print("Gateway receive task was cancelled.") + break # Exit if cancelled + except Exception as e: + print(f"Error checking gateway receive task: {e}") + break # Exit on other errors + await asyncio.sleep(1) # Main loop check interval + except DisagreementException as e: + print(f"Client run loop encountered an error: {e}") + # Error already logged by connect or other methods + except asyncio.CancelledError: + print("Client run loop was cancelled.") + finally: + if not self._closed: + await self.close() + + async def close(self) -> None: + """ + Closes the connection to Discord. This method is a coroutine. + """ + if self._closed: + return + + self._closed = True + print("Closing client...") + + if self._shard_manager: + await self._shard_manager.close() + self._shard_manager = None + if self._gateway: + await self._gateway.close() + + if self._http: # HTTPClient has its own session to close + await self._http.close() + + self._ready_event.set() # Ensure any waiters for ready are unblocked + print("Client closed.") + + async def __aenter__(self) -> "Client": + """Enter the context manager by connecting to Discord.""" + await self.connect() + return self + + async def __aexit__( + self, + exc_type: Optional[type], + exc: Optional[BaseException], + tb: Optional[BaseException], + ) -> bool: + """Exit the context manager and close the client.""" + await self.close() + return False + + async def close_gateway(self, code: int = 1000) -> None: + """Closes only the gateway connection, allowing for potential reconnect.""" + if self._shard_manager: + await self._shard_manager.close() + self._shard_manager = None + if self._gateway: + await self._gateway.close(code=code) + self._gateway = None + self._ready_event.clear() # No longer ready if gateway is closed + + def is_closed(self) -> bool: + """Indicates if the client has been closed.""" + return self._closed + + def is_ready(self) -> bool: + """Indicates if the client has successfully connected to the Gateway and is ready.""" + return self._ready_event.is_set() + + async def wait_until_ready(self) -> None: + """|coro| + Waits until the client is fully connected to Discord and the initial state is processed. + This is mainly useful for waiting for the READY event from the Gateway. + """ + await self._ready_event.wait() + + async def wait_for( + self, + event_name: str, + check: Optional[Callable[[Any], bool]] = None, + timeout: Optional[float] = None, + ) -> Any: + """|coro| + Waits for a specific event to occur that satisfies the ``check``. + + Parameters + ---------- + event_name: str + The name of the event to wait for. + check: Optional[Callable[[Any], bool]] + A function that determines whether the received event should resolve the wait. + timeout: Optional[float] + How long to wait for the event before raising :class:`asyncio.TimeoutError`. + """ + + future: asyncio.Future = self.loop.create_future() + self._event_dispatcher.add_waiter(event_name, future, check) + try: + return await asyncio.wait_for(future, timeout=timeout) + finally: + self._event_dispatcher.remove_waiter(event_name, future) + + async def change_presence( + self, + status: str, + activity_name: Optional[str] = None, + activity_type: int = 0, + since: int = 0, + afk: bool = False, + ): + """ + Changes the client's presence on Discord. + + Args: + status (str): The new status for the client (e.g., "online", "idle", "dnd", "invisible"). + activity_name (Optional[str]): The name of the activity. + activity_type (int): The type of the activity. + since (int): The timestamp (in milliseconds) of when the client went idle. + afk (bool): Whether the client is AFK. + """ + if self._closed: + raise DisagreementException("Client is closed.") + + if self._gateway: + await self._gateway.update_presence( + status=status, + activity_name=activity_name, + activity_type=activity_type, + since=since, + afk=afk, + ) + + # --- Event Handling --- + + def event( + self, coro: Callable[..., Awaitable[None]] + ) -> Callable[..., Awaitable[None]]: + """ + A decorator that registers an event to listen to. + The name of the coroutine is used as the event name. + Example: + @client.event + async def on_ready(): # Will listen for the 'READY' event + print("Bot is ready!") + + @client.event + async def on_message(message: disagreement.Message): # Will listen for 'MESSAGE_CREATE' + print(f"Message from {message.author}: {message.content}") + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("Event registered must be a coroutine function.") + + event_name = coro.__name__ + # Map common function names to Discord event types + # e.g., on_ready -> READY, on_message -> MESSAGE_CREATE + if event_name.startswith("on_"): + discord_event_name = event_name[3:].upper() # e.g., on_message -> MESSAGE + if discord_event_name == "MESSAGE": # Common case + discord_event_name = "MESSAGE_CREATE" + # Add other mappings if needed, e.g. on_member_join -> GUILD_MEMBER_ADD + + self._event_dispatcher.register(discord_event_name, coro) + else: + # If not starting with "on_", assume it's the direct Discord event name (e.g. "TYPING_START") + # Or raise an error if a specific format is required. + # For now, let's assume direct mapping if no "on_" prefix. + self._event_dispatcher.register(event_name.upper(), coro) + + return coro # Return the original coroutine + + def on_event( + self, event_name: str + ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """ + A decorator that registers an event to listen to with a specific event name. + Example: + @client.on_event('MESSAGE_CREATE') + async def my_message_handler(message: disagreement.Message): + print(f"Message: {message.content}") + """ + + def decorator( + coro: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + if not asyncio.iscoroutinefunction(coro): + raise TypeError("Event registered must be a coroutine function.") + self._event_dispatcher.register(event_name.upper(), coro) + return coro + + return decorator + + async def _process_message_for_commands(self, message: "Message") -> None: + """Internal listener to process messages for commands.""" + # Make sure message object is valid and not from a bot (optional, common check) + if ( + not message or not message.author or message.author.bot + ): # Add .bot check to User model + return + await self.command_handler.process_commands(message) + + # --- Command Framework Methods --- + + def add_cog(self, cog: Cog) -> None: + """ + Adds a Cog to the bot. + Cogs are classes that group commands, listeners, and state. + This will also discover and register any application commands defined in the cog. + + Args: + cog (Cog): An instance of a class derived from `disagreement.ext.commands.Cog`. + """ + # Add to prefix command handler + self.command_handler.add_cog( + cog + ) # This should call cog._inject() internally or cog._inject() is called on Cog init + + # Discover and add application commands from the cog + # AppCommand and AppCommandGroup are already imported in TYPE_CHECKING block + for app_cmd_obj in cog.get_app_commands_and_groups(): # Uses the new method + # The cog attribute should have been set within Cog._inject() for AppCommands + self.app_command_handler.add_command(app_cmd_obj) + print( + f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'." + ) + + def remove_cog(self, cog_name: str) -> Optional[Cog]: + """ + Removes a Cog from the bot. + + Args: + cog_name (str): The name of the Cog to remove. + + Returns: + Optional[Cog]: The Cog that was removed, or None if not found. + """ + removed_cog = self.command_handler.remove_cog(cog_name) + if removed_cog: + # Also remove associated application commands + # This requires AppCommand to store a reference to its cog, or iterate all app_commands. + # Assuming AppCommand has a .cog attribute, which is set in Cog._inject() + # And AppCommandGroup might store commands that have .cog attribute + for app_cmd_or_group in removed_cog.get_app_commands_and_groups(): + # The AppCommandHandler.remove_command needs to handle both AppCommand and AppCommandGroup + self.app_command_handler.remove_command( + app_cmd_or_group.name + ) # Assuming name is unique enough for removal here + print( + f"Removed app command/group '{app_cmd_or_group.name}' from cog '{cog_name}'." + ) + # Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique + # (e.g. if it needs type or if groups and commands can share names). + # For now, assuming name is sufficient for removal from the handler's flat list. + return removed_cog + + def add_app_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None: + """ + Adds a standalone application command or group to the bot. + Use this for commands not defined within a Cog. + + Args: + command (Union[AppCommand, AppCommandGroup]): The application command or group instance. + This is typically the object returned by a decorator like @slash_command. + """ + from .ext.app_commands.commands import ( + AppCommand, + AppCommandGroup, + ) # Ensure types + + if not isinstance(command, (AppCommand, AppCommandGroup)): + raise TypeError( + "Command must be an instance of AppCommand or AppCommandGroup." + ) + + # If it's a decorated function, the command object might be on __app_command_object__ + if hasattr(command, "__app_command_object__") and isinstance( + getattr(command, "__app_command_object__"), (AppCommand, AppCommandGroup) + ): + actual_command_obj = getattr(command, "__app_command_object__") + self.app_command_handler.add_command(actual_command_obj) + print( + f"Registered standalone app command/group '{actual_command_obj.name}'." + ) + elif isinstance( + command, (AppCommand, AppCommandGroup) + ): # It's already the command object + self.app_command_handler.add_command(command) + print(f"Registered standalone app command/group '{command.name}'.") + else: + # This case should ideally not be hit if type checks are done by decorators + print( + f"Warning: Could not register app command {command}. It's not a recognized command object or decorated function." + ) + + async def on_command_error( + self, ctx: "CommandContext", error: "CommandError" + ) -> None: + """ + Default command error handler. Called when a command raises an error. + Users can override this method in a subclass of Client to implement custom error handling. + + Args: + ctx (CommandContext): The context of the command that raised the error. + error (CommandError): The error that was raised. + """ + # Default behavior: print to console. + # Users might want to send a message to ctx.channel or log to a file. + print( + f"Error in command '{ctx.command.name if ctx.command else 'unknown'}': {error}" + ) + + # Need to import CommandInvokeError for this check if not already globally available + # For now, assuming it's imported via TYPE_CHECKING or directly if needed at runtime + from .ext.commands.errors import ( + CommandInvokeError as CIE, + ) # Local import for isinstance check + + if isinstance(error, CIE): + # Now it's safe to access error.original + print( + f"Original exception: {type(error.original).__name__}: {error.original}" + ) + # import traceback + # traceback.print_exception(type(error.original), error.original, error.original.__traceback__) + + # --- Model Parsing and Fetching --- + + def parse_user(self, data: Dict[str, Any]) -> "User": + """Parses user data and returns a User object, updating cache.""" + from .models import User # Ensure User model is available + + user = User(data) + self._users[user.id] = user # Cache the user + return user + + def parse_channel(self, data: Dict[str, Any]) -> "Channel": + """Parses channel data and returns a Channel object, updating caches.""" + + from .models import channel_factory + + channel = channel_factory(data, self) + self._channels[channel.id] = channel + if channel.guild_id: + guild = self._guilds.get(channel.guild_id) + if guild: + guild._channels[channel.id] = channel + return channel + + def parse_message(self, data: Dict[str, Any]) -> "Message": + """Parses message data and returns a Message object, updating cache.""" + + from .models import Message + + message = Message(data, client_instance=self) + self._messages[message.id] = message + return message + + async def fetch_user(self, user_id: Snowflake) -> Optional["User"]: + """Fetches a user by ID from Discord.""" + if self._closed: + raise DisagreementException("Client is closed.") + + cached_user = self._users.get(user_id) + if cached_user: + return cached_user # Return cached if available, though fetch implies wanting fresh + + try: + user_data = await self._http.get_user(user_id) + return self.parse_user(user_data) + except DisagreementException as e: # Catch HTTP exceptions from http client + print(f"Failed to fetch user {user_id}: {e}") + return None + + async def fetch_message( + self, channel_id: Snowflake, message_id: Snowflake + ) -> Optional["Message"]: + """Fetches a message by ID from Discord and caches it.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + cached_message = self._messages.get(message_id) + if cached_message: + return cached_message + + try: + message_data = await self._http.get_message(channel_id, message_id) + return self.parse_message(message_data) + except DisagreementException as e: + print( + f"Failed to fetch message {message_id} from channel {channel_id}: {e}" + ) + return None + + def parse_member(self, data: Dict[str, Any], guild_id: Snowflake) -> "Member": + """Parses member data and returns a Member object, updating relevant caches.""" + from .models import Member # Ensure Member model is available + + # Member's __init__ should handle the nested 'user' data. + member = Member(data, client_instance=self) + member.guild_id = str(guild_id) + + # Cache the member in the guild's member cache + guild = self._guilds.get(guild_id) + if guild: + guild._members[member.id] = member # Assuming Guild has _members dict + + # Also cache the user part if not already cached or if this is newer + # Since Member inherits from User, the member object itself is the user. + self._users[member.id] = member + # If 'user' was in data and Member.__init__ used it, it's already part of 'member'. + return member + + async def fetch_member( + self, guild_id: Snowflake, member_id: Snowflake + ) -> Optional["Member"]: + """Fetches a member from a guild by ID.""" + if self._closed: + raise DisagreementException("Client is closed.") + + guild = self.get_guild(guild_id) + if guild: + cached_member = guild.get_member(member_id) # Use Guild's get_member + if cached_member: + return cached_member # Return cached if available + + try: + member_data = await self._http.get_guild_member(guild_id, member_id) + return self.parse_member(member_data, guild_id) + except DisagreementException as e: + print(f"Failed to fetch member {member_id} from guild {guild_id}: {e}") + return None + + def parse_role(self, data: Dict[str, Any], guild_id: Snowflake) -> "Role": + """Parses role data and returns a Role object, updating guild's role cache.""" + from .models import Role # Ensure Role model is available + + role = Role(data) + guild = self._guilds.get(guild_id) + if guild: + # Update the role in the guild's roles list if it exists, or add it. + # Guild.roles is List[Role]. We need to find and replace or append. + found = False + for i, existing_role in enumerate(guild.roles): + if existing_role.id == role.id: + guild.roles[i] = role + found = True + break + if not found: + guild.roles.append(role) + return role + + def parse_guild(self, data: Dict[str, Any]) -> "Guild": + """Parses guild data and returns a Guild object, updating cache.""" + + from .models import Guild + + guild = Guild(data, client_instance=self) + self._guilds[guild.id] = guild + + # Populate channel and member caches if provided + for ch in data.get("channels", []): + channel_obj = self.parse_channel(ch) + guild._channels[channel_obj.id] = channel_obj + + for member in data.get("members", []): + member_obj = self.parse_member(member, guild.id) + guild._members[member_obj.id] = member_obj + + return guild + + async def fetch_roles(self, guild_id: Snowflake) -> List["Role"]: + """Fetches all roles for a given guild and caches them. + + If the guild is not cached, it will be retrieved first using + :meth:`fetch_guild`. + """ + if self._closed: + raise DisagreementException("Client is closed.") + guild = self.get_guild(guild_id) + if not guild: + guild = await self.fetch_guild(guild_id) + if not guild: + return [] + + try: + roles_data = await self._http.get_guild_roles(guild_id) + parsed_roles = [] + for role_data in roles_data: + # parse_role will add/update it in the guild.roles list + parsed_roles.append(self.parse_role(role_data, guild_id)) + guild.roles = parsed_roles # Replace the entire list with the fresh one + return parsed_roles + except DisagreementException as e: + print(f"Failed to fetch roles for guild {guild_id}: {e}") + return [] + + async def fetch_role( + self, guild_id: Snowflake, role_id: Snowflake + ) -> Optional["Role"]: + """Fetches a specific role from a guild by ID. + If roles for the guild aren't cached or might be stale, it fetches all roles first. + """ + guild = self.get_guild(guild_id) + if guild: + # Try to find in existing guild.roles + for role in guild.roles: + if role.id == role_id: + return role + + # If not found in cache or guild doesn't exist yet in cache, fetch all roles for the guild + await self.fetch_roles(guild_id) # This will populate/update guild.roles + + # Try again from the now (hopefully) populated cache + guild = self.get_guild( + guild_id + ) # Re-get guild in case it was populated by fetch_roles + if guild: + for role in guild.roles: + if role.id == role_id: + return role + + return None # Role not found even after fetching + + # --- API Methods --- + + # --- API Methods --- + + async def send_message( + self, + channel_id: str, + content: Optional[str] = None, + *, # Make additional params keyword-only + tts: bool = False, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + message_reference: Optional[Dict[str, Any]] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """|coro| + Sends a message to the specified channel. + + Args: + channel_id (str): The ID of the channel to send the message to. + content (Optional[str]): The content of the message. + tts (bool): Whether the message should be sent with text-to-speech. Defaults to False. + embed (Optional[Embed]): A single embed to send. Cannot be used with `embeds`. + embeds (Optional[List[Embed]]): A list of embeds to send. Cannot be used with `embed`. + Discord supports up to 10 embeds per message. + components (Optional[List[ActionRow]]): A list of ActionRow components to include. + allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. + message_reference (Optional[Dict[str, Any]]): Message reference for replying. + flags (Optional[int]): Message flags. + view (Optional[View]): A view to send with the message. + + Returns: + Message: The message that was sent. + + Raises: + HTTPException: Sending the message failed. + ValueError: If both `embed` and `embeds` are provided, or if both `components` and `view` are provided. + """ + if self._closed: + raise DisagreementException("Client is closed.") + + if embed and embeds: + raise ValueError("Cannot provide both embed and embeds.") + if components and view: + raise ValueError("Cannot provide both 'components' and 'view'.") + + final_embeds_payload: Optional[List[Dict[str, Any]]] = None + if embed: + final_embeds_payload = [embed.to_dict()] + elif embeds: + from .models import ( + Embed as EmbedModel, + ) + + final_embeds_payload = [ + e.to_dict() for e in embeds if isinstance(e, EmbedModel) + ] + + components_payload: Optional[List[Dict[str, Any]]] = None + if view: + await view._start(self) + components_payload = view.to_components_payload() + elif components: + from .models import ActionRow as ActionRowModel + + components_payload = [ + comp.to_dict() + for comp in components + if isinstance(comp, ActionRowModel) + ] + + message_data = await self._http.send_message( + channel_id=channel_id, + content=content, + tts=tts, + embeds=final_embeds_payload, + components=components_payload, + allowed_mentions=allowed_mentions, + message_reference=message_reference, + flags=flags, + ) + + if view: + message_id = message_data["id"] + view.message_id = message_id + self._views[message_id] = view + + return self.parse_message(message_data) + + def typing(self, channel_id: str) -> Typing: + """Return a context manager to show a typing indicator in a channel.""" + + return Typing(self, channel_id) + + async def create_reaction( + self, channel_id: str, message_id: str, emoji: str + ) -> None: + """|coro| Add a reaction to a message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.create_reaction(channel_id, message_id, emoji) + + async def delete_reaction( + self, channel_id: str, message_id: str, emoji: str + ) -> None: + """|coro| Remove the bot's reaction from a message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + await self._http.delete_reaction(channel_id, message_id, emoji) + + async def get_reactions( + self, channel_id: str, message_id: str, emoji: str + ) -> List["User"]: + """|coro| Return the users who reacted with the given emoji.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + users_data = await self._http.get_reactions(channel_id, message_id, emoji) + return [self.parse_user(u) for u in users_data] + + async def edit_message( + self, + channel_id: str, + message_id: str, + *, + content: Optional[str] = None, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """Edits a previously sent message.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + if embed and embeds: + raise ValueError("Cannot provide both embed and embeds.") + if components and view: + raise ValueError("Cannot provide both 'components' and 'view'.") + + final_embeds_payload: Optional[List[Dict[str, Any]]] = None + if embed: + final_embeds_payload = [embed.to_dict()] + elif embeds: + final_embeds_payload = [e.to_dict() for e in embeds] + + components_payload: Optional[List[Dict[str, Any]]] = None + if view: + await view._start(self) + components_payload = view.to_components_payload() + elif components: + components_payload = [c.to_dict() for c in components] + + payload: Dict[str, Any] = {} + if content is not None: + payload["content"] = content + if final_embeds_payload is not None: + payload["embeds"] = final_embeds_payload + if components_payload is not None: + payload["components"] = components_payload + if allowed_mentions is not None: + payload["allowed_mentions"] = allowed_mentions + if flags is not None: + payload["flags"] = flags + + message_data = await self._http.edit_message( + channel_id=channel_id, + message_id=message_id, + payload=payload, + ) + + if view: + view.message_id = message_data["id"] + self._views[message_data["id"]] = view + + return self.parse_message(message_data) + + def get_guild(self, guild_id: Snowflake) -> Optional["Guild"]: + """Returns a guild from the internal cache. + + Use :meth:`fetch_guild` to retrieve it from Discord if it's not cached. + """ + + return self._guilds.get(guild_id) + + def get_channel(self, channel_id: Snowflake) -> Optional["Channel"]: + """Returns a channel from the internal cache.""" + + return self._channels.get(channel_id) + + def get_message(self, message_id: Snowflake) -> Optional["Message"]: + """Returns a message from the internal cache.""" + + return self._messages.get(message_id) + + async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]: + """Fetches a guild by ID from Discord and caches it.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + cached_guild = self._guilds.get(guild_id) + if cached_guild: + return cached_guild + + try: + guild_data = await self._http.get_guild(guild_id) + return self.parse_guild(guild_data) + except DisagreementException as e: + print(f"Failed to fetch guild {guild_id}: {e}") + return None + + async def fetch_channel(self, channel_id: Snowflake) -> Optional["Channel"]: + """Fetches a channel from Discord by its ID and updates the cache.""" + + if self._closed: + raise DisagreementException("Client is closed.") + + try: + channel_data = await self._http.get_channel(channel_id) + if not channel_data: + return None + + from .models import channel_factory + + channel = channel_factory(channel_data, self) + + self._channels[channel.id] = channel + return channel + + except DisagreementException as e: # Includes HTTPException + print(f"Failed to fetch channel {channel_id}: {e}") + return None + + # --- Application Command Methods --- + async def process_interaction(self, interaction: Interaction) -> None: + """Internal method to process an interaction from the gateway.""" + + if hasattr(self, "on_interaction_create"): + asyncio.create_task(self.on_interaction_create(interaction)) + # Route component interactions to the appropriate View + if ( + interaction.type == InteractionType.MESSAGE_COMPONENT + and interaction.message + ): + view = self._views.get(interaction.message.id) + if view: + asyncio.create_task(view._dispatch(interaction)) + return + + await self.app_command_handler.process_interaction(interaction) + + async def sync_application_commands( + self, guild_id: Optional[Snowflake] = None + ) -> None: + """Synchronizes application commands with Discord.""" + + if not self.application_id: + print( + "Warning: Cannot sync application commands, application_id is not set. " + "Ensure the client is connected and READY." + ) + return + if not self.is_ready(): + print( + "Warning: Client is not ready. Waiting for client to be ready before syncing commands." + ) + await self.wait_until_ready() + if not self.application_id: + print( + "Error: application_id still not set after client is ready. Cannot sync commands." + ) + return + + await self.app_command_handler.sync_commands( + application_id=self.application_id, guild_id=guild_id + ) + + async def on_interaction_create(self, interaction: Interaction) -> None: + """|coro| Called when an interaction is created.""" + + pass + + async def on_presence_update(self, presence) -> None: + """|coro| Called when a user's presence is updated.""" + + pass + + async def on_typing_start(self, typing) -> None: + """|coro| Called when a user starts typing in a channel.""" + + pass + + async def on_app_command_error( + self, context: AppCommandContext, error: Exception + ) -> None: + """Default error handler for application commands.""" + + print( + f"Error in application command '{context.command.name if context.command else 'unknown'}': {error}" + ) + try: + if not context._responded: + await context.send( + "An error occurred while running this command.", ephemeral=True + ) + except Exception as e: + print(f"Failed to send error message for app command: {e}") + + async def on_error( + self, event_method: str, exc: Exception, *args: Any, **kwargs: Any + ) -> None: + """Default event listener error handler.""" + + print(f"Unhandled exception in event listener for '{event_method}':") + print(f"{type(exc).__name__}: {exc}") diff --git a/disagreement/components.py b/disagreement/components.py new file mode 100644 index 0000000..11da89a --- /dev/null +++ b/disagreement/components.py @@ -0,0 +1,166 @@ +"""Message component utilities.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, TYPE_CHECKING + +from .enums import ComponentType, ButtonStyle, ChannelType, TextInputStyle +from .models import ( + ActionRow, + Button, + Component, + SelectMenu, + SelectOption, + PartialEmoji, + PartialEmoji, + Section, + TextDisplay, + Thumbnail, + MediaGallery, + MediaGalleryItem, + File, + Separator, + Container, + UnfurledMediaItem, +) + +if TYPE_CHECKING: # pragma: no cover - optional client for future use + from .client import Client + + +def component_factory( + data: Dict[str, Any], client: Optional["Client"] = None +) -> "Component": + """Create a component object from raw API data.""" + ctype = ComponentType(data["type"]) + + if ctype == ComponentType.ACTION_ROW: + row = ActionRow() + for comp in data.get("components", []): + row.add_component(component_factory(comp, client)) + return row + + if ctype == ComponentType.BUTTON: + return Button( + style=ButtonStyle(data["style"]), + label=data.get("label"), + emoji=PartialEmoji(data["emoji"]) if data.get("emoji") else None, + custom_id=data.get("custom_id"), + url=data.get("url"), + disabled=data.get("disabled", False), + ) + + if ctype in { + ComponentType.STRING_SELECT, + ComponentType.USER_SELECT, + ComponentType.ROLE_SELECT, + ComponentType.MENTIONABLE_SELECT, + ComponentType.CHANNEL_SELECT, + }: + options = [ + SelectOption( + label=o["label"], + value=o["value"], + description=o.get("description"), + emoji=PartialEmoji(o["emoji"]) if o.get("emoji") else None, + default=o.get("default", False), + ) + for o in data.get("options", []) + ] + channel_types = None + if ctype == ComponentType.CHANNEL_SELECT and data.get("channel_types"): + channel_types = [ChannelType(ct) for ct in data.get("channel_types", [])] + + return SelectMenu( + custom_id=data["custom_id"], + options=options, + placeholder=data.get("placeholder"), + min_values=data.get("min_values", 1), + max_values=data.get("max_values", 1), + disabled=data.get("disabled", False), + channel_types=channel_types, + type=ctype, + ) + + if ctype == ComponentType.TEXT_INPUT: + from .ui.modal import TextInput + + return TextInput( + label=data.get("label", ""), + custom_id=data.get("custom_id"), + style=TextInputStyle(data.get("style", TextInputStyle.SHORT.value)), + placeholder=data.get("placeholder"), + required=data.get("required", True), + min_length=data.get("min_length"), + max_length=data.get("max_length"), + ) + + if ctype == ComponentType.SECTION: + # The components in a section can only be TextDisplay + section_components = [] + for c in data.get("components", []): + comp = component_factory(c, client) + if isinstance(comp, TextDisplay): + section_components.append(comp) + + accessory = None + if data.get("accessory"): + acc_comp = component_factory(data["accessory"], client) + if isinstance(acc_comp, (Thumbnail, Button)): + accessory = acc_comp + + return Section( + components=section_components, + accessory=accessory, + id=data.get("id"), + ) + + if ctype == ComponentType.TEXT_DISPLAY: + return TextDisplay(content=data["content"], id=data.get("id")) + + if ctype == ComponentType.THUMBNAIL: + return Thumbnail( + media=UnfurledMediaItem(**data["media"]), + description=data.get("description"), + spoiler=data.get("spoiler", False), + id=data.get("id"), + ) + + if ctype == ComponentType.MEDIA_GALLERY: + return MediaGallery( + items=[ + MediaGalleryItem( + media=UnfurledMediaItem(**i["media"]), + description=i.get("description"), + spoiler=i.get("spoiler", False), + ) + for i in data.get("items", []) + ], + id=data.get("id"), + ) + + if ctype == ComponentType.FILE: + return File( + file=UnfurledMediaItem(**data["file"]), + spoiler=data.get("spoiler", False), + id=data.get("id"), + ) + + if ctype == ComponentType.SEPARATOR: + return Separator( + divider=data.get("divider", True), + spacing=data.get("spacing", 1), + id=data.get("id"), + ) + + if ctype == ComponentType.CONTAINER: + return Container( + components=[ + component_factory(c, client) for c in data.get("components", []) + ], + accent_color=data.get("accent_color"), + spoiler=data.get("spoiler", False), + id=data.get("id"), + ) + + raise ValueError(f"Unsupported component type: {ctype}") diff --git a/disagreement/enums.py b/disagreement/enums.py new file mode 100644 index 0000000..d224576 --- /dev/null +++ b/disagreement/enums.py @@ -0,0 +1,357 @@ +# disagreement/enums.py + +""" +Enums for Discord constants. +""" + +from enum import IntEnum, Enum # Import Enum + + +class GatewayOpcode(IntEnum): + """Represents a Discord Gateway Opcode.""" + + DISPATCH = 0 + HEARTBEAT = 1 + IDENTIFY = 2 + PRESENCE_UPDATE = 3 + VOICE_STATE_UPDATE = 4 + RESUME = 6 + RECONNECT = 7 + REQUEST_GUILD_MEMBERS = 8 + INVALID_SESSION = 9 + HELLO = 10 + HEARTBEAT_ACK = 11 + + +class GatewayIntent(IntEnum): + """Represents a Discord Gateway Intent bit. + + Intents are used to subscribe to specific groups of events from the Gateway. + """ + + GUILDS = 1 << 0 + GUILD_MEMBERS = 1 << 1 # Privileged + GUILD_MODERATION = 1 << 2 # Formerly GUILD_BANS + GUILD_EMOJIS_AND_STICKERS = 1 << 3 + GUILD_INTEGRATIONS = 1 << 4 + GUILD_WEBHOOKS = 1 << 5 + GUILD_INVITES = 1 << 6 + GUILD_VOICE_STATES = 1 << 7 + GUILD_PRESENCES = 1 << 8 # Privileged + GUILD_MESSAGES = 1 << 9 + GUILD_MESSAGE_REACTIONS = 1 << 10 + GUILD_MESSAGE_TYPING = 1 << 11 + DIRECT_MESSAGES = 1 << 12 + DIRECT_MESSAGE_REACTIONS = 1 << 13 + DIRECT_MESSAGE_TYPING = 1 << 14 + MESSAGE_CONTENT = 1 << 15 # Privileged (as of Aug 31, 2022) + GUILD_SCHEDULED_EVENTS = 1 << 16 + AUTO_MODERATION_CONFIGURATION = 1 << 20 + AUTO_MODERATION_EXECUTION = 1 << 21 + + @classmethod + def default(cls) -> int: + """Returns default intents (excluding privileged ones like members, presences, message content).""" + return ( + cls.GUILDS + | cls.GUILD_MODERATION + | cls.GUILD_EMOJIS_AND_STICKERS + | cls.GUILD_INTEGRATIONS + | cls.GUILD_WEBHOOKS + | cls.GUILD_INVITES + | cls.GUILD_VOICE_STATES + | cls.GUILD_MESSAGES + | cls.GUILD_MESSAGE_REACTIONS + | cls.GUILD_MESSAGE_TYPING + | cls.DIRECT_MESSAGES + | cls.DIRECT_MESSAGE_REACTIONS + | cls.DIRECT_MESSAGE_TYPING + | cls.GUILD_SCHEDULED_EVENTS + | cls.AUTO_MODERATION_CONFIGURATION + | cls.AUTO_MODERATION_EXECUTION + ) + + @classmethod + def all(cls) -> int: + """Returns all intents, including privileged ones. Use with caution.""" + val = 0 + for intent in cls: + val |= intent.value + return val + + @classmethod + def privileged(cls) -> int: + """Returns a bitmask of all privileged intents.""" + return cls.GUILD_MEMBERS | cls.GUILD_PRESENCES | cls.MESSAGE_CONTENT + + +# --- Application Command Enums --- + + +class ApplicationCommandType(IntEnum): + """Type of application command.""" + + CHAT_INPUT = 1 + USER = 2 + MESSAGE = 3 + PRIMARY_ENTRY_POINT = 4 + + +class ApplicationCommandOptionType(IntEnum): + """Type of application command option.""" + + SUB_COMMAND = 1 + SUB_COMMAND_GROUP = 2 + STRING = 3 + INTEGER = 4 # Any integer between -2^53 and 2^53 + BOOLEAN = 5 + USER = 6 + CHANNEL = 7 # Includes all channel types + categories + ROLE = 8 + MENTIONABLE = 9 # Includes users and roles + NUMBER = 10 # Any double between -2^53 and 2^53 + ATTACHMENT = 11 + + +class InteractionType(IntEnum): + """Type of interaction.""" + + PING = 1 + APPLICATION_COMMAND = 2 + MESSAGE_COMPONENT = 3 + APPLICATION_COMMAND_AUTOCOMPLETE = 4 + MODAL_SUBMIT = 5 + + +class InteractionCallbackType(IntEnum): + """Type of interaction callback.""" + + PONG = 1 + CHANNEL_MESSAGE_WITH_SOURCE = 4 + DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE = 5 + DEFERRED_UPDATE_MESSAGE = 6 + UPDATE_MESSAGE = 7 + APPLICATION_COMMAND_AUTOCOMPLETE_RESULT = 8 + MODAL = 9 # Response to send a modal + + +class IntegrationType(IntEnum): + """ + Installation context(s) where the command is available, + only for globally-scoped commands. + """ + + GUILD_INSTALL = ( + 0 # Command is available when the app is installed to a guild (default) + ) + USER_INSTALL = 1 # Command is available when the app is installed to a user + + +class InteractionContextType(IntEnum): + """ + Interaction context(s) where the command can be used, + only for globally-scoped commands. + """ + + GUILD = 0 # Command can be used in guilds + BOT_DM = 1 # Command can be used in DMs with the app's bot user + PRIVATE_CHANNEL = 2 # Command can be used in Group DMs and DMs (requires USER_INSTALL integration_type) + + +class MessageFlags(IntEnum): + """Represents the flags of a message.""" + + CROSSPOSTED = 1 << 0 + IS_CROSSPOST = 1 << 1 + SUPPRESS_EMBEDS = 1 << 2 + SOURCE_MESSAGE_DELETED = 1 << 3 + URGENT = 1 << 4 + HAS_THREAD = 1 << 5 + EPHEMERAL = 1 << 6 + LOADING = 1 << 7 + FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8 + SUPPRESS_NOTIFICATIONS = ( + 1 << 12 + ) # Discord specific, was previously 1 << 4 (IS_VOICE_MESSAGE) + IS_COMPONENTS_V2 = 1 << 15 + + +# --- Guild Enums --- + + +class VerificationLevel(IntEnum): + """Guild verification level.""" + + NONE = 0 + LOW = 1 + MEDIUM = 2 + HIGH = 3 + VERY_HIGH = 4 + + +class MessageNotificationLevel(IntEnum): + """Default message notification level for a guild.""" + + ALL_MESSAGES = 0 + ONLY_MENTIONS = 1 + + +class ExplicitContentFilterLevel(IntEnum): + """Explicit content filter level for a guild.""" + + DISABLED = 0 + MEMBERS_WITHOUT_ROLES = 1 + ALL_MEMBERS = 2 + + +class MFALevel(IntEnum): + """Multi-Factor Authentication level for a guild.""" + + NONE = 0 + ELEVATED = 1 + + +class GuildNSFWLevel(IntEnum): + """NSFW level of a guild.""" + + DEFAULT = 0 + EXPLICIT = 1 + SAFE = 2 + AGE_RESTRICTED = 3 + + +class PremiumTier(IntEnum): + """Guild premium tier (boost level).""" + + NONE = 0 + TIER_1 = 1 + TIER_2 = 2 + TIER_3 = 3 + + +class GuildFeature(str, Enum): # Changed from IntEnum to Enum + """Features that a guild can have. + + Note: This is not an exhaustive list and Discord may add more. + Using str as a base allows for unknown features to be stored as strings. + """ + + ANIMATED_BANNER = "ANIMATED_BANNER" + ANIMATED_ICON = "ANIMATED_ICON" + APPLICATION_COMMAND_PERMISSIONS_V2 = "APPLICATION_COMMAND_PERMISSIONS_V2" + AUTO_MODERATION = "AUTO_MODERATION" + BANNER = "BANNER" + COMMUNITY = "COMMUNITY" + CREATOR_MONETIZABLE_PROVISIONAL = "CREATOR_MONETIZABLE_PROVISIONAL" + CREATOR_STORE_PAGE = "CREATOR_STORE_PAGE" + DEVELOPER_SUPPORT_SERVER = "DEVELOPER_SUPPORT_SERVER" + DISCOVERABLE = "DISCOVERABLE" + FEATURABLE = "FEATURABLE" + INVITES_DISABLED = "INVITES_DISABLED" + INVITE_SPLASH = "INVITE_SPLASH" + MEMBER_VERIFICATION_GATE_ENABLED = "MEMBER_VERIFICATION_GATE_ENABLED" + MORE_STICKERS = "MORE_STICKERS" + NEWS = "NEWS" + PARTNERED = "PARTNERED" + PREVIEW_ENABLED = "PREVIEW_ENABLED" + RAID_ALERTS_DISABLED = "RAID_ALERTS_DISABLED" + ROLE_ICONS = "ROLE_ICONS" + ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE = ( + "ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE" + ) + ROLE_SUBSCRIPTIONS_ENABLED = "ROLE_SUBSCRIPTIONS_ENABLED" + TICKETED_EVENTS_ENABLED = "TICKETED_EVENTS_ENABLED" + VANITY_URL = "VANITY_URL" + VERIFIED = "VERIFIED" + VIP_REGIONS = "VIP_REGIONS" + WELCOME_SCREEN_ENABLED = "WELCOME_SCREEN_ENABLED" + # Add more as they become known or needed + + # This allows GuildFeature("UNKNOWN_FEATURE_STRING") to work + @classmethod + def _missing_(cls, value): # type: ignore + return str(value) + + +# --- Channel Enums --- + + +class ChannelType(IntEnum): + """Type of channel.""" + + GUILD_TEXT = 0 # a text channel within a server + DM = 1 # a direct message between users + GUILD_VOICE = 2 # a voice channel within a server + GROUP_DM = 3 # a direct message between multiple users + GUILD_CATEGORY = 4 # an organizational category that contains up to 50 channels + GUILD_ANNOUNCEMENT = 5 # a channel that users can follow and crosspost into their own server (formerly GUILD_NEWS) + ANNOUNCEMENT_THREAD = ( + 10 # a temporary sub-channel within a GUILD_ANNOUNCEMENT channel + ) + PUBLIC_THREAD = ( + 11 # a temporary sub-channel within a GUILD_TEXT or GUILD_ANNOUNCEMENT channel + ) + PRIVATE_THREAD = 12 # a temporary sub-channel within a GUILD_TEXT channel that is only viewable by those invited and those with the MANAGE_THREADS permission + GUILD_STAGE_VOICE = ( + 13 # a voice channel for hosting events with speakers and audiences + ) + GUILD_DIRECTORY = 14 # a channel in a hub containing the listed servers + GUILD_FORUM = 15 # (Still in development) a channel that can only contain threads + GUILD_MEDIA = 16 # (Still in development) a channel that can only contain media + + +class OverwriteType(IntEnum): + """Type of target for a permission overwrite.""" + + ROLE = 0 + MEMBER = 1 + + +# --- Component Enums --- + + +class ComponentType(IntEnum): + """Type of message component.""" + + ACTION_ROW = 1 + BUTTON = 2 + STRING_SELECT = 3 # Formerly SELECT_MENU + TEXT_INPUT = 4 + USER_SELECT = 5 + ROLE_SELECT = 6 + MENTIONABLE_SELECT = 7 + CHANNEL_SELECT = 8 + SECTION = 9 + TEXT_DISPLAY = 10 + THUMBNAIL = 11 + MEDIA_GALLERY = 12 + FILE = 13 + SEPARATOR = 14 + CONTAINER = 17 + + +class ButtonStyle(IntEnum): + """Style of a button component.""" + + # Blurple + PRIMARY = 1 + # Grey + SECONDARY = 2 + # Green + SUCCESS = 3 + # Red + DANGER = 4 + # Grey, navigates to a URL + LINK = 5 + + +class TextInputStyle(IntEnum): + """Style of a text input component.""" + + SHORT = 1 + PARAGRAPH = 2 + + +# Example of how you might combine intents: +# intents = GatewayIntent.GUILDS | GatewayIntent.GUILD_MESSAGES | GatewayIntent.MESSAGE_CONTENT +# client = Client(token="YOUR_TOKEN", intents=intents) diff --git a/disagreement/error_handler.py b/disagreement/error_handler.py new file mode 100644 index 0000000..0240cca --- /dev/null +++ b/disagreement/error_handler.py @@ -0,0 +1,33 @@ +import asyncio +import logging +import traceback +from typing import Optional + +from .logging_config import setup_logging + + +def setup_global_error_handler( + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> None: + """Configure a basic global error handler for the provided loop. + + The handler logs unhandled exceptions so they don't crash the bot. + """ + if loop is None: + loop = asyncio.get_event_loop() + + if not logging.getLogger().hasHandlers(): + setup_logging(logging.ERROR) + + def handle_exception(loop: asyncio.AbstractEventLoop, context: dict) -> None: + exception = context.get("exception") + if exception: + logging.error("Unhandled exception in event loop: %s", exception) + traceback.print_exception( + type(exception), exception, exception.__traceback__ + ) + else: + message = context.get("message") + logging.error("Event loop error: %s", message) + + loop.set_exception_handler(handle_exception) diff --git a/disagreement/errors.py b/disagreement/errors.py new file mode 100644 index 0000000..df42905 --- /dev/null +++ b/disagreement/errors.py @@ -0,0 +1,112 @@ +# disagreement/errors.py + +""" +Custom exceptions for the Disagreement library. +""" + +from typing import Optional, Any # Add Optional and Any here + + +class DisagreementException(Exception): + """Base exception class for all errors raised by this library.""" + + pass + + +class HTTPException(DisagreementException): + """Exception raised for HTTP-related errors. + + Attributes: + response: The aiohttp response object, if available. + status: The HTTP status code. + text: The response text, if available. + error_code: Discord specific error code, if available. + """ + + def __init__( + self, response=None, message=None, *, status=None, text=None, error_code=None + ): + self.response = response + self.status = status or (response.status if response else None) + self.text = text or ( + response.text if response else None + ) # Or await response.text() if in async context + self.error_code = error_code + + full_message = f"HTTP {self.status}" + if message: + full_message += f": {message}" + elif self.text: + full_message += f": {self.text}" + if self.error_code: + full_message += f" (Discord Error Code: {self.error_code})" + + super().__init__(full_message) + + +class GatewayException(DisagreementException): + """Exception raised for errors related to the Discord Gateway connection or protocol.""" + + pass + + +class AuthenticationError(DisagreementException): + """Exception raised for authentication failures (e.g., invalid token).""" + + pass + + +class RateLimitError(HTTPException): + """ + Exception raised when a rate limit is encountered. + + Attributes: + retry_after (float): The number of seconds to wait before retrying. + is_global (bool): Whether this is a global rate limit. + """ + + def __init__( + self, response, message=None, *, retry_after: float, is_global: bool = False + ): + self.retry_after = retry_after + self.is_global = is_global + super().__init__( + response, + message + or f"Rate limited. Retry after: {retry_after}s. Global: {is_global}", + ) + + +# You can add more specific exceptions as needed, e.g.: +# class NotFound(HTTPException): +# """Raised for 404 Not Found errors.""" +# pass + +# class Forbidden(HTTPException): +# """Raised for 403 Forbidden errors.""" +# pass + + +class AppCommandError(DisagreementException): + """Base exception for application command related errors.""" + + pass + + +class AppCommandOptionConversionError(AppCommandError): + """Exception raised when an application command option fails to convert.""" + + def __init__( + self, + message: str, + option_name: Optional[str] = None, + original_value: Any = None, + ): + self.option_name = option_name + self.original_value = original_value + full_message = message + if option_name: + full_message = f"Failed to convert option '{option_name}': {message}" + if original_value is not None: + full_message += f" (Original value: '{original_value}')" + super().__init__(full_message) diff --git a/disagreement/event_dispatcher.py b/disagreement/event_dispatcher.py new file mode 100644 index 0000000..7938d59 --- /dev/null +++ b/disagreement/event_dispatcher.py @@ -0,0 +1,243 @@ +# disagreement/event_dispatcher.py + +""" +Event dispatcher for handling Discord Gateway events. +""" + +import asyncio +import inspect +from collections import defaultdict +from typing import ( + Callable, + Coroutine, + Any, + Dict, + List, + Set, + TYPE_CHECKING, + Awaitable, + Optional, +) + +from .models import Message, User # Assuming User might be part of other events +from .errors import DisagreementException + +if TYPE_CHECKING: + from .client import Client # For type hinting to avoid circular imports + from .interactions import Interaction + +# Type alias for an event listener +EventListener = Callable[..., Awaitable[None]] + + +class EventDispatcher: + """ + Manages registration and dispatching of event listeners. + """ + + def __init__(self, client_instance: "Client"): + self._client: "Client" = client_instance + self._listeners: Dict[str, List[EventListener]] = defaultdict(list) + self._waiters: Dict[ + str, List[tuple[asyncio.Future, Optional[Callable[[Any], bool]]]] + ] = defaultdict(list) + self.on_dispatch_error: Optional[ + Callable[[str, Exception, EventListener], Awaitable[None]] + ] = None + # Pre-defined parsers for specific event types to convert raw data to models + self._event_parsers: Dict[str, Callable[[Dict[str, Any]], Any]] = { + "MESSAGE_CREATE": self._parse_message_create, + "INTERACTION_CREATE": self._parse_interaction_create, + "GUILD_CREATE": self._parse_guild_create, + "CHANNEL_CREATE": self._parse_channel_create, + "PRESENCE_UPDATE": self._parse_presence_update, + "TYPING_START": self._parse_typing_start, + } + + def _parse_message_create(self, data: Dict[str, Any]) -> Message: + """Parses raw MESSAGE_CREATE data into a Message object.""" + return self._client.parse_message(data) + + def _parse_interaction_create(self, data: Dict[str, Any]) -> "Interaction": + """Parses raw INTERACTION_CREATE data into an Interaction object.""" + from .interactions import Interaction + + return Interaction(data=data, client_instance=self._client) + + def _parse_guild_create(self, data: Dict[str, Any]): + """Parses raw GUILD_CREATE data into a Guild object.""" + + return self._client.parse_guild(data) + + def _parse_channel_create(self, data: Dict[str, Any]): + """Parses raw CHANNEL_CREATE data into a Channel object.""" + + return self._client.parse_channel(data) + + def _parse_presence_update(self, data: Dict[str, Any]): + """Parses raw PRESENCE_UPDATE data into a PresenceUpdate object.""" + + from .models import PresenceUpdate + + return PresenceUpdate(data, client_instance=self._client) + + def _parse_typing_start(self, data: Dict[str, Any]): + """Parses raw TYPING_START data into a TypingStart object.""" + + from .models import TypingStart + + return TypingStart(data, client_instance=self._client) + + # Potentially add _parse_user for events that directly provide a full user object + # def _parse_user_update(self, data: Dict[str, Any]) -> User: + # return User(data=data) + + def register(self, event_name: str, coro: EventListener): + """ + Registers a coroutine function to listen for a specific event. + + Args: + event_name (str): The name of the event (e.g., 'MESSAGE_CREATE'). + coro (Callable): The coroutine function to call when the event occurs. + It should accept arguments appropriate for the event. + + Raises: + TypeError: If the provided callback is not a coroutine function. + """ + if not inspect.iscoroutinefunction(coro): + raise TypeError( + f"Event listener for '{event_name}' must be a coroutine function (async def)." + ) + + # Normalize event name, e.g., 'on_message' -> 'MESSAGE_CREATE' + # For now, we assume event_name is already the Discord event type string. + # If using decorators like @client.on_message, the decorator would handle this mapping. + self._listeners[event_name.upper()].append(coro) + + def unregister(self, event_name: str, coro: EventListener): + """ + Unregisters a coroutine function from an event. + + Args: + event_name (str): The name of the event. + coro (Callable): The coroutine function to unregister. + """ + event_name_upper = event_name.upper() + if event_name_upper in self._listeners: + try: + self._listeners[event_name_upper].remove(coro) + except ValueError: + pass # Listener not in list + + def add_waiter( + self, + event_name: str, + future: asyncio.Future, + check: Optional[Callable[[Any], bool]] = None, + ) -> None: + self._waiters[event_name.upper()].append((future, check)) + + def remove_waiter(self, event_name: str, future: asyncio.Future) -> None: + waiters = self._waiters.get(event_name.upper()) + if not waiters: + return + self._waiters[event_name.upper()] = [ + (f, c) for f, c in waiters if f is not future + ] + if not self._waiters[event_name.upper()]: + self._waiters.pop(event_name.upper(), None) + + def _resolve_waiters(self, event_name: str, data: Any) -> None: + waiters = self._waiters.get(event_name) + if not waiters: + return + to_remove: List[tuple[asyncio.Future, Optional[Callable[[Any], bool]]]] = [] + for future, check in waiters: + if future.cancelled(): + to_remove.append((future, check)) + continue + try: + if check is None or check(data): + future.set_result(data) + to_remove.append((future, check)) + except Exception as exc: + future.set_exception(exc) + to_remove.append((future, check)) + for item in to_remove: + if item in waiters: + waiters.remove(item) + if not waiters: + self._waiters.pop(event_name, None) + + async def dispatch(self, event_name: str, raw_data: Dict[str, Any]): + """ + Dispatches an event to all registered listeners. + + Args: + event_name (str): The name of the event (e.g., 'MESSAGE_CREATE'). + raw_data (Dict[str, Any]): The raw data payload from the Discord Gateway for this event. + """ + event_name_upper = event_name.upper() + listeners = self._listeners.get(event_name_upper) + + if not listeners: + # print(f"No listeners for event {event_name_upper}") + return + + parsed_data: Any = raw_data + if event_name_upper in self._event_parsers: + try: + parser = self._event_parsers[event_name_upper] + parsed_data = parser(raw_data) + except Exception as e: + print(f"Error parsing event data for {event_name_upper}: {e}") + # Optionally, dispatch with raw_data or raise, or log more formally + # For now, we'll proceed to dispatch with raw_data if parsing fails, + # or just log and return if parsed_data is critical. + # Let's assume if a parser exists, its output is critical. + return + + self._resolve_waiters(event_name_upper, parsed_data) + # print(f"Dispatching event {event_name_upper} with data: {parsed_data} to {len(listeners)} listeners.") + for listener in listeners: + try: + # Inspect the listener to see how many arguments it expects + sig = inspect.signature(listener) + num_params = len(sig.parameters) + + if num_params == 0: # Listener takes no arguments + await listener() + elif ( + num_params == 1 + ): # Listener takes one argument (the parsed data or model) + await listener(parsed_data) + # elif num_params == 2 and event_name_upper == "MESSAGE_CREATE": # Special case for (client, message) + # await listener(self._client, parsed_data) # This might be too specific here + else: + # Fallback or error if signature doesn't match expected patterns + # For now, assume one arg is the most common for parsed data. + # Or, if you want to be strict: + print( + f"Warning: Listener {listener.__name__} for {event_name_upper} has an unhandled number of parameters ({num_params}). Skipping or attempting with one arg." + ) + if num_params > 0: # Try with one arg if it takes any + await listener(parsed_data) + + except Exception as e: + callback = self.on_dispatch_error + if callback is not None: + try: + await callback(event_name_upper, e, listener) + + except Exception as hook_error: + print(f"Error in on_dispatch_error hook itself: {hook_error}") + else: + # Default error handling if no hook is set + print( + f"Error in event listener {listener.__name__} for {event_name_upper}: {e}" + ) + if hasattr(self._client, "on_error"): + try: + await self._client.on_error(event_name_upper, e, listener) + except Exception as client_err_e: + print(f"Error in client.on_error itself: {client_err_e}") diff --git a/disagreement/ext/app_commands/__init__.py b/disagreement/ext/app_commands/__init__.py new file mode 100644 index 0000000..39c541e --- /dev/null +++ b/disagreement/ext/app_commands/__init__.py @@ -0,0 +1,46 @@ +# disagreement/ext/app_commands/__init__.py + +""" +Application Commands Extension for Disagreement. + +This package provides the framework for creating and handling +Discord Application Commands (slash commands, user commands, message commands). +""" + +from .commands import ( + AppCommand, + SlashCommand, + UserCommand, + MessageCommand, + AppCommandGroup, +) +from .decorators import ( + slash_command, + user_command, + message_command, + hybrid_command, + group, + subcommand, + subcommand_group, + OptionMetadata, +) +from .context import AppCommandContext + +# from .handler import AppCommandHandler # Will be imported when defined + +__all__ = [ + "AppCommand", + "SlashCommand", + "UserCommand", + "MessageCommand", + "AppCommandGroup", # To be defined + "slash_command", + "user_command", + "message_command", + "hybrid_command", + "group", + "subcommand", + "subcommand_group", + "OptionMetadata", + "AppCommandContext", # To be defined +] diff --git a/disagreement/ext/app_commands/commands.py b/disagreement/ext/app_commands/commands.py new file mode 100644 index 0000000..6d2823d --- /dev/null +++ b/disagreement/ext/app_commands/commands.py @@ -0,0 +1,513 @@ +# disagreement/ext/app_commands/commands.py + +import inspect +from typing import Callable, Optional, List, Dict, Any, Union, TYPE_CHECKING + + +if TYPE_CHECKING: + from disagreement.ext.commands.core import ( + Command as PrefixCommand, + ) # Alias to avoid name clash + from disagreement.interactions import ApplicationCommandOption, Snowflake + from disagreement.enums import ( + ApplicationCommandType, + IntegrationType, + InteractionContextType, + ApplicationCommandOptionType, # Added + ) + from disagreement.ext.commands.cog import Cog # Corrected import path + +# Placeholder for Cog if not using the existing one or if it needs adaptation +if not TYPE_CHECKING: + # This dynamic Cog = Any might not be ideal if Cog is used in runtime type checks. + # However, for type hinting purposes when TYPE_CHECKING is false, it avoids import. + # If Cog is needed at runtime by this module (it is, for AppCommand.cog type hint), + # it should be imported directly. + # For now, the TYPE_CHECKING block handles the proper import for static analysis. + # Let's ensure Cog is available at runtime if AppCommand.cog is accessed. + # A simple way is to import it outside TYPE_CHECKING too, if it doesn't cause circularity. + # Given its usage, a forward reference string 'Cog' might be better in AppCommand.cog type hint. + # Let's try importing it directly for runtime, assuming no circularity with this specific module. + try: + from disagreement.ext.commands.cog import Cog + except ImportError: + Cog = Any # Fallback if direct import fails (e.g. during partial builds/tests) + # Import PrefixCommand at runtime for HybridCommand + try: + from disagreement.ext.commands.core import Command as PrefixCommand + except Exception: # pragma: no cover - safeguard against unusual import issues + PrefixCommand = Any # type: ignore + # Import enums used at runtime + try: + from disagreement.enums import ( + ApplicationCommandType, + IntegrationType, + InteractionContextType, + ApplicationCommandOptionType, + ) + from disagreement.interactions import ApplicationCommandOption, Snowflake + except Exception: # pragma: no cover + ApplicationCommandType = ApplicationCommandOptionType = IntegrationType = ( + InteractionContextType + ) = Any # type: ignore + ApplicationCommandOption = Snowflake = Any # type: ignore +else: # When TYPE_CHECKING is true, Cog and PrefixCommand are already imported above. + pass + + +class AppCommand: + """ + Base class for an application command. + + Attributes: + name (str): The name of the command. + description (Optional[str]): The description of the command. + Required for CHAT_INPUT, empty for USER and MESSAGE commands. + callback (Callable[..., Any]): The coroutine function that will be called when the command is executed. + type (ApplicationCommandType): The type of the application command. + options (Optional[List[ApplicationCommandOption]]): Parameters for the command. Populated by decorators. + guild_ids (Optional[List[Snowflake]]): List of guild IDs where this command is active. None for global. + default_member_permissions (Optional[str]): Bitwise permissions required by default for users to run the command. + nsfw (bool): Whether the command is age-restricted. + parent (Optional['AppCommandGroup']): The parent group if this is a subcommand. + cog (Optional[Cog]): The cog this command belongs to, if any. + _full_description (Optional[str]): Stores the original full description, e.g. from docstring, + even if the payload description is different (like for User/Message commands). + name_localizations (Optional[Dict[str, str]]): Localizations for the command's name. + description_localizations (Optional[Dict[str, str]]): Localizations for the command's description. + integration_types (Optional[List[IntegrationType]]): Installation contexts. + contexts (Optional[List[InteractionContextType]]): Interaction contexts. + """ + + def __init__( + self, + callback: Callable[..., Any], + *, + name: str, + description: Optional[str] = None, + locale: Optional[str] = None, + type: "ApplicationCommandType", + guild_ids: Optional[List["Snowflake"]] = None, + default_member_permissions: Optional[str] = None, + nsfw: bool = False, + parent: Optional["AppCommandGroup"] = None, + cog: Optional[ + Any + ] = None, # Changed 'Cog' to Any to avoid runtime import issues if Cog is complex + name_localizations: Optional[Dict[str, str]] = None, + description_localizations: Optional[Dict[str, str]] = None, + integration_types: Optional[List["IntegrationType"]] = None, + contexts: Optional[List["InteractionContextType"]] = None, + ): + if not asyncio.iscoroutinefunction(callback): + raise TypeError( + "Application command callback must be a coroutine function." + ) + + if locale: + from disagreement import i18n + + translate = i18n.translate + + self.name = translate(name, locale) + self.description = ( + translate(description, locale) if description is not None else None + ) + else: + self.name = name + self.description = description + self.locale: Optional[str] = locale + self.callback: Callable[..., Any] = callback + self.type: "ApplicationCommandType" = type + self.options: List["ApplicationCommandOption"] = [] # Populated by decorator + self.guild_ids: Optional[List["Snowflake"]] = guild_ids + self.default_member_permissions: Optional[str] = default_member_permissions + self.nsfw: bool = nsfw + self.parent: Optional["AppCommandGroup"] = parent + self.cog: Optional[Any] = cog # Changed 'Cog' to Any + self.name_localizations: Optional[Dict[str, str]] = name_localizations + self.description_localizations: Optional[Dict[str, str]] = ( + description_localizations + ) + self.integration_types: Optional[List["IntegrationType"]] = integration_types + self.contexts: Optional[List["InteractionContextType"]] = contexts + self._full_description: Optional[str] = ( + None # Initialized by decorator if needed + ) + + # Signature for argument parsing by decorators/handlers + self.params = inspect.signature(callback).parameters + + async def invoke( + self, context: "AppCommandContext", *args: Any, **kwargs: Any + ) -> None: + """Invokes the command's callback with the given context and arguments.""" + # Similar to Command.invoke, handle cog if present + actual_args = [] + if self.cog: + actual_args.append(self.cog) + actual_args.append(context) + actual_args.extend(args) + + await self.callback(*actual_args, **kwargs) + + def to_dict(self) -> Dict[str, Any]: + """Converts the command to a dictionary payload for Discord API.""" + payload: Dict[str, Any] = { + "name": self.name, + "type": self.type.value, + # CHAT_INPUT commands require a description. + # USER and MESSAGE commands must have an empty description in the payload if not omitted. + # The constructor for UserCommand/MessageCommand already sets self.description to "" + "description": ( + self.description + if self.type == ApplicationCommandType.CHAT_INPUT + else "" + ), + } + + # For CHAT_INPUT commands, options are its parameters. + # For USER/MESSAGE commands, options should be empty or not present. + if self.type == ApplicationCommandType.CHAT_INPUT and self.options: + payload["options"] = [opt.to_dict() for opt in self.options] + + if self.default_member_permissions is not None: # Can be "0" for no permissions + payload["default_member_permissions"] = str(self.default_member_permissions) + + # nsfw defaults to False, only include if True + if self.nsfw: + payload["nsfw"] = True + + if self.name_localizations: + payload["name_localizations"] = self.name_localizations + + # Description localizations only apply if there's a description (CHAT_INPUT commands) + if ( + self.type == ApplicationCommandType.CHAT_INPUT + and self.description + and self.description_localizations + ): + payload["description_localizations"] = self.description_localizations + + if self.integration_types: + payload["integration_types"] = [it.value for it in self.integration_types] + + if self.contexts: + payload["contexts"] = [ict.value for ict in self.contexts] + + # According to Discord API, guild_id is not part of this payload, + # it's used in the URL path for guild-specific command registration. + # However, the global command registration takes an 'application_id' in the payload, + # but that's handled by the HTTPClient. + + return payload + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} name='{self.name}' type={self.type!r}>" + + +class SlashCommand(AppCommand): + """Represents a CHAT_INPUT (slash) command.""" + + def __init__(self, callback: Callable[..., Any], **kwargs: Any): + if not kwargs.get("description"): + raise ValueError("SlashCommand requires a description.") + super().__init__(callback, type=ApplicationCommandType.CHAT_INPUT, **kwargs) + + +class UserCommand(AppCommand): + """Represents a USER context menu command.""" + + def __init__(self, callback: Callable[..., Any], **kwargs: Any): + # Description is not allowed by Discord API for User Commands, but can be set to empty string. + kwargs["description"] = kwargs.get( + "description", "" + ) # Ensure it's empty or not present in payload + super().__init__(callback, type=ApplicationCommandType.USER, **kwargs) + + +class MessageCommand(AppCommand): + """Represents a MESSAGE context menu command.""" + + def __init__(self, callback: Callable[..., Any], **kwargs: Any): + # Description is not allowed by Discord API for Message Commands. + kwargs["description"] = kwargs.get("description", "") + super().__init__(callback, type=ApplicationCommandType.MESSAGE, **kwargs) + + +class HybridCommand(SlashCommand, PrefixCommand): # Inherit from both + """ + Represents a command that can be invoked as both a slash command + and a traditional prefix-based command. + """ + + def __init__(self, callback: Callable[..., Any], **kwargs: Any): + # Initialize SlashCommand part (which calls AppCommand.__init__) + # We need to ensure 'type' is correctly passed for AppCommand + # kwargs for SlashCommand: name, description, guild_ids, default_member_permissions, nsfw, parent, cog, etc. + # kwargs for PrefixCommand: name, aliases, brief, description, cog + + # Pop prefix-specific args before passing to SlashCommand constructor + prefix_aliases = kwargs.pop("aliases", []) + prefix_brief = kwargs.pop("brief", None) + # Description is used by both, AppCommand's constructor will handle it. + # Name is used by both. Cog is used by both. + + # Call SlashCommand's __init__ + # This will set up name, description, callback, type=CHAT_INPUT, options, etc. + super().__init__(callback, **kwargs) # This is SlashCommand.__init__ + + # Now, explicitly initialize the PrefixCommand parts that SlashCommand didn't cover + # or that need specific values for the prefix version. + # PrefixCommand.__init__(self, callback, name=self.name, aliases=prefix_aliases, brief=prefix_brief, description=self.description, cog=self.cog) + # However, PrefixCommand.__init__ also sets self.params, which AppCommand already did. + # We need to be careful not to re-initialize things unnecessarily or incorrectly. + # Let's manually set the distinct attributes for the PrefixCommand aspect. + + # Attributes from PrefixCommand: + # self.callback is already set by AppCommand + # self.name is already set by AppCommand + self.aliases: List[str] = ( + prefix_aliases # This was specific to HybridCommand before, now aligns with PrefixCommand + ) + self.brief: Optional[str] = prefix_brief + # self.description is already set by AppCommand (SlashCommand ensures it exists) + # self.cog is already set by AppCommand + # self.params is already set by AppCommand + + # Ensure the MRO is handled correctly. Python's MRO (C3 linearization) + # should call SlashCommand's __init__ then AppCommand's __init__. + # PrefixCommand.__init__ won't be called automatically unless we explicitly call it. + # By setting attributes directly, we avoid potential issues with multiple __init__ calls + # if their logic overlaps too much (e.g., both trying to set self.params). + + # We might need to override invoke if the context or argument passing differs significantly + # between app command invocation and prefix command invocation. + # For now, SlashCommand.invoke and PrefixCommand.invoke are separate. + # The correct one will be called depending on how the command is dispatched. + # The AppCommandHandler will use AppCommand.invoke (via SlashCommand). + # The prefix CommandHandler will use PrefixCommand.invoke. + # This seems acceptable. + + +class AppCommandGroup: + """ + Represents a group of application commands (subcommands or subcommand groups). + This itself is not directly callable but acts as a namespace. + """ + + def __init__( + self, + name: str, + description: Optional[ + str + ] = None, # Required for top-level groups that form part of a slash command + guild_ids: Optional[List["Snowflake"]] = None, + parent: Optional["AppCommandGroup"] = None, + default_member_permissions: Optional[str] = None, + nsfw: bool = False, + name_localizations: Optional[Dict[str, str]] = None, + description_localizations: Optional[Dict[str, str]] = None, + integration_types: Optional[List["IntegrationType"]] = None, + contexts: Optional[List["InteractionContextType"]] = None, + ): + self.name: str = name + self.description: Optional[str] = description + self.guild_ids: Optional[List["Snowflake"]] = guild_ids + self.parent: Optional["AppCommandGroup"] = parent + self.commands: Dict[str, Union[AppCommand, "AppCommandGroup"]] = {} + self.default_member_permissions: Optional[str] = default_member_permissions + self.nsfw: bool = nsfw + self.name_localizations: Optional[Dict[str, str]] = name_localizations + self.description_localizations: Optional[Dict[str, str]] = ( + description_localizations + ) + self.integration_types: Optional[List["IntegrationType"]] = integration_types + self.contexts: Optional[List["InteractionContextType"]] = contexts + # A group itself doesn't have a cog directly, its commands do. + + def add_command(self, command: Union[AppCommand, "AppCommandGroup"]) -> None: + if command.name in self.commands: + raise ValueError( + f"Command or group '{command.name}' already exists in group '{self.name}'." + ) + command.parent = self + self.commands[command.name] = command + + def get_command(self, name: str) -> Optional[Union[AppCommand, "AppCommandGroup"]]: + return self.commands.get(name) + + def command(self, *d_args: Any, **d_kwargs: Any): + d_kwargs.setdefault("parent", self) + from .decorators import slash_command + + return slash_command(*d_args, **d_kwargs) + + def group( + self, + name: str, + description: Optional[str] = None, + **kwargs: Any, + ): + sub_group = AppCommandGroup( + name=name, + description=description, + parent=self, + guild_ids=kwargs.get("guild_ids"), + default_member_permissions=kwargs.get("default_member_permissions"), + nsfw=kwargs.get("nsfw", False), + name_localizations=kwargs.get("name_localizations"), + description_localizations=kwargs.get("description_localizations"), + integration_types=kwargs.get("integration_types"), + contexts=kwargs.get("contexts"), + ) + self.add_command(sub_group) + + def decorator(func: Optional[Callable[..., Any]] = None): + if func is not None: + setattr(func, "__app_command_object__", sub_group) + return sub_group + return sub_group + + return decorator + + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> Dict[str, Any]: + """ + Converts the command group to a dictionary payload for Discord API. + This represents a top-level command that has subcommands/subcommand groups. + """ + payload: Dict[str, Any] = { + "name": self.name, + "type": ApplicationCommandType.CHAT_INPUT.value, # Groups are implicitly CHAT_INPUT + "description": self.description + or "No description provided", # Top-level groups require a description + "options": [], + } + + if self.default_member_permissions is not None: + payload["default_member_permissions"] = str(self.default_member_permissions) + if self.nsfw: + payload["nsfw"] = True + if self.name_localizations: + payload["name_localizations"] = self.name_localizations + if ( + self.description and self.description_localizations + ): # Only if description is not empty + payload["description_localizations"] = self.description_localizations + if self.integration_types: + payload["integration_types"] = [it.value for it in self.integration_types] + if self.contexts: + payload["contexts"] = [ict.value for ict in self.contexts] + + # guild_ids are handled at the registration level, not in this specific payload part. + + options_payload: List[Dict[str, Any]] = [] + for cmd_name, command_or_group in self.commands.items(): + if isinstance(command_or_group, AppCommand): # This is a Subcommand + # Subcommands use their own options (parameters) + sub_options = ( + [opt.to_dict() for opt in command_or_group.options] + if command_or_group.options + else [] + ) + option_dict = { + "type": ApplicationCommandOptionType.SUB_COMMAND.value, + "name": command_or_group.name, + "description": command_or_group.description + or "No description provided", + "options": sub_options, + } + # Add localization for subcommand name and description if available + if command_or_group.name_localizations: + option_dict["name_localizations"] = ( + command_or_group.name_localizations + ) + if ( + command_or_group.description + and command_or_group.description_localizations + ): + option_dict["description_localizations"] = ( + command_or_group.description_localizations + ) + options_payload.append(option_dict) + + elif isinstance( + command_or_group, AppCommandGroup + ): # This is a Subcommand Group + # Subcommand groups have their subcommands/groups as options + sub_group_options: List[Dict[str, Any]] = [] + for sub_cmd_name, sub_command in command_or_group.commands.items(): + # Nested groups can only contain subcommands, not further nested groups as per Discord rules. + # So, sub_command here must be an AppCommand. + if isinstance( + sub_command, AppCommand + ): # Should always be AppCommand if structure is valid + sub_cmd_options = ( + [opt.to_dict() for opt in sub_command.options] + if sub_command.options + else [] + ) + sub_group_option_entry = { + "type": ApplicationCommandOptionType.SUB_COMMAND.value, + "name": sub_command.name, + "description": sub_command.description + or "No description provided", + "options": sub_cmd_options, + } + # Add localization for subcommand name and description if available + if sub_command.name_localizations: + sub_group_option_entry["name_localizations"] = ( + sub_command.name_localizations + ) + if ( + sub_command.description + and sub_command.description_localizations + ): + sub_group_option_entry["description_localizations"] = ( + sub_command.description_localizations + ) + sub_group_options.append(sub_group_option_entry) + # else: + # # This case implies a group nested inside a group, which then contains another group. + # # Discord's structure is: + # # command -> option (SUB_COMMAND_GROUP) -> option (SUB_COMMAND) -> option (param) + # # This should be caught by validation logic in decorators or add_command. + # # For now, we assume valid structure where AppCommandGroup's commands are AppCommands. + # pass + + option_dict = { + "type": ApplicationCommandOptionType.SUB_COMMAND_GROUP.value, + "name": command_or_group.name, + "description": command_or_group.description + or "No description provided", + "options": sub_group_options, # These are the SUB_COMMANDs + } + # Add localization for subcommand group name and description if available + if command_or_group.name_localizations: + option_dict["name_localizations"] = ( + command_or_group.name_localizations + ) + if ( + command_or_group.description + and command_or_group.description_localizations + ): + option_dict["description_localizations"] = ( + command_or_group.description_localizations + ) + options_payload.append(option_dict) + + payload["options"] = options_payload + return payload + + +# Need to import asyncio for iscoroutinefunction check +import asyncio + +if TYPE_CHECKING: + from .context import AppCommandContext # For type hint in AppCommand.invoke + + # Ensure ApplicationCommandOptionType is available for the to_dict method + from disagreement.enums import ApplicationCommandOptionType diff --git a/disagreement/ext/app_commands/context.py b/disagreement/ext/app_commands/context.py new file mode 100644 index 0000000..d2c5507 --- /dev/null +++ b/disagreement/ext/app_commands/context.py @@ -0,0 +1,556 @@ +# disagreement/ext/app_commands/context.py + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, List, Union, Any, Dict + +if TYPE_CHECKING: + from disagreement.client import Client + from disagreement.interactions import ( + Interaction, + InteractionCallbackData, + InteractionResponsePayload, + Snowflake, + ) + from disagreement.enums import InteractionCallbackType, MessageFlags + from disagreement.models import ( + User, + Member, + Message, + Channel, + ActionRow, + ) + from disagreement.ui.view import View + + # For full model hints, these would be imported from disagreement.models when defined: + Embed = Any + PartialAttachment = Any + Guild = Any # from disagreement.models import Guild + TextChannel = Any # from disagreement.models import TextChannel, etc. + from .commands import AppCommand + +from disagreement.enums import InteractionCallbackType, MessageFlags +from disagreement.interactions import ( + Interaction, + InteractionCallbackData, + InteractionResponsePayload, + Snowflake, +) +from disagreement.models import Message +from disagreement.typing import Typing + + +class AppCommandContext: + """ + Represents the context in which an application command is being invoked. + Provides methods to respond to the interaction. + """ + + def __init__( + self, + bot: "Client", + interaction: "Interaction", + command: Optional["AppCommand"] = None, + ): + self.bot: "Client" = bot + self.interaction: "Interaction" = interaction + self.command: Optional["AppCommand"] = command # The command that was invoked + + self._responded: bool = False + self._deferred: bool = False + + @property + def token(self) -> str: + """The interaction token.""" + return self.interaction.token + + @property + def interaction_id(self) -> "Snowflake": + """The interaction ID.""" + return self.interaction.id + + @property + def application_id(self) -> "Snowflake": + """The application ID of the interaction.""" + return self.interaction.application_id + + @property + def guild_id(self) -> Optional["Snowflake"]: + """The ID of the guild where the interaction occurred, if any.""" + return self.interaction.guild_id + + @property + def channel_id(self) -> Optional["Snowflake"]: + """The ID of the channel where the interaction occurred.""" + return self.interaction.channel_id + + @property + def author(self) -> Optional[Union["User", "Member"]]: + """The user or member who invoked the interaction.""" + return self.interaction.member or self.interaction.user + + @property + def user(self) -> Optional["User"]: + """The user who invoked the interaction. + If in a guild, this is the user part of the member. + If in a DM, this is the top-level user. + """ + return self.interaction.user + + @property + def member(self) -> Optional["Member"]: + """The member who invoked the interaction, if this occurred in a guild.""" + return self.interaction.member + + @property + def locale(self) -> Optional[str]: + """The selected language of the invoking user.""" + return self.interaction.locale + + @property + def guild_locale(self) -> Optional[str]: + """The guild's preferred language, if applicable.""" + return self.interaction.guild_locale + + @property + async def guild(self) -> Optional["Guild"]: + """The guild object where the interaction occurred, if available.""" + + if not self.guild_id: + return None + + guild = None + if hasattr(self.bot, "get_guild"): + guild = self.bot.get_guild(self.guild_id) + + if not guild and hasattr(self.bot, "fetch_guild"): + try: + guild = await self.bot.fetch_guild(self.guild_id) + except Exception: + guild = None + + return guild + + @property + async def channel(self) -> Optional[Any]: + """The channel object where the interaction occurred, if available.""" + + if not self.channel_id: + return None + + channel = None + if hasattr(self.bot, "get_channel"): + channel = self.bot.get_channel(self.channel_id) + elif hasattr(self.bot, "_channels"): + channel = self.bot._channels.get(self.channel_id) + + if not channel and hasattr(self.bot, "fetch_channel"): + try: + channel = await self.bot.fetch_channel(self.channel_id) + except Exception: + channel = None + + return channel + + async def _send_response( + self, + response_type: "InteractionCallbackType", + data: Optional[Dict[str, Any]] = None, + ) -> None: + """Internal helper to send interaction responses.""" + if ( + self._responded + and not self._deferred + and response_type + != InteractionCallbackType.APPLICATION_COMMAND_AUTOCOMPLETE_RESULT + ): + # If already responded and not deferred, subsequent responses must be followups + # (unless it's an autocomplete result which is a special case) + # For now, let's assume followups are handled by separate methods. + # This logic might need refinement based on how followups are exposed. + raise RuntimeError( + "Interaction has already been responded to. Use send_followup()." + ) + + callback_data = InteractionCallbackData(data) if data else None + payload = InteractionResponsePayload(type=response_type, data=callback_data) + + await self.bot._http.create_interaction_response( + interaction_id=self.interaction_id, + interaction_token=self.token, + payload=payload, + ) + if ( + response_type + != InteractionCallbackType.APPLICATION_COMMAND_AUTOCOMPLETE_RESULT + ): + self._responded = True + if ( + response_type + == InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE + or response_type == InteractionCallbackType.DEFERRED_UPDATE_MESSAGE + ): + self._deferred = True + + async def defer(self, ephemeral: bool = False, thinking: bool = True) -> None: + """ + Defers the interaction response. + + This is typically used when your command might take longer than 3 seconds to process. + You must send a followup message within 15 minutes. + + Args: + ephemeral (bool): Whether the subsequent followup response should be ephemeral. + Only applicable if `thinking` is True. + thinking (bool): If True (default), responds with a "Bot is thinking..." message + (DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE). + If False, responds with DEFERRED_UPDATE_MESSAGE (for components). + """ + if self._responded: + raise RuntimeError("Interaction has already been responded to or deferred.") + + response_type = ( + InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE + if thinking + else InteractionCallbackType.DEFERRED_UPDATE_MESSAGE + ) + data = None + if ephemeral and thinking: + data = { + "flags": MessageFlags.EPHEMERAL.value + } # Assuming MessageFlags enum exists + + await self._send_response(response_type, data) + self._deferred = True # Mark as deferred + + async def send( + self, + content: Optional[str] = None, + embed: Optional["Embed"] = None, # Convenience for single embed + embeds: Optional[List["Embed"]] = None, + *, + tts: bool = False, + files: Optional[List[Any]] = None, + components: Optional[List[ActionRow]] = None, + view: Optional[View] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + ephemeral: bool = False, + flags: Optional[int] = None, + ) -> Optional[ + "Message" + ]: # Returns Message if not ephemeral and response was not deferred + """ + Sends a response to the interaction. + If the interaction was previously deferred, this will edit the original deferred response. + Otherwise, it sends an initial response. + + Args: + content (Optional[str]): The message content. + embed (Optional[Embed]): A single embed to send. If `embeds` is also provided, this is ignored. + embeds (Optional[List[Embed]]): A list of embeds to send (max 10). + ephemeral (bool): Whether the message should be ephemeral (only visible to the invoker). + flags (Optional[int]): Additional message flags to apply. + + Returns: + Optional[Message]: The sent message object if a new message was created and not ephemeral. + None if the response was ephemeral or an edit to a deferred message. + """ + if not self._responded and self._deferred: # Editing a deferred response + # Use edit_original_interaction_response + payload: Dict[str, Any] = {} + if content is not None: + payload["content"] = content + + if tts: + payload["tts"] = True + + actual_embeds = embeds + if embed and not embeds: + actual_embeds = [embed] + if actual_embeds: + payload["embeds"] = [e.to_dict() for e in actual_embeds] + + if view: + await view._start(self.bot) + payload["components"] = view.to_components_payload() + elif components: + payload["components"] = [c.to_dict() for c in components] + + if files is not None: + payload["attachments"] = [ + f.to_dict() if hasattr(f, "to_dict") else f for f in files + ] + + if allowed_mentions is not None: + payload["allowed_mentions"] = allowed_mentions + + # Flags (like ephemeral) cannot be set when editing the original deferred response this way. + # Ephemeral for deferred must be set during defer(). + + msg_data = await self.bot._http.edit_original_interaction_response( + application_id=self.application_id, + interaction_token=self.token, + payload=payload, + ) + self._responded = True # Ensure it's marked as fully responded + if view and msg_data and "id" in msg_data: + view.message_id = msg_data["id"] + self.bot._views[msg_data["id"]] = view + # Construct and return Message object if needed, for now returns None for edits + return None + + elif not self._responded: # Sending an initial response + data: Dict[str, Any] = {} + if content is not None: + data["content"] = content + + if tts: + data["tts"] = True + + actual_embeds = embeds + if embed and not embeds: + actual_embeds = [embed] + if actual_embeds: + data["embeds"] = [ + e.to_dict() for e in actual_embeds + ] # Assuming embeds have to_dict() + + if view: + await view._start(self.bot) + data["components"] = view.to_components_payload() + elif components: + data["components"] = [c.to_dict() for c in components] + + if files is not None: + data["attachments"] = [ + f.to_dict() if hasattr(f, "to_dict") else f for f in files + ] + + if allowed_mentions is not None: + data["allowed_mentions"] = allowed_mentions + + flags_value = 0 + if ephemeral: + flags_value |= MessageFlags.EPHEMERAL.value + if flags: + flags_value |= flags + if flags_value: + data["flags"] = flags_value + + await self._send_response( + InteractionCallbackType.CHANNEL_MESSAGE_WITH_SOURCE, data + ) + + if view and not ephemeral: + try: + msg_data = await self.bot._http.get_original_interaction_response( + application_id=self.application_id, + interaction_token=self.token, + ) + if msg_data and "id" in msg_data: + view.message_id = msg_data["id"] + self.bot._views[msg_data["id"]] = view + except Exception: + pass + if not ephemeral: + return None + return None + else: + # If already responded and not deferred, this should be a followup. + # This method is for initial response or editing deferred. + raise RuntimeError( + "Interaction has already been responded to. Use send_followup()." + ) + + async def send_followup( + self, + content: Optional[str] = None, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + *, + ephemeral: bool = False, + tts: bool = False, + files: Optional[List[Any]] = None, + components: Optional[List["ActionRow"]] = None, + view: Optional[View] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + flags: Optional[int] = None, + ) -> Optional["Message"]: + """ + Sends a followup message to an interaction. + This can be used after an initial response or a deferred response. + + Args: + content (Optional[str]): The message content. + embed (Optional[Embed]): A single embed to send. + embeds (Optional[List[Embed]]): A list of embeds to send. + ephemeral (bool): Whether the followup message should be ephemeral. + flags (Optional[int]): Additional message flags to apply. + + Returns: + Message: The sent followup message object. + """ + if not self._responded: + raise RuntimeError( + "Must acknowledge or defer the interaction before sending a followup." + ) + + payload: Dict[str, Any] = {} + if content is not None: + payload["content"] = content + + if tts: + payload["tts"] = True + + actual_embeds = embeds + if embed and not embeds: + actual_embeds = [embed] + if actual_embeds: + payload["embeds"] = [ + e.to_dict() for e in actual_embeds + ] # Assuming embeds have to_dict() + + if view: + await view._start(self.bot) + payload["components"] = view.to_components_payload() + elif components: + payload["components"] = [c.to_dict() for c in components] + + if files is not None: + payload["attachments"] = [ + f.to_dict() if hasattr(f, "to_dict") else f for f in files + ] + + if allowed_mentions is not None: + payload["allowed_mentions"] = allowed_mentions + + flags_value = 0 + if ephemeral: + flags_value |= MessageFlags.EPHEMERAL.value + if flags: + flags_value |= flags + if flags_value: + payload["flags"] = flags_value + + # Followup messages are sent to a webhook endpoint + message_data = await self.bot._http.create_followup_message( + application_id=self.application_id, + interaction_token=self.token, + payload=payload, + ) + if view and message_data and "id" in message_data: + view.message_id = message_data["id"] + self.bot._views[message_data["id"]] = view + from disagreement.models import Message # Ensure Message is available + + return Message(data=message_data, client_instance=self.bot) + + async def edit( + self, + message_id: "Snowflake" = "@original", # Defaults to editing the original response + content: Optional[str] = None, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + *, + components: Optional[List["ActionRow"]] = None, + attachments: Optional[List[Any]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + ) -> Optional["Message"]: + """ + Edits a message previously sent in response to this interaction. + Can edit the original response or a followup message. + + Args: + message_id (Snowflake): The ID of the message to edit. Defaults to "@original" + to edit the initial interaction response. + content (Optional[str]): The new message content. + embed (Optional[Embed]): A single new embed. + embeds (Optional[List[Embed]]): A list of new embeds. + + Returns: + Optional[Message]: The edited message object if available. + """ + if not self._responded: + raise RuntimeError( + "Cannot edit response if interaction hasn't been responded to or deferred." + ) + + payload: Dict[str, Any] = {} + if content is not None: + payload["content"] = content # Use None to clear + + actual_embeds = embeds + if embed and not embeds: + actual_embeds = [embed] + if actual_embeds is not None: # Allow passing empty list to clear embeds + payload["embeds"] = [ + e.to_dict() for e in actual_embeds + ] # Assuming embeds have to_dict() + + if components is not None: + payload["components"] = [c.to_dict() for c in components] + + if attachments is not None: + payload["attachments"] = [ + a.to_dict() if hasattr(a, "to_dict") else a for a in attachments + ] + + if allowed_mentions is not None: + payload["allowed_mentions"] = allowed_mentions + + if message_id == "@original": + edited_message_data = ( + await self.bot._http.edit_original_interaction_response( + application_id=self.application_id, + interaction_token=self.token, + payload=payload, + ) + ) + else: + edited_message_data = await self.bot._http.edit_followup_message( + application_id=self.application_id, + interaction_token=self.token, + message_id=message_id, + payload=payload, + ) + # The HTTP methods used in tests return minimal data that is insufficient + # to construct a full ``Message`` instance, so we simply return ``None`` + # rather than attempting to parse the response. + return None + + async def delete(self, message_id: "Snowflake" = "@original") -> None: + """ + Deletes a message previously sent in response to this interaction. + Can delete the original response or a followup message. + + Args: + message_id (Snowflake): The ID of the message to delete. Defaults to "@original" + to delete the initial interaction response. + """ + if not self._responded: + # If not responded, there's nothing to delete via this interaction's lifecycle. + # Deferral doesn't create a message to delete until a followup is sent. + raise RuntimeError( + "Cannot delete response if interaction hasn't been responded to." + ) + + if message_id == "@original": + await self.bot._http.delete_original_interaction_response( + application_id=self.application_id, interaction_token=self.token + ) + else: + await self.bot._http.delete_followup_message( + application_id=self.application_id, + interaction_token=self.token, + message_id=message_id, + ) + # After deleting the original response, further followups might be problematic. + # Discord docs: "Once the original message is deleted, you can no longer edit the message or send followups." + # Consider implications for context state. + + def typing(self) -> Typing: + """Return a typing context manager for this interaction's channel.""" + + if not self.channel_id: + raise RuntimeError("Cannot send typing indicator without a channel.") + return self.bot.typing(self.channel_id) diff --git a/disagreement/ext/app_commands/converters.py b/disagreement/ext/app_commands/converters.py new file mode 100644 index 0000000..23d3e64 --- /dev/null +++ b/disagreement/ext/app_commands/converters.py @@ -0,0 +1,478 @@ +# disagreement/ext/app_commands/converters.py + +""" +Converters for transforming application command option values. +""" + +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Protocol, + TypeVar, + Union, + TYPE_CHECKING, +) +from disagreement.enums import ApplicationCommandOptionType +from disagreement.errors import ( + AppCommandOptionConversionError, +) # To be created in disagreement/errors.py + +if TYPE_CHECKING: + from disagreement.interactions import Interaction # For context if needed + from disagreement.models import ( + User, + Member, + Role, + Channel, + Attachment, + ) # Discord models + from disagreement.client import Client # For fetching objects + +T = TypeVar("T", covariant=True) + + +class Converter(Protocol[T]): + """ + A protocol for classes that can convert an interaction option value to a specific type. + """ + + async def convert(self, interaction: "Interaction", value: Any) -> T: + """ + Converts the given value to the target type. + + Parameters: + interaction (Interaction): The interaction context. + value (Any): The raw value from the interaction option. + + Returns: + T: The converted value. + + Raises: + AppCommandOptionConversionError: If conversion fails. + """ + ... + + +# Basic Type Converters + + +class StringConverter(Converter[str]): + async def convert(self, interaction: "Interaction", value: Any) -> str: + if not isinstance(value, str): + raise AppCommandOptionConversionError( + f"Expected a string, but got {type(value).__name__}: {value}" + ) + return value + + +class IntegerConverter(Converter[int]): + async def convert(self, interaction: "Interaction", value: Any) -> int: + if not isinstance(value, int): + try: + return int(value) + except (ValueError, TypeError): + raise AppCommandOptionConversionError( + f"Expected an integer, but got {type(value).__name__}: {value}" + ) + return value + + +class BooleanConverter(Converter[bool]): + async def convert(self, interaction: "Interaction", value: Any) -> bool: + if not isinstance(value, bool): + if isinstance(value, str): + if value.lower() == "true": + return True + elif value.lower() == "false": + return False + raise AppCommandOptionConversionError( + f"Expected a boolean, but got {type(value).__name__}: {value}" + ) + return value + + +class NumberConverter(Converter[float]): # Discord 'NUMBER' type is float + async def convert(self, interaction: "Interaction", value: Any) -> float: + if not isinstance(value, (int, float)): + try: + return float(value) + except (ValueError, TypeError): + raise AppCommandOptionConversionError( + f"Expected a number (float), but got {type(value).__name__}: {value}" + ) + return float(value) # Ensure it's a float even if int is passed + + +# Discord Model Converters + + +class UserConverter(Converter["User"]): + def __init__(self, client: "Client"): + self._client = client + + async def convert(self, interaction: "Interaction", value: Any) -> "User": + if isinstance(value, str): # Assume it's a user ID + user_id = value + # Attempt to get from interaction resolved data first + if ( + interaction.data + and interaction.data.resolved + and interaction.data.resolved.users + ): + user_object = interaction.data.resolved.users.get( + user_id + ) # This is already a User object + if user_object: + return user_object # Return the already parsed User object + + # Fallback to fetching if not in resolved or if interaction has no resolved data + try: + user = await self._client.fetch_user( + user_id + ) # fetch_user now also parses and caches + if user: + return user + raise AppCommandOptionConversionError( + f"User with ID '{user_id}' not found.", + option_name="user", + original_value=value, + ) + except Exception as e: # Catch potential HTTP errors from fetch_user + raise AppCommandOptionConversionError( + f"Failed to fetch user '{user_id}': {e}", + option_name="user", + original_value=value, + ) + elif ( + isinstance(value, dict) and "id" in value + ): # If it's raw user data dict (less common path now) + return self._client.parse_user(value) # parse_user handles dict -> User + raise AppCommandOptionConversionError( + f"Expected a user ID string or user data dict, got {type(value).__name__}", + option_name="user", + original_value=value, + ) + + +class MemberConverter(Converter["Member"]): + def __init__(self, client: "Client"): + self._client = client + + async def convert(self, interaction: "Interaction", value: Any) -> "Member": + if not interaction.guild_id: + raise AppCommandOptionConversionError( + "Cannot convert to Member outside of a guild context.", + option_name="member", + ) + + if isinstance(value, str): # Assume it's a user ID + member_id = value + # Attempt to get from interaction resolved data first + if ( + interaction.data + and interaction.data.resolved + and interaction.data.resolved.members + ): + # The Member object from resolved.members should already be correctly initialized + # by ResolvedData, including its User part. + member = interaction.data.resolved.members.get(member_id) + if member: + return ( + member # Return the already resolved and parsed Member object + ) + + # Fallback to fetching if not in resolved + try: + member = await self._client.fetch_member( + interaction.guild_id, member_id + ) + if member: + return member + raise AppCommandOptionConversionError( + f"Member with ID '{member_id}' not found in guild '{interaction.guild_id}'.", + option_name="member", + original_value=value, + ) + except Exception as e: + raise AppCommandOptionConversionError( + f"Failed to fetch member '{member_id}': {e}", + option_name="member", + original_value=value, + ) + elif isinstance(value, dict) and "id" in value.get( + "user", {} + ): # If it's already a member data dict + return self._client.parse_member(value, interaction.guild_id) + raise AppCommandOptionConversionError( + f"Expected a member ID string or member data dict, got {type(value).__name__}", + option_name="member", + original_value=value, + ) + + +class RoleConverter(Converter["Role"]): + def __init__(self, client: "Client"): + self._client = client + + async def convert(self, interaction: "Interaction", value: Any) -> "Role": + if not interaction.guild_id: + raise AppCommandOptionConversionError( + "Cannot convert to Role outside of a guild context.", option_name="role" + ) + + if isinstance(value, str): # Assume it's a role ID + role_id = value + # Attempt to get from interaction resolved data first + if ( + interaction.data + and interaction.data.resolved + and interaction.data.resolved.roles + ): + role_object = interaction.data.resolved.roles.get( + role_id + ) # Should be a Role object + if role_object: + return role_object + + # Fallback to fetching from guild if not in resolved + # This requires Client to have a fetch_role method or similar + try: + # Assuming Client.fetch_role(guild_id, role_id) will be implemented + role = await self._client.fetch_role(interaction.guild_id, role_id) + if role: + return role + raise AppCommandOptionConversionError( + f"Role with ID '{role_id}' not found in guild '{interaction.guild_id}'.", + option_name="role", + original_value=value, + ) + except Exception as e: + raise AppCommandOptionConversionError( + f"Failed to fetch role '{role_id}': {e}", + option_name="role", + original_value=value, + ) + elif ( + isinstance(value, dict) and "id" in value + ): # If it's already role data dict + if ( + not interaction.guild_id + ): # Should have been caught earlier, but as a safeguard + raise AppCommandOptionConversionError( + "Guild context is required to parse role data.", + option_name="role", + original_value=value, + ) + return self._client.parse_role(value, interaction.guild_id) + # This path is reached if value is not a string (role ID) and not a dict (role data) + # or if it's a string but all fetching/lookup attempts failed. + # The final raise AppCommandOptionConversionError should be outside the if/elif for string values. + # If value was a string, an error should have been raised within the 'if isinstance(value, str)' block. + # If it wasn't a string or dict, this is the correct place to raise. + # The previous structure was slightly off, as the final raise was inside the string check block. + # Let's ensure the final raise is at the correct scope. + # The current structure seems to imply that if it's not a string, it must be a dict or error. + # If it's a string and all lookups fail, an error is raised within that block. + # If it's a dict and parsing fails (or guild_id missing), error raised. + # If it's neither, this final raise is correct. + # The "Function with declared return type "Role" must return value on all code paths" + # error suggests a path where no return or raise happens. + # This happens if `isinstance(value, str)` is true, but then all internal paths + # (resolved check, fetch try/except) don't lead to a return or raise *before* + # falling out of the `if isinstance(value, str)` block. + # The `raise AppCommandOptionConversionError` at the end of the `if isinstance(value, str)` block + # (line 156 in previous version) handles the case where a role ID is given but not found. + # The one at the very end (line 164 in previous) handles cases where value is not str/dict. + + # Corrected structure for the final raise: + # It should be at the same level as the initial `if isinstance(value, str):` + # to catch cases where `value` is neither a str nor a dict. + # However, the current logic within the `if isinstance(value, str):` block + # ensures a raise if the role ID is not found. + # The `elif isinstance(value, dict)` handles the dict case. + # The final `raise` (line 164) is for types other than str or dict. + # The Pylance error "Function with declared return type "Role" must return value on all code paths" + # implies that if `value` is a string, and `interaction.data.resolved.roles.get(role_id)` is None, + # AND `self._client.fetch_role` returns None (which it can), then the + # `raise AppCommandOptionConversionError` on line 156 is correctly hit. + # The issue might be that Pylance doesn't see `AppCommandOptionConversionError` as definitively terminating. + # This is unlikely. Let's re-verify the logic flow. + + # The `raise` on line 156 is correct if role_id is not found after fetching. + # The `raise` on line 164 is for when `value` is not a str and not a dict. + # This seems logically sound. The Pylance error might be a misinterpretation or a subtle issue. + # For now, the duplicated `except` is the primary syntax error. + # The "must return value on all code paths" often occurs if an if/elif chain doesn't + # exhaust all possibilities or if a path through a try/except doesn't guarantee a return/raise. + # In this case, if `value` is a string, it either returns a Role or raises an error. + # If `value` is a dict, it either returns a Role or raises an error. + # If `value` is neither, it raises an error. All paths seem covered. + # The syntax error from the duplicated `except` is the most likely culprit for Pylance's confusion. + raise AppCommandOptionConversionError( + f"Expected a role ID string or role data dict, got {type(value).__name__}", + option_name="role", + original_value=value, + ) + + +class ChannelConverter(Converter["Channel"]): + def __init__(self, client: "Client"): + self._client = client + + async def convert(self, interaction: "Interaction", value: Any) -> "Channel": + if isinstance(value, str): # Assume it's a channel ID + channel_id = value + # Attempt to get from interaction resolved data first + if ( + interaction.data + and interaction.data.resolved + and interaction.data.resolved.channels + ): + # Resolved channels are PartialChannel. Client.fetch_channel will get the full typed one. + partial_channel = interaction.data.resolved.channels.get(channel_id) + if partial_channel: + # Client.fetch_channel should handle fetching and parsing to the correct Channel subtype + full_channel = await self._client.fetch_channel(partial_channel.id) + if full_channel: + return full_channel + # If fetch_channel returns None even with a resolved ID, it's an issue. + raise AppCommandOptionConversionError( + f"Failed to fetch full channel for resolved ID '{channel_id}'.", + option_name="channel", + original_value=value, + ) + + # Fallback to fetching directly if not in resolved or if resolved fetch failed + try: + channel = await self._client.fetch_channel( + channel_id + ) # fetch_channel handles parsing + if channel: + return channel + raise AppCommandOptionConversionError( + f"Channel with ID '{channel_id}' not found.", + option_name="channel", + original_value=value, + ) + except Exception as e: + raise AppCommandOptionConversionError( + f"Failed to fetch channel '{channel_id}': {e}", + option_name="channel", + original_value=value, + ) + # Raw channel data dicts are not typically provided for slash command options. + raise AppCommandOptionConversionError( + f"Expected a channel ID string, got {type(value).__name__}", + option_name="channel", + original_value=value, + ) + + +class AttachmentConverter(Converter["Attachment"]): + def __init__( + self, client: "Client" + ): # Client might be needed for future enhancements or consistency + self._client = client + + async def convert(self, interaction: "Interaction", value: Any) -> "Attachment": + if isinstance(value, str): # Value is the attachment ID + attachment_id = value + if ( + interaction.data + and interaction.data.resolved + and interaction.data.resolved.attachments + ): + attachment_object = interaction.data.resolved.attachments.get( + attachment_id + ) # This is already an Attachment object + if attachment_object: + return ( + attachment_object # Return the already parsed Attachment object + ) + raise AppCommandOptionConversionError( + f"Attachment with ID '{attachment_id}' not found in resolved data.", + option_name="attachment", + original_value=value, + ) + raise AppCommandOptionConversionError( + f"Expected an attachment ID string, got {type(value).__name__}", + option_name="attachment", + original_value=value, + ) + + +# Converters can be registered dynamically using +# :meth:`disagreement.ext.app_commands.handler.AppCommandHandler.register_converter`. + +# Mapping from ApplicationCommandOptionType to default converters +# This will be used by the AppCommandHandler to automatically apply converters +# if no explicit converter is specified for a command option's type hint. +DEFAULT_CONVERTERS: Dict[ + ApplicationCommandOptionType, Callable[..., Converter[Any]] +] = { # Changed Callable signature + ApplicationCommandOptionType.STRING: StringConverter, + ApplicationCommandOptionType.INTEGER: IntegerConverter, + ApplicationCommandOptionType.BOOLEAN: BooleanConverter, + ApplicationCommandOptionType.NUMBER: NumberConverter, + ApplicationCommandOptionType.USER: UserConverter, + ApplicationCommandOptionType.CHANNEL: ChannelConverter, + ApplicationCommandOptionType.ROLE: RoleConverter, + # ApplicationCommandOptionType.MENTIONABLE: MentionableConverter, # Special case, can be User or Role + ApplicationCommandOptionType.ATTACHMENT: AttachmentConverter, # Added +} + + +async def run_converters( + interaction: "Interaction", + param_type: Any, # The type hint of the parameter + option_type: ApplicationCommandOptionType, # The Discord option type + value: Any, + client: "Client", # Needed for model converters +) -> Any: + """ + Runs the appropriate converter for a given parameter type and value. + This function will be more complex, handling custom converters, unions, optionals etc. + For now, a basic lookup. + """ + converter_class_factory = DEFAULT_CONVERTERS.get(option_type) + if converter_class_factory: + # Check if the factory needs the client instance + # This is a bit simplistic; a more robust way might involve inspecting __init__ signature + # or having converters register their needs. + if option_type in [ + ApplicationCommandOptionType.USER, + ApplicationCommandOptionType.CHANNEL, # Anticipating these + ApplicationCommandOptionType.ROLE, + ApplicationCommandOptionType.MENTIONABLE, + ApplicationCommandOptionType.ATTACHMENT, + ]: + converter_instance = converter_class_factory(client=client) + else: + converter_instance = converter_class_factory() + + return await converter_instance.convert(interaction, value) + + # Fallback for unhandled types or if direct type matching is needed + if param_type is str and isinstance(value, str): + return value + if param_type is int and isinstance(value, int): + return value + if param_type is bool and isinstance(value, bool): + return value + if param_type is float and isinstance(value, (float, int)): + return float(value) + + # If no specific converter, and it's not a basic type match, raise error or return raw + # For now, let's raise if no converter found for a specific option type + if option_type in DEFAULT_CONVERTERS: # Should have been handled + pass # This path implies a logic error above or missing converter in DEFAULT_CONVERTERS + + # If it's a model type but no converter yet, this will need to be handled + # e.g. if param_type is User and option_type is ApplicationCommandOptionType.USER + + raise AppCommandOptionConversionError( + f"No suitable converter found for option type {option_type.name} " + f"with value '{value}' to target type {param_type.__name__ if hasattr(param_type, '__name__') else param_type}" + ) diff --git a/disagreement/ext/app_commands/decorators.py b/disagreement/ext/app_commands/decorators.py new file mode 100644 index 0000000..7f60778 --- /dev/null +++ b/disagreement/ext/app_commands/decorators.py @@ -0,0 +1,569 @@ +# disagreement/ext/app_commands/decorators.py + +import inspect +import asyncio +from dataclasses import dataclass +from typing import ( + Callable, + Optional, + List, + Dict, + Any, + Union, + Type, + get_origin, + get_args, + TYPE_CHECKING, + Literal, + Annotated, + TypeVar, + cast, +) + +from .commands import ( + SlashCommand, + UserCommand, + MessageCommand, + AppCommand, + AppCommandGroup, + HybridCommand, +) +from disagreement.interactions import ( + ApplicationCommandOption, + ApplicationCommandOptionChoice, + Snowflake, +) +from disagreement.enums import ( + ApplicationCommandOptionType, + IntegrationType, + InteractionContextType, + # Assuming ChannelType will be added to disagreement.enums +) + +if TYPE_CHECKING: + from disagreement.client import Client # For potential future use + from disagreement.models import Channel, User + + # Assuming TextChannel, VoiceChannel etc. might be defined or aliased + # For now, we'll use string comparisons for channel types or rely on a yet-to-be-defined ChannelType enum + Channel = Any + Member = Any + Role = Any + Attachment = Any + # from .cog import Cog # Placeholder +else: + # Runtime fallbacks for optional model classes + from disagreement.models import Channel + + Client = Any # type: ignore + User = Any # type: ignore + Member = Any # type: ignore + Role = Any # type: ignore + Attachment = Any # type: ignore + +# Mapping Python types to Discord ApplicationCommandOptionType +# This will need to be expanded and made more robust. +# Consider using a registry or a more sophisticated type mapping system. +_type_mapping: Dict[Any, ApplicationCommandOptionType] = ( + { # Changed Type to Any for key due to placeholders + str: ApplicationCommandOptionType.STRING, + int: ApplicationCommandOptionType.INTEGER, + bool: ApplicationCommandOptionType.BOOLEAN, + float: ApplicationCommandOptionType.NUMBER, # Discord 'NUMBER' type is for float/double + User: ApplicationCommandOptionType.USER, + Channel: ApplicationCommandOptionType.CHANNEL, + # Placeholders for actual model types from disagreement.models + # These will be resolved to their actual types when TYPE_CHECKING is False or via isinstance checks + } +) + + +# Helper dataclass for storing extra option metadata +@dataclass +class OptionMetadata: + channel_types: Optional[List[int]] = None + min_value: Optional[Union[int, float]] = None + max_value: Optional[Union[int, float]] = None + min_length: Optional[int] = None + max_length: Optional[int] = None + autocomplete: bool = False + + +# Ensure these are updated if model names/locations change +if TYPE_CHECKING: + # _type_mapping[User] = ApplicationCommandOptionType.USER # Already added above + _type_mapping[Member] = ApplicationCommandOptionType.USER # Member implies User + _type_mapping[Role] = ApplicationCommandOptionType.ROLE + _type_mapping[Attachment] = ApplicationCommandOptionType.ATTACHMENT + _type_mapping[Channel] = ApplicationCommandOptionType.CHANNEL + +# TypeVar for the app command decorator factory +AppCmdType = TypeVar("AppCmdType", bound=AppCommand) + + +def _extract_options_from_signature( + func: Callable[..., Any], option_meta: Optional[Dict[str, OptionMetadata]] = None +) -> List[ApplicationCommandOption]: + """ + Inspects a function signature and generates ApplicationCommandOption list. + """ + options: List[ApplicationCommandOption] = [] + params = inspect.signature(func).parameters + + doc = inspect.getdoc(func) + param_descriptions: Dict[str, str] = {} + if doc: + for line in inspect.cleandoc(doc).splitlines(): + line = line.strip() + if line.startswith(":param"): + try: + _, rest = line.split(" ", 1) + name, desc = rest.split(":", 1) + param_descriptions[name.strip()] = desc.strip() + except ValueError: + continue + + # Skip 'self' (for cogs) and 'ctx' (context) parameters + param_iter = iter(params.values()) + first_param = next(param_iter, None) + + # Heuristic: if the function is bound to a class (cog), 'self' might be the first param. + # A more robust way would be to check if `func` is a method of a Cog instance later. + # For now, simple name check. + if first_param and first_param.name == "self": + first_param = next(param_iter, None) # Consume 'self', get next + + if first_param and first_param.name == "ctx": # Consume 'ctx' + pass # ctx is handled, now iterate over actual command options + elif ( + first_param + ): # If first_param was not 'self' and not 'ctx', it's a command option + param_iter = iter(params.values()) # Reset iterator to include the first param + + for param in param_iter: + if param.name == "self" or param.name == "ctx": # Should have been skipped + continue + + if param.kind == param.VAR_POSITIONAL or param.kind == param.VAR_KEYWORD: + # *args and **kwargs are not directly supported by slash command options structure. + # Could raise an error or ignore. For now, ignore. + # print(f"Warning: *args/**kwargs ({param.name}) are not supported for slash command options.") + continue + + option_name = param.name + option_description = param_descriptions.get( + option_name, f"Description for {option_name}" + ) + meta = option_meta.get(option_name) if option_meta else None + + param_type_hint = param.annotation + if param_type_hint == inspect.Parameter.empty: + # Default to string if no type hint, or raise error. + # Forcing type hints is generally better for slash commands. + # raise TypeError(f"Option '{option_name}' must have a type hint for slash commands.") + param_type_hint = str # Defaulting to string, can be made stricter + + option_type: Optional[ApplicationCommandOptionType] = None + choices: Optional[List[ApplicationCommandOptionChoice]] = None + + origin = get_origin(param_type_hint) + args = get_args(param_type_hint) + + if origin is Annotated: + param_type_hint = args[0] + for extra in args[1:]: + if isinstance(extra, OptionMetadata): + meta = extra + origin = get_origin(param_type_hint) + args = get_args(param_type_hint) + + actual_type_for_mapping = param_type_hint + is_optional = False + + if origin is Union: # Handles Optional[T] which is Union[T, NoneType] + # Filter out NoneType to get the actual type for mapping + union_types = [t for t in args if t is not type(None)] + if len(union_types) == 1: + actual_type_for_mapping = union_types[0] + is_optional = True + else: + # More complex Unions are not directly supported by a single option type. + # Could default to STRING or raise. + # For now, let's assume simple Optional[T] or direct types. + # print(f"Warning: Complex Union type for '{option_name}' not fully supported, defaulting to STRING.") + actual_type_for_mapping = str + + elif origin is list and len(args) == 1: + # List[T] is not a direct option type. Discord handles multiple values for some types + # via repeated options or specific component interactions, not directly in slash command options. + # This might indicate a need for a different interaction pattern or custom parsing. + # For now, treat List[str] as a string, others might error or default. + # print(f"Warning: List type for '{option_name}' not directly supported as a single option. Consider type {args[0]}.") + actual_type_for_mapping = args[ + 0 + ] # Use the inner type for mapping, but this is a simplification. + + if origin is Literal: # typing.Literal['a', 'b'] + choices = [] + for choice_val in args: + if not isinstance(choice_val, (str, int, float)): + raise TypeError( + f"Literal choices for '{option_name}' must be str, int, or float. Got {type(choice_val)}." + ) + choices.append( + ApplicationCommandOptionChoice( + data={"name": str(choice_val), "value": choice_val} + ) + ) + # The type of the Literal's arguments determines the option type + if choices: + literal_arg_type = type(choices[0].value) + option_type = _type_mapping.get(literal_arg_type) + if ( + not option_type and literal_arg_type is float + ): # float maps to NUMBER + option_type = ApplicationCommandOptionType.NUMBER + + if not option_type: # If not determined by Literal + option_type = _type_mapping.get(actual_type_for_mapping) + # Special handling for User, Member, Role, Attachment, Channel if not directly in _type_mapping + # This is a bit crude; a proper registry or isinstance checks would be better. + if not option_type: + if ( + actual_type_for_mapping.__name__ == "User" + or actual_type_for_mapping.__name__ == "Member" + ): + option_type = ApplicationCommandOptionType.USER + elif actual_type_for_mapping.__name__ == "Role": + option_type = ApplicationCommandOptionType.ROLE + elif actual_type_for_mapping.__name__ == "Attachment": + option_type = ApplicationCommandOptionType.ATTACHMENT + elif ( + inspect.isclass(actual_type_for_mapping) + and isinstance(Channel, type) + and issubclass(actual_type_for_mapping, cast(type, Channel)) + ): + option_type = ApplicationCommandOptionType.CHANNEL + + if not option_type: + # Fallback or error if type couldn't be mapped + # print(f"Warning: Could not map type '{actual_type_for_mapping}' for option '{option_name}'. Defaulting to STRING.") + option_type = ApplicationCommandOptionType.STRING # Default fallback + + required = (param.default == inspect.Parameter.empty) and not is_optional + + data: Dict[str, Any] = { + "name": option_name, + "description": option_description, + "type": option_type.value, + "required": required, + "choices": ([c.to_dict() for c in choices] if choices else None), + } + + if meta: + if meta.channel_types is not None: + data["channel_types"] = meta.channel_types + if meta.min_value is not None: + data["min_value"] = meta.min_value + if meta.max_value is not None: + data["max_value"] = meta.max_value + if meta.min_length is not None: + data["min_length"] = meta.min_length + if meta.max_length is not None: + data["max_length"] = meta.max_length + if meta.autocomplete: + data["autocomplete"] = True + + options.append(ApplicationCommandOption(data=data)) + return options + + +def _app_command_decorator( + cls: Type[AppCmdType], + option_meta: Optional[Dict[str, OptionMetadata]] = None, + **attrs: Any, +) -> Callable[[Callable[..., Any]], AppCmdType]: + """Generic factory for creating app command decorators.""" + + def decorator(func: Callable[..., Any]) -> AppCmdType: + if not asyncio.iscoroutinefunction(func): + raise TypeError( + "Application command callback must be a coroutine function." + ) + + name = attrs.pop("name", None) or func.__name__ + description = attrs.pop("description", None) or inspect.getdoc(func) + if description: # Clean up docstring + description = inspect.cleandoc(description).split("\n\n", 1)[ + 0 + ] # Use first paragraph + + parent_group = attrs.pop("parent", None) + if parent_group and not isinstance(parent_group, AppCommandGroup): + raise TypeError( + "The 'parent' argument must be an AppCommandGroup instance." + ) + + # For User/Message commands, description should be empty for payload, but can be stored for help. + if cls is UserCommand or cls is MessageCommand: + actual_description_for_payload = "" + else: + actual_description_for_payload = description + if not actual_description_for_payload and cls is SlashCommand: + raise ValueError(f"Slash command '{name}' must have a description.") + + # Create the command instance + cmd_instance = cls( + callback=func, + name=name, + description=actual_description_for_payload, # Use payload-appropriate description + **attrs, # Remaining attributes like guild_ids, nsfw, etc. + ) + + # Store original description if different (e.g. for User/Message commands for help text) + if description != actual_description_for_payload: + cmd_instance._full_description = ( + description # Custom attribute for library use + ) + + if isinstance(cmd_instance, SlashCommand): + cmd_instance.options = _extract_options_from_signature(func, option_meta) + + if parent_group: + parent_group.add_command(cmd_instance) # This also sets cmd_instance.parent + + # Attach command object to the function for later collection by Cog or Client + # This is a common pattern. + if hasattr(func, "__app_command_object__"): + # Function might already be decorated (e.g. hybrid or stacked decorators) + # Decide on behavior: error, overwrite, or store list of commands. + # For now, let's assume one app command decorator of a specific type per function. + # Hybrid commands will need special handling. + print( + f"Warning: Function {func.__name__} is already an app command or has one attached. Overwriting." + ) + + setattr(func, "__app_command_object__", cmd_instance) + setattr(cmd_instance, "__app_command_object__", cmd_instance) + + # If the command is a HybridCommand, also set the attribute + # that the prefix command system's Cog._inject looks for. + if isinstance(cmd_instance, HybridCommand): + setattr(func, "__command_object__", cmd_instance) + setattr(cmd_instance, "__command_object__", cmd_instance) + + return cmd_instance # Return the command instance itself, not the function + # This allows it to be added to cogs/handlers directly. + + return decorator + + +def slash_command( + name: Optional[str] = None, + description: Optional[str] = None, + guild_ids: Optional[List[Snowflake]] = None, + default_member_permissions: Optional[str] = None, + nsfw: bool = False, + name_localizations: Optional[Dict[str, str]] = None, + description_localizations: Optional[Dict[str, str]] = None, + integration_types: Optional[List[IntegrationType]] = None, + contexts: Optional[List[InteractionContextType]] = None, + *, + guilds: bool = True, + dms: bool = True, + private_channels: bool = True, + parent: Optional[AppCommandGroup] = None, # Added parent parameter + locale: Optional[str] = None, + option_meta: Optional[Dict[str, OptionMetadata]] = None, +) -> Callable[[Callable[..., Any]], SlashCommand]: + """ + Decorator to create a CHAT_INPUT (slash) command. + Options are inferred from the function's type hints. + """ + if contexts is None: + ctxs: List[InteractionContextType] = [] + if guilds: + ctxs.append(InteractionContextType.GUILD) + if dms: + ctxs.append(InteractionContextType.BOT_DM) + if private_channels: + ctxs.append(InteractionContextType.PRIVATE_CHANNEL) + if len(ctxs) != 3: + contexts = ctxs + attrs = { + "name": name, + "description": description, + "guild_ids": guild_ids, + "default_member_permissions": default_member_permissions, + "nsfw": nsfw, + "name_localizations": name_localizations, + "description_localizations": description_localizations, + "integration_types": integration_types, + "contexts": contexts, + "parent": parent, # Pass parent to attrs + "locale": locale, + } + # Filter out None values to avoid passing them as explicit None to command constructor + # Keep 'parent' even if None, as _app_command_decorator handles None parent. + # nsfw default is False, so it's fine if not present and defaults. + attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw", "parent"]} + return _app_command_decorator(SlashCommand, option_meta, **attrs) + + +def user_command( + name: Optional[str] = None, + guild_ids: Optional[List[Snowflake]] = None, + default_member_permissions: Optional[str] = None, + nsfw: bool = False, # Though less common for user commands + name_localizations: Optional[Dict[str, str]] = None, + integration_types: Optional[List[IntegrationType]] = None, + contexts: Optional[List[InteractionContextType]] = None, + locale: Optional[str] = None, + # description is not used by Discord for User commands +) -> Callable[[Callable[..., Any]], UserCommand]: + """Decorator to create a USER context menu command.""" + attrs = { + "name": name, + "guild_ids": guild_ids, + "default_member_permissions": default_member_permissions, + "nsfw": nsfw, + "name_localizations": name_localizations, + "integration_types": integration_types, + "contexts": contexts, + "locale": locale, + } + attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw"]} + return _app_command_decorator(UserCommand, **attrs) + + +def message_command( + name: Optional[str] = None, + guild_ids: Optional[List[Snowflake]] = None, + default_member_permissions: Optional[str] = None, + nsfw: bool = False, # Though less common for message commands + name_localizations: Optional[Dict[str, str]] = None, + integration_types: Optional[List[IntegrationType]] = None, + contexts: Optional[List[InteractionContextType]] = None, + locale: Optional[str] = None, + # description is not used by Discord for Message commands +) -> Callable[[Callable[..., Any]], MessageCommand]: + """Decorator to create a MESSAGE context menu command.""" + attrs = { + "name": name, + "guild_ids": guild_ids, + "default_member_permissions": default_member_permissions, + "nsfw": nsfw, + "name_localizations": name_localizations, + "integration_types": integration_types, + "contexts": contexts, + "locale": locale, + } + attrs = {k: v for k, v in attrs.items() if v is not None or k in ["nsfw"]} + return _app_command_decorator(MessageCommand, **attrs) + + +def hybrid_command( + name: Optional[str] = None, + description: Optional[str] = None, + guild_ids: Optional[List[Snowflake]] = None, + default_member_permissions: Optional[str] = None, + nsfw: bool = False, + name_localizations: Optional[Dict[str, str]] = None, + description_localizations: Optional[Dict[str, str]] = None, + integration_types: Optional[List[IntegrationType]] = None, + contexts: Optional[List[InteractionContextType]] = None, + *, + guilds: bool = True, + dms: bool = True, + private_channels: bool = True, + aliases: Optional[List[str]] = None, # Specific to prefix command aspect + # Other prefix-specific options can be added here (e.g., help, brief) + option_meta: Optional[Dict[str, OptionMetadata]] = None, + locale: Optional[str] = None, +) -> Callable[[Callable[..., Any]], HybridCommand]: + """ + Decorator to create a command that can be invoked as both a slash command + and a traditional prefix-based command. + Options for the slash command part are inferred from the function's type hints. + """ + if contexts is None: + ctxs: List[InteractionContextType] = [] + if guilds: + ctxs.append(InteractionContextType.GUILD) + if dms: + ctxs.append(InteractionContextType.BOT_DM) + if private_channels: + ctxs.append(InteractionContextType.PRIVATE_CHANNEL) + if len(ctxs) != 3: + contexts = ctxs + attrs = { + "name": name, + "description": description, + "guild_ids": guild_ids, + "default_member_permissions": default_member_permissions, + "nsfw": nsfw, + "name_localizations": name_localizations, + "description_localizations": description_localizations, + "integration_types": integration_types, + "contexts": contexts, + "aliases": aliases or [], # Ensure aliases is a list + "locale": locale, + } + # Filter out None values to avoid passing them as explicit None to command constructor + # Keep 'nsfw' and 'aliases' as they have defaults (False, []) + attrs = { + k: v for k, v in attrs.items() if v is not None or k in ["nsfw", "aliases"] + } + return _app_command_decorator(HybridCommand, option_meta, **attrs) + + +def subcommand( + parent: AppCommandGroup, *d_args: Any, **d_kwargs: Any +) -> Callable[[Callable[..., Any]], SlashCommand]: + """Create a subcommand under an existing :class:`AppCommandGroup`.""" + + d_kwargs.setdefault("parent", parent) + return slash_command(*d_args, **d_kwargs) + + +def group( + name: str, + description: Optional[str] = None, + **kwargs: Any, +) -> Callable[[Optional[Callable[..., Any]]], AppCommandGroup]: + """Decorator to declare a top level :class:`AppCommandGroup`.""" + + def decorator(func: Optional[Callable[..., Any]] = None) -> AppCommandGroup: + grp = AppCommandGroup( + name=name, + description=description, + guild_ids=kwargs.get("guild_ids"), + parent=kwargs.get("parent"), + default_member_permissions=kwargs.get("default_member_permissions"), + nsfw=kwargs.get("nsfw", False), + name_localizations=kwargs.get("name_localizations"), + description_localizations=kwargs.get("description_localizations"), + integration_types=kwargs.get("integration_types"), + contexts=kwargs.get("contexts"), + ) + + if func is not None: + setattr(func, "__app_command_object__", grp) + return grp + + return decorator + + +def subcommand_group( + parent: AppCommandGroup, + name: str, + description: Optional[str] = None, + **kwargs: Any, +) -> Callable[[Optional[Callable[..., Any]]], AppCommandGroup]: + """Create a nested :class:`AppCommandGroup` under ``parent``.""" + + return parent.group( + name=name, + description=description, + **kwargs, + ) diff --git a/disagreement/ext/app_commands/handler.py b/disagreement/ext/app_commands/handler.py new file mode 100644 index 0000000..0ac29e2 --- /dev/null +++ b/disagreement/ext/app_commands/handler.py @@ -0,0 +1,627 @@ +# disagreement/ext/app_commands/handler.py + +import inspect +from typing import ( + TYPE_CHECKING, + Dict, + Optional, + List, + Any, + Tuple, + Union, + get_origin, + get_args, + Literal, +) + +if TYPE_CHECKING: + from disagreement.client import Client + from disagreement.interactions import Interaction, ResolvedData, Snowflake + from disagreement.enums import ( + ApplicationCommandType, + ApplicationCommandOptionType, + InteractionType, + ) + from .commands import ( + AppCommand, + SlashCommand, + UserCommand, + MessageCommand, + AppCommandGroup, + ) + from .context import AppCommandContext + from disagreement.models import ( + User, + Member, + Role, + Attachment, + Message, + ) # For resolved data + + # Channel models would also go here + +# Placeholder for models not yet fully defined or imported +if not TYPE_CHECKING: + from disagreement.enums import ( + ApplicationCommandType, + ApplicationCommandOptionType, + InteractionType, + ) + from .commands import ( + AppCommand, + SlashCommand, + UserCommand, + MessageCommand, + AppCommandGroup, + ) + from .context import AppCommandContext + + User = Any + Member = Any + Role = Any + Attachment = Any + Channel = Any + Message = Any + + +class AppCommandHandler: + """ + Manages application command registration, parsing, and dispatching. + """ + + def __init__(self, client: "Client"): + self.client: "Client" = client + # Store commands: key could be (name, type) for global, or (name, type, guild_id) for guild-specific + # For simplicity, let's start with a flat structure and refine if needed for guild commands. + # A more robust system might have separate dicts for global and guild commands. + self._slash_commands: Dict[str, SlashCommand] = {} + self._user_commands: Dict[str, UserCommand] = {} + self._message_commands: Dict[str, MessageCommand] = {} + self._app_command_groups: Dict[str, AppCommandGroup] = {} + self._converter_registry: Dict[type, type] = {} + + def add_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None: + """Adds an application command or a command group to the handler.""" + if isinstance(command, AppCommandGroup): + if command.name in self._app_command_groups: + raise ValueError( + f"AppCommandGroup '{command.name}' is already registered." + ) + self._app_command_groups[command.name] = command + return + + if isinstance(command, SlashCommand): + if command.name in self._slash_commands: + raise ValueError( + f"SlashCommand '{command.name}' is already registered." + ) + self._slash_commands[command.name] = command + return + + if isinstance(command, UserCommand): + if command.name in self._user_commands: + raise ValueError(f"UserCommand '{command.name}' is already registered.") + self._user_commands[command.name] = command + return + + if isinstance(command, MessageCommand): + if command.name in self._message_commands: + raise ValueError( + f"MessageCommand '{command.name}' is already registered." + ) + self._message_commands[command.name] = command + return + + if isinstance(command, AppCommand): + # Fallback for plain AppCommand objects + if command.type == ApplicationCommandType.CHAT_INPUT: + if command.name in self._slash_commands: + raise ValueError( + f"SlashCommand '{command.name}' is already registered." + ) + self._slash_commands[command.name] = command # type: ignore + elif command.type == ApplicationCommandType.USER: + if command.name in self._user_commands: + raise ValueError( + f"UserCommand '{command.name}' is already registered." + ) + self._user_commands[command.name] = command # type: ignore + elif command.type == ApplicationCommandType.MESSAGE: + if command.name in self._message_commands: + raise ValueError( + f"MessageCommand '{command.name}' is already registered." + ) + self._message_commands[command.name] = command # type: ignore + else: + raise TypeError( + f"Unsupported command type: {command.type} for '{command.name}'" + ) + else: + raise TypeError("Can only add AppCommand or AppCommandGroup instances.") + + def remove_command( + self, name: str + ) -> Optional[Union["AppCommand", "AppCommandGroup"]]: + """Removes an application command or group by name.""" + if name in self._slash_commands: + return self._slash_commands.pop(name) + if name in self._user_commands: + return self._user_commands.pop(name) + if name in self._message_commands: + return self._message_commands.pop(name) + if name in self._app_command_groups: + return self._app_command_groups.pop(name) + return None + + def register_converter(self, annotation: type, converter_cls: type) -> None: + """Register a custom converter class for a type annotation.""" + self._converter_registry[annotation] = converter_cls + + def get_converter(self, annotation: type) -> Optional[type]: + """Retrieve a registered converter class for a type annotation.""" + return self._converter_registry.get(annotation) + + def get_command( + self, + name: str, + command_type: "ApplicationCommandType", + interaction_options: Optional[List[Dict[str, Any]]] = None, + ) -> Optional["AppCommand"]: + """Retrieves a command of a specific type.""" + if command_type == ApplicationCommandType.CHAT_INPUT: + if not interaction_options: + return self._slash_commands.get(name) + + # Handle subcommands/groups + current_options = interaction_options + target_command_or_group: Optional[Union[AppCommand, AppCommandGroup]] = ( + self._app_command_groups.get(name) + ) + + if not target_command_or_group: + return self._slash_commands.get(name) + + final_command: Optional[AppCommand] = None + + while current_options: + opt_data = current_options[0] + opt_name = opt_data.get("name") + opt_type = ( + ApplicationCommandOptionType(opt_data["type"]) + if opt_data.get("type") + else None + ) + + if not opt_name or not isinstance( + target_command_or_group, AppCommandGroup + ): + break + + next_target = target_command_or_group.get_command(opt_name) + + if isinstance(next_target, AppCommand) and ( + opt_type == ApplicationCommandOptionType.SUB_COMMAND + or not opt_data.get("options") + ): + final_command = next_target + break + elif ( + isinstance(next_target, AppCommandGroup) + and opt_type == ApplicationCommandOptionType.SUB_COMMAND_GROUP + ): + target_command_or_group = next_target + current_options = opt_data.get("options", []) + if not current_options: + break + else: + break + + return final_command + + if command_type == ApplicationCommandType.USER: + return self._user_commands.get(name) + + if command_type == ApplicationCommandType.MESSAGE: + return self._message_commands.get(name) + + return None + + async def _resolve_option_value( + self, + value: Any, + expected_type: Any, + resolved_data: Optional["ResolvedData"], + guild_id: Optional["Snowflake"], + ) -> Any: + """ + Resolves an option value to the expected Python type using resolved_data. + """ + converter_cls = self.get_converter(expected_type) + if converter_cls: + try: + init_params = inspect.signature(converter_cls.__init__).parameters + if "client" in init_params: + converter_instance = converter_cls(client=self.client) # type: ignore[arg-type] + else: + converter_instance = converter_cls() + return await converter_instance.convert(None, value) # type: ignore[arg-type] + except Exception: + pass + + # This is a simplified resolver. A more robust one would use converters. + if resolved_data: + if expected_type is User or expected_type.__name__ == "User": + return resolved_data.users.get(value) if resolved_data.users else None + + if expected_type is Member or expected_type.__name__ == "Member": + member_obj = ( + resolved_data.members.get(value) if resolved_data.members else None + ) + if member_obj: + if ( + hasattr(member_obj, "username") + and not member_obj.username + and resolved_data.users + ): + user_obj = resolved_data.users.get(value) + if user_obj: + member_obj.username = user_obj.username + member_obj.discriminator = user_obj.discriminator + member_obj.avatar = user_obj.avatar + member_obj.bot = user_obj.bot + member_obj.user = user_obj # type: ignore[attr-defined] + return member_obj + return None + if expected_type is Role or expected_type.__name__ == "Role": + return resolved_data.roles.get(value) if resolved_data.roles else None + if expected_type is Attachment or expected_type.__name__ == "Attachment": + return ( + resolved_data.attachments.get(value) + if resolved_data.attachments + else None + ) + if expected_type is Message or expected_type.__name__ == "Message": + return ( + resolved_data.messages.get(value) + if resolved_data.messages + else None + ) + if "Channel" in expected_type.__name__: + return ( + resolved_data.channels.get(value) + if resolved_data.channels + else None + ) + + # For basic types, Discord already sends them correctly (string, int, bool, float) + if isinstance(value, expected_type): + return value + try: # Attempt direct conversion for basic types if Discord sent string for int/float/bool + if expected_type is int: + return int(value) + if expected_type is float: + return float(value) + if expected_type is bool: # Discord sends true/false + if isinstance(value, str): + return value.lower() == "true" + return bool(value) + except (ValueError, TypeError): + pass # Conversion failed + return value # Return as is if no specific resolution or conversion applied + + async def _resolve_value( + self, + value: Any, + expected_type: Any, + resolved_data: Optional["ResolvedData"], + guild_id: Optional["Snowflake"], + ) -> Any: + """Public wrapper around ``_resolve_option_value`` used by tests.""" + + return await self._resolve_option_value( + value=value, + expected_type=expected_type, + resolved_data=resolved_data, + guild_id=guild_id, + ) + + async def _parse_interaction_options( + self, + command_params: Dict[str, inspect.Parameter], # From command.params + interaction_options: Optional[List[Dict[str, Any]]], + resolved_data: Optional["ResolvedData"], + guild_id: Optional["Snowflake"], + ) -> Tuple[List[Any], Dict[str, Any]]: + """ + Parses options from an interaction payload and maps them to command function arguments. + """ + args_list: List[Any] = [] + kwargs_dict: Dict[str, Any] = {} + + if not interaction_options: # No options provided in interaction + # Check if command has required params without defaults + for name, param in command_params.items(): + if param.default == inspect.Parameter.empty: + # This should ideally be caught by Discord if option is marked required + raise ValueError(f"Missing required option: {name}") + return args_list, kwargs_dict + + # Create a dictionary of provided options by name for easier lookup + provided_options: Dict[str, Any] = { + opt["name"]: opt["value"] for opt in interaction_options if "value" in opt + } + + for name, param in command_params.items(): + if name in provided_options: + raw_value = provided_options[name] + expected_type = ( + param.annotation + if param.annotation != inspect.Parameter.empty + else str + ) + + # Handle Optional[T] + origin_type = get_origin(expected_type) + if origin_type is Union: + union_args = get_args(expected_type) + # Assuming Optional[T] is Union[T, NoneType] + non_none_types = [t for t in union_args if t is not type(None)] + if len(non_none_types) == 1: + expected_type = non_none_types[0] + # Else, complex Union, might need more sophisticated handling or default to raw_value/str + elif origin_type is Literal: + literal_args = get_args(expected_type) + if literal_args: + expected_type = type(literal_args[0]) + else: + expected_type = str + + resolved_value = await self._resolve_option_value( + raw_value, expected_type, resolved_data, guild_id + ) + + if ( + param.kind == inspect.Parameter.KEYWORD_ONLY + or param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ): + kwargs_dict[name] = resolved_value + # Note: Slash commands don't map directly to *args. All options are named. + # So, we'll primarily use kwargs_dict and then construct args_list based on param order if needed, + # but Discord sends named options, so direct kwarg usage is more natural. + elif param.default != inspect.Parameter.empty: + kwargs_dict[name] = param.default + else: + # Required parameter not provided by Discord - this implies an issue with command definition + # or Discord's validation, as Discord should enforce required options. + raise ValueError( + f"Required option '{name}' not found in interaction payload." + ) + + # Populate args_list based on the order in command_params for positional arguments + # This assumes that all args that are not keyword-only are passed positionally if present in kwargs_dict + for name, param in command_params.items(): + if param.kind == inspect.Parameter.POSITIONAL_ONLY or ( + param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and name in kwargs_dict + ): + if name in kwargs_dict: # Ensure it was resolved or had a default + args_list.append(kwargs_dict[name]) + # If it was POSITIONAL_ONLY and not in kwargs_dict, it's an error (already raised) + elif param.kind == inspect.Parameter.VAR_POSITIONAL: # *args + # Slash commands don't map to *args well. This would be empty. + pass + + # Filter kwargs_dict to only include actual KEYWORD_ONLY or POSITIONAL_OR_KEYWORD params + # that were not used for args_list (if strict positional/keyword separation is desired). + # For slash commands, it's simpler to pass all resolved named options as kwargs. + final_kwargs = { + k: v + for k, v in kwargs_dict.items() + if k in command_params + and command_params[k].kind != inspect.Parameter.POSITIONAL_ONLY + } + + # For simplicity with slash commands, let's assume all resolved options are passed via kwargs + # and the command signature is primarily (self, ctx, **options) or (ctx, **options) + # or (self, ctx, option1, option2) where names match. + # The AppCommand.invoke will handle passing them. + # The current args_list and final_kwargs might be redundant if invoke just uses **final_kwargs. + # Let's return kwargs_dict directly for now, and AppCommand.invoke can map them. + + return [], kwargs_dict # Return empty args, all in kwargs for now. + + async def dispatch_app_command_error( + self, context: "AppCommandContext", error: Exception + ) -> None: + """Dispatches an app command error to the client if implemented.""" + if hasattr(self.client, "on_app_command_error"): + await self.client.on_app_command_error(context, error) + + async def process_interaction(self, interaction: "Interaction") -> None: + """Processes an incoming interaction.""" + if interaction.type == InteractionType.MODAL_SUBMIT: + callback = getattr(self.client, "on_modal_submit", None) + if callback is not None: + from typing import Awaitable, Callable, cast + + await cast(Callable[["Interaction"], Awaitable[None]], callback)( + interaction + ) + return + + if interaction.type == InteractionType.APPLICATION_COMMAND_AUTOCOMPLETE: + callback = getattr(self.client, "on_autocomplete", None) + if callback is not None: + from typing import Awaitable, Callable, cast + + await cast(Callable[["Interaction"], Awaitable[None]], callback)( + interaction + ) + return + + if interaction.type != InteractionType.APPLICATION_COMMAND: + return + + if not interaction.data or not interaction.data.name: + from .context import AppCommandContext + + ctx = AppCommandContext( + bot=self.client, interaction=interaction, command=None + ) + await ctx.send("Command not found.", ephemeral=True) + return + + command_name = interaction.data.name + command_type = interaction.data.type or ApplicationCommandType.CHAT_INPUT + command = self.get_command( + command_name, + command_type, + interaction.data.options if interaction.data else None, + ) + + if not command: + from .context import AppCommandContext + + ctx = AppCommandContext( + bot=self.client, interaction=interaction, command=None + ) + await ctx.send(f"Command '{command_name}' not found.", ephemeral=True) + return + + # Create context + from .context import AppCommandContext # Ensure AppCommandContext is available + + ctx = AppCommandContext( + bot=self.client, interaction=interaction, command=command + ) + + try: + # Prepare arguments for the command callback + # Skip 'self' and 'ctx' from command.params for parsing interaction options + params_to_parse = { + name: param + for name, param in command.params.items() + if name not in ("self", "ctx") + } + + if command.type in ( + ApplicationCommandType.USER, + ApplicationCommandType.MESSAGE, + ): + # Context menu commands provide a target_id. Resolve and pass it + args = [] + kwargs = {} + if params_to_parse and interaction.data and interaction.data.target_id: + first_param = next(iter(params_to_parse.values())) + expected = ( + first_param.annotation + if first_param.annotation != inspect.Parameter.empty + else str + ) + resolved = await self._resolve_option_value( + interaction.data.target_id, + expected, + interaction.data.resolved, + interaction.guild_id, + ) + if first_param.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + args.append(resolved) + else: + kwargs[first_param.name] = resolved + + await command.invoke(ctx, *args, **kwargs) + else: + parsed_args, parsed_kwargs = await self._parse_interaction_options( + command_params=params_to_parse, + interaction_options=interaction.data.options, + resolved_data=interaction.data.resolved, + guild_id=interaction.guild_id, + ) + + await command.invoke(ctx, *parsed_args, **parsed_kwargs) + + except Exception as e: + print(f"Error invoking app command '{command.name}': {e}") + await self.dispatch_app_command_error(ctx, e) + # else: + # # Default error reply if no handler on client + # try: + # await ctx.send(f"An error occurred: {e}", ephemeral=True) + # except Exception as send_e: + # print(f"Failed to send error message for app command: {send_e}") + + async def sync_commands( + self, application_id: "Snowflake", guild_id: Optional["Snowflake"] = None + ) -> None: + """ + Synchronizes (registers/updates) all application commands with Discord. + If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands. + """ + commands_to_sync: List[Dict[str, Any]] = [] + + # Collect commands based on scope (global or specific guild) + # This needs to be more sophisticated to handle guild_ids on commands/groups + + source_commands = ( + list(self._slash_commands.values()) + + list(self._user_commands.values()) + + list(self._message_commands.values()) + + list(self._app_command_groups.values()) + ) + + for cmd_or_group in source_commands: + # Determine if this command/group should be synced for the current scope + is_guild_specific_command = ( + cmd_or_group.guild_ids is not None and len(cmd_or_group.guild_ids) > 0 + ) + + if guild_id: # Syncing for a specific guild + # Skip if not a guild-specific command OR if it's for a different guild + if not is_guild_specific_command or ( + cmd_or_group.guild_ids is not None + and guild_id not in cmd_or_group.guild_ids + ): + continue + else: # Syncing global commands + if is_guild_specific_command: + continue # Skip guild-specific commands when syncing global + + # Use the to_dict() method from AppCommand or AppCommandGroup + try: + payload = cmd_or_group.to_dict() + commands_to_sync.append(payload) + except AttributeError: + print( + f"Warning: Command or group '{cmd_or_group.name}' does not have a to_dict() method. Skipping." + ) + except Exception as e: + print( + f"Error converting command/group '{cmd_or_group.name}' to dict: {e}. Skipping." + ) + + if not commands_to_sync: + print( + f"No commands to sync for {'guild ' + str(guild_id) if guild_id else 'global'} scope." + ) + return + + try: + if guild_id: + print( + f"Syncing {len(commands_to_sync)} commands for guild {guild_id}..." + ) + await self.client._http.bulk_overwrite_guild_application_commands( + application_id, guild_id, commands_to_sync + ) + else: + print(f"Syncing {len(commands_to_sync)} global commands...") + await self.client._http.bulk_overwrite_global_application_commands( + application_id, commands_to_sync + ) + print("Command sync successful.") + except Exception as e: + print(f"Error syncing application commands: {e}") + # Consider re-raising or specific error handling diff --git a/disagreement/ext/commands/__init__.py b/disagreement/ext/commands/__init__.py new file mode 100644 index 0000000..779a7cb --- /dev/null +++ b/disagreement/ext/commands/__init__.py @@ -0,0 +1,49 @@ +# disagreement/ext/commands/__init__.py + +""" +disagreement.ext.commands - A command framework extension for the Disagreement library. +""" + +from .cog import Cog +from .core import ( + Command, + CommandContext, + CommandHandler, +) # CommandHandler might be internal +from .decorators import command, listener, check, check_any, cooldown +from .errors import ( + CommandError, + CommandNotFound, + BadArgument, + MissingRequiredArgument, + ArgumentParsingError, + CheckFailure, + CheckAnyFailure, + CommandOnCooldown, + CommandInvokeError, +) + +__all__ = [ + # Cog + "Cog", + # Core + "Command", + "CommandContext", + # "CommandHandler", # Usually not part of public API for direct use by bot devs + # Decorators + "command", + "listener", + "check", + "check_any", + "cooldown", + # Errors + "CommandError", + "CommandNotFound", + "BadArgument", + "MissingRequiredArgument", + "ArgumentParsingError", + "CheckFailure", + "CheckAnyFailure", + "CommandOnCooldown", + "CommandInvokeError", +] diff --git a/disagreement/ext/commands/cog.py b/disagreement/ext/commands/cog.py new file mode 100644 index 0000000..b7c9302 --- /dev/null +++ b/disagreement/ext/commands/cog.py @@ -0,0 +1,155 @@ +# disagreement/ext/commands/cog.py + +import inspect +from typing import TYPE_CHECKING, List, Tuple, Callable, Awaitable, Any, Dict, Union + +if TYPE_CHECKING: + from disagreement.client import Client + from .core import Command + from disagreement.ext.app_commands.commands import ( + AppCommand, + AppCommandGroup, + ) # Added +else: # pragma: no cover - runtime imports for isinstance checks + from disagreement.ext.app_commands.commands import AppCommand, AppCommandGroup + + # EventDispatcher might be needed if cogs register listeners directly + # from disagreement.event_dispatcher import EventDispatcher + + +class Cog: + """ + The base class for cogs, which are collections of commands and listeners. + """ + + def __init__(self, client: "Client"): + self._client: "Client" = client + self._cog_name: str = self.__class__.__name__ + self._commands: Dict[str, "Command"] = {} + self._listeners: List[Tuple[str, Callable[..., Awaitable[None]]]] = [] + self._app_commands_and_groups: List[Union["AppCommand", "AppCommandGroup"]] = ( + [] + ) # Added + + # Discover commands and listeners defined in this cog instance + self._inject() + + @property + def client(self) -> "Client": + return self._client + + @property + def cog_name(self) -> str: + return self._cog_name + + def _inject(self) -> None: + """ + Called to discover and prepare commands and listeners within this cog. + This is typically called by the CommandHandler when adding the cog. + """ + # Clear any previously injected state (e.g., if re-injecting) + self._commands.clear() + self._listeners.clear() + self._app_commands_and_groups.clear() # Added + + for member_name, member in inspect.getmembers(self): + if hasattr(member, "__command_object__"): + # This is a prefix or hybrid command object + cmd: "Command" = getattr(member, "__command_object__") + cmd.cog = self # Assign the cog instance to the command + if cmd.name in self._commands: + # This should ideally be caught earlier or handled by CommandHandler + print( + f"Warning: Duplicate command name '{cmd.name}' in cog '{self.cog_name}'. Overwriting." + ) + self._commands[cmd.name.lower()] = cmd + # Also register aliases + for alias in cmd.aliases: + self._commands[alias.lower()] = cmd + + # If this command is also an application command (HybridCommand) + if isinstance(cmd, (AppCommand, AppCommandGroup)): + self._app_commands_and_groups.append(cmd) + + elif hasattr(member, "__app_command_object__"): # Added for app commands + app_cmd_obj = getattr(member, "__app_command_object__") + if isinstance(app_cmd_obj, (AppCommand, AppCommandGroup)): + if isinstance(app_cmd_obj, AppCommand): + app_cmd_obj.cog = self # Associate cog + # For AppCommandGroup, its commands will have cog set individually if they are AppCommands + self._app_commands_and_groups.append(app_cmd_obj) + else: + print( + f"Warning: Member '{member_name}' in cog '{self.cog_name}' has '__app_command_object__' but it's not an AppCommand or AppCommandGroup." + ) + + elif isinstance(member, (AppCommand, AppCommandGroup)): + if isinstance(member, AppCommand): + member.cog = self + self._app_commands_and_groups.append(member) + + elif hasattr(member, "__listener_name__"): + # This is a method decorated with @commands.Cog.listener or @commands.listener + if not inspect.iscoroutinefunction(member): + # Decorator should have caught this, but double check + print( + f"Warning: Listener '{member_name}' in cog '{self.cog_name}' is not a coroutine. Skipping." + ) + continue + + event_name: str = getattr(member, "__listener_name__") + # The callback needs to be the bound method from this cog instance + self._listeners.append((event_name, member)) + + def _eject(self) -> None: + """ + Called when the cog is being removed. + The CommandHandler will handle unregistering commands/listeners. + This method is for any cog-specific cleanup before that. + """ + # For now, just clear local collections. Actual unregistration is external. + self._commands.clear() + self._listeners.clear() + self._app_commands_and_groups.clear() # Added + + def get_commands(self) -> List["Command"]: + """Returns a list of commands in this cog.""" + # Avoid duplicates if aliases point to the same command object + return list(dict.fromkeys(self._commands.values())) + + def get_listeners(self) -> List[Tuple[str, Callable[..., Awaitable[None]]]]: + """Returns a list of (event_name, callback) tuples for listeners in this cog.""" + return self._listeners + + def get_app_commands_and_groups( + self, + ) -> List[Union["AppCommand", "AppCommandGroup"]]: + """Returns a list of application commands and groups in this cog.""" + return self._app_commands_and_groups + + async def cog_load(self) -> None: + """ + A special method that is called when the cog is loaded. + This is a good place for any asynchronous setup. + Subclasses should override this if they need async setup. + """ + pass + + async def cog_unload(self) -> None: + """ + A special method that is called when the cog is unloaded. + This is a good place for any asynchronous cleanup. + Subclasses should override this if they need async cleanup. + """ + pass + + # Example of how a listener might be defined within a Cog using the decorator + # from .decorators import listener # Would be imported at module level + # + # @listener(name="ON_MESSAGE_CREATE_CUSTOM") # Explicit name + # async def on_my_event(self, message: 'Message'): + # print(f"Cog '{self.cog_name}' received event with message: {message.content}") + # + # @listener() # Name derived from method: on_ready + # async def on_ready(self): + # print(f"Cog '{self.cog_name}' is ready.") diff --git a/disagreement/ext/commands/converters.py b/disagreement/ext/commands/converters.py new file mode 100644 index 0000000..23ff879 --- /dev/null +++ b/disagreement/ext/commands/converters.py @@ -0,0 +1,175 @@ +# disagreement/ext/commands/converters.py + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Generic +from abc import ABC, abstractmethod +import re + +from .errors import BadArgument +from disagreement.models import Member, Guild, Role + +if TYPE_CHECKING: + from .core import CommandContext + +T = TypeVar("T") + + +class Converter(ABC, Generic[T]): + """ + Base class for custom command argument converters. + Subclasses must implement the `convert` method. + """ + + async def convert(self, ctx: "CommandContext", argument: str) -> T: + """ + Converts the argument to the desired type. + + Args: + ctx: The invocation context. + argument: The string argument to convert. + + Returns: + The converted argument. + + Raises: + BadArgument: If the conversion fails. + """ + raise NotImplementedError("Converter subclass must implement convert method.") + + +# --- Built-in Type Converters --- + + +class IntConverter(Converter[int]): + async def convert(self, ctx: "CommandContext", argument: str) -> int: + try: + return int(argument) + except ValueError: + raise BadArgument(f"'{argument}' is not a valid integer.") + + +class FloatConverter(Converter[float]): + async def convert(self, ctx: "CommandContext", argument: str) -> float: + try: + return float(argument) + except ValueError: + raise BadArgument(f"'{argument}' is not a valid number.") + + +class BoolConverter(Converter[bool]): + async def convert(self, ctx: "CommandContext", argument: str) -> bool: + lowered = argument.lower() + if lowered in ("yes", "y", "true", "t", "1", "on", "enable", "enabled"): + return True + elif lowered in ("no", "n", "false", "f", "0", "off", "disable", "disabled"): + return False + raise BadArgument(f"'{argument}' is not a valid boolean-like value.") + + +class StringConverter(Converter[str]): + async def convert(self, ctx: "CommandContext", argument: str) -> str: + # For basic string, no conversion is needed, but this provides a consistent interface + return argument + + +# --- Discord Model Converters --- + + +class MemberConverter(Converter["Member"]): + async def convert(self, ctx: "CommandContext", argument: str) -> "Member": + if not ctx.message.guild_id: + raise BadArgument("Member converter requires guild context.") + + match = re.match(r"<@!?(\d+)>$", argument) + member_id = match.group(1) if match else argument + + guild = ctx.bot.get_guild(ctx.message.guild_id) + if guild: + member = guild.get_member(member_id) + if member: + return member + + member = await ctx.bot.fetch_member(ctx.message.guild_id, member_id) + if member: + return member + raise BadArgument(f"Member '{argument}' not found.") + + +class RoleConverter(Converter["Role"]): + async def convert(self, ctx: "CommandContext", argument: str) -> "Role": + if not ctx.message.guild_id: + raise BadArgument("Role converter requires guild context.") + + match = re.match(r"<@&(?P\d+)>$", argument) + role_id = match.group("id") if match else argument + + guild = ctx.bot.get_guild(ctx.message.guild_id) + if guild: + role = guild.get_role(role_id) + if role: + return role + + role = await ctx.bot.fetch_role(ctx.message.guild_id, role_id) + if role: + return role + raise BadArgument(f"Role '{argument}' not found.") + + +class GuildConverter(Converter["Guild"]): + async def convert(self, ctx: "CommandContext", argument: str) -> "Guild": + guild_id = argument.strip("<>") # allow style + + guild = ctx.bot.get_guild(guild_id) + if guild: + return guild + + guild = await ctx.bot.fetch_guild(guild_id) + if guild: + return guild + raise BadArgument(f"Guild '{argument}' not found.") + + +# Default converters mapping +DEFAULT_CONVERTERS: dict[type, Converter[Any]] = { + int: IntConverter(), + float: FloatConverter(), + bool: BoolConverter(), + str: StringConverter(), + Member: MemberConverter(), + Guild: GuildConverter(), + Role: RoleConverter(), + # User: UserConverter(), # Add when User model and converter are ready +} + + +async def run_converters(ctx: "CommandContext", annotation: Any, argument: str) -> Any: + """ + Attempts to run a converter for the given annotation and argument. + """ + converter = DEFAULT_CONVERTERS.get(annotation) + if converter: + return await converter.convert(ctx, argument) + + # If no direct converter, check if annotation itself is a Converter subclass + if inspect.isclass(annotation) and issubclass(annotation, Converter): + try: + instance = annotation() # type: ignore + return await instance.convert(ctx, argument) + except Exception as e: # Catch instantiation errors or other issues + raise BadArgument( + f"Failed to use custom converter {annotation.__name__}: {e}" + ) + + # If it's a custom class that's not a Converter, we can't handle it by default + # Or if it's a complex type hint like Union, Optional, Literal etc. + # This part needs more advanced logic for those. + + # For now, if no specific converter, and it's not 'str', raise error or return as str? + # Let's be strict for now if an annotation is given but no converter found. + if annotation is not str and annotation is not inspect.Parameter.empty: + raise BadArgument(f"No converter found for type annotation '{annotation}'.") + + return argument # Default to string if no annotation or annotation is str + + +# Need to import inspect for the run_converters function +import inspect diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py new file mode 100644 index 0000000..15694ee --- /dev/null +++ b/disagreement/ext/commands/core.py @@ -0,0 +1,490 @@ +# disagreement/ext/commands/core.py + +import asyncio +import inspect +from typing import ( + TYPE_CHECKING, + Optional, + List, + Dict, + Any, + Union, + Callable, + Awaitable, + Tuple, + get_origin, + get_args, +) + +from .view import StringView +from .errors import ( + CommandError, + CommandNotFound, + BadArgument, + MissingRequiredArgument, + ArgumentParsingError, + CheckFailure, + CommandInvokeError, +) +from .converters import run_converters, DEFAULT_CONVERTERS, Converter +from .cog import Cog +from disagreement.typing import Typing + +if TYPE_CHECKING: + from disagreement.client import Client + from disagreement.models import Message, User + + +class Command: + """ + Represents a bot command. + + Attributes: + name (str): The primary name of the command. + callback (Callable[..., Awaitable[None]]): The coroutine function to execute. + aliases (List[str]): Alternative names for the command. + brief (Optional[str]): A short description for help commands. + description (Optional[str]): A longer description for help commands. + cog (Optional['Cog']): Reference to the Cog this command belongs to. + params (Dict[str, inspect.Parameter]): Cached parameters of the callback. + """ + + def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any): + if not asyncio.iscoroutinefunction(callback): + raise TypeError("Command callback must be a coroutine function.") + + self.callback: Callable[..., Awaitable[None]] = callback + self.name: str = attrs.get("name", callback.__name__) + self.aliases: List[str] = attrs.get("aliases", []) + self.brief: Optional[str] = attrs.get("brief") + self.description: Optional[str] = attrs.get("description") or callback.__doc__ + self.cog: Optional["Cog"] = attrs.get("cog") + + self.params = inspect.signature(callback).parameters + self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = [] + if hasattr(callback, "__command_checks__"): + self.checks.extend(getattr(callback, "__command_checks__")) + + def add_check( + self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool] + ) -> None: + self.checks.append(predicate) + + async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None: + from .errors import CheckFailure + + for predicate in self.checks: + result = predicate(ctx) + if inspect.isawaitable(result): + result = await result + if not result: + raise CheckFailure("Check predicate failed.") + + if self.cog: + await self.callback(self.cog, ctx, *args, **kwargs) + else: + await self.callback(ctx, *args, **kwargs) + + +class CommandContext: + """ + Represents the context in which a command is being invoked. + """ + + def __init__( + self, + *, + message: "Message", + bot: "Client", + prefix: str, + command: "Command", + invoked_with: str, + args: Optional[List[Any]] = None, + kwargs: Optional[Dict[str, Any]] = None, + cog: Optional["Cog"] = None, + ): + self.message: "Message" = message + self.bot: "Client" = bot + self.prefix: str = prefix + self.command: "Command" = command + self.invoked_with: str = invoked_with + self.args: List[Any] = args or [] + self.kwargs: Dict[str, Any] = kwargs or {} + self.cog: Optional["Cog"] = cog + + self.author: "User" = message.author + + async def reply( + self, + content: str, + *, + mention_author: Optional[bool] = None, + **kwargs: Any, + ) -> "Message": + """Replies to the invoking message. + + Parameters + ---------- + content: str + The content to send. + mention_author: Optional[bool] + Whether to mention the author in the reply. If ``None`` the + client's :attr:`mention_replies` value is used. + """ + + allowed_mentions = kwargs.pop("allowed_mentions", None) + if mention_author is None: + mention_author = getattr(self.bot, "mention_replies", False) + + if allowed_mentions is None: + allowed_mentions = {"replied_user": mention_author} + else: + allowed_mentions = dict(allowed_mentions) + allowed_mentions.setdefault("replied_user", mention_author) + + return await self.bot.send_message( + channel_id=self.message.channel_id, + content=content, + message_reference={ + "message_id": self.message.id, + "channel_id": self.message.channel_id, + "guild_id": self.message.guild_id, + }, + allowed_mentions=allowed_mentions, + **kwargs, + ) + + async def send(self, content: str, **kwargs: Any) -> "Message": + return await self.bot.send_message( + channel_id=self.message.channel_id, content=content, **kwargs + ) + + async def edit( + self, + message: Union[str, "Message"], + *, + content: Optional[str] = None, + **kwargs: Any, + ) -> "Message": + """Edits a message previously sent by the bot.""" + + message_id = message if isinstance(message, str) else message.id + return await self.bot.edit_message( + channel_id=self.message.channel_id, + message_id=message_id, + content=content, + **kwargs, + ) + + def typing(self) -> "Typing": + """Return a typing context manager for this context's channel.""" + + return self.bot.typing(self.message.channel_id) + + +class CommandHandler: + """ + Manages command registration, parsing, and dispatching. + """ + + def __init__( + self, + client: "Client", + prefix: Union[ + str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] + ], + ): + self.client: "Client" = client + self.prefix: Union[ + str, List[str], Callable[["Client", "Message"], Union[str, List[str]]] + ] = prefix + self.commands: Dict[str, Command] = {} + self.cogs: Dict[str, "Cog"] = {} + + from .help import HelpCommand + + self.add_command(HelpCommand(self)) + + def add_command(self, command: Command) -> None: + if command.name in self.commands: + raise ValueError(f"Command '{command.name}' is already registered.") + + self.commands[command.name.lower()] = command + for alias in command.aliases: + if alias in self.commands: + print( + f"Warning: Alias '{alias}' for command '{command.name}' conflicts with an existing command or alias." + ) + self.commands[alias.lower()] = command + + def remove_command(self, name: str) -> Optional[Command]: + command = self.commands.pop(name.lower(), None) + if command: + for alias in command.aliases: + self.commands.pop(alias.lower(), None) + return command + + def get_command(self, name: str) -> Optional[Command]: + return self.commands.get(name.lower()) + + def add_cog(self, cog_to_add: "Cog") -> None: + if not isinstance(cog_to_add, Cog): + raise TypeError("Argument must be a subclass of Cog.") + + if cog_to_add.cog_name in self.cogs: + raise ValueError( + f"Cog with name '{cog_to_add.cog_name}' is already registered." + ) + + self.cogs[cog_to_add.cog_name] = cog_to_add + + for cmd in cog_to_add.get_commands(): + self.add_command(cmd) + + if hasattr(self.client, "_event_dispatcher"): + for event_name, callback in cog_to_add.get_listeners(): + self.client._event_dispatcher.register(event_name.upper(), callback) + else: + print( + f"Warning: Client does not have '_event_dispatcher'. Listeners for cog '{cog_to_add.cog_name}' not registered." + ) + + if hasattr(cog_to_add, "cog_load") and inspect.iscoroutinefunction( + cog_to_add.cog_load + ): + asyncio.create_task(cog_to_add.cog_load()) + + print(f"Cog '{cog_to_add.cog_name}' added.") + + def remove_cog(self, cog_name: str) -> Optional["Cog"]: + cog_to_remove = self.cogs.pop(cog_name, None) + if cog_to_remove: + for cmd in cog_to_remove.get_commands(): + self.remove_command(cmd.name) + + if hasattr(self.client, "_event_dispatcher"): + for event_name, callback in cog_to_remove.get_listeners(): + print( + f"Note: Listener '{callback.__name__}' for event '{event_name}' from cog '{cog_name}' needs manual unregistration logic in EventDispatcher." + ) + + if hasattr(cog_to_remove, "cog_unload") and inspect.iscoroutinefunction( + cog_to_remove.cog_unload + ): + asyncio.create_task(cog_to_remove.cog_unload()) + + cog_to_remove._eject() + print(f"Cog '{cog_name}' removed.") + return cog_to_remove + + async def get_prefix(self, message: "Message") -> Union[str, List[str], None]: + if callable(self.prefix): + if inspect.iscoroutinefunction(self.prefix): + return await self.prefix(self.client, message) + else: + return self.prefix(self.client, message) # type: ignore + return self.prefix + + async def _parse_arguments( + self, command: Command, ctx: CommandContext, view: StringView + ) -> Tuple[List[Any], Dict[str, Any]]: + args_list = [] + kwargs_dict = {} + params_to_parse = list(command.params.values()) + + if params_to_parse and params_to_parse[0].name == "self" and command.cog: + params_to_parse.pop(0) + if params_to_parse and params_to_parse[0].name == "ctx": + params_to_parse.pop(0) + + for param in params_to_parse: + view.skip_whitespace() + final_value_for_param: Any = inspect.Parameter.empty + + if param.kind == inspect.Parameter.VAR_POSITIONAL: + while not view.eof: + view.skip_whitespace() + if view.eof: + break + word = view.get_word() + if word or not view.eof: + args_list.append(word) + elif view.eof: + break + break + + arg_str_value: Optional[str] = ( + None # Holds the raw string for current param + ) + + if view.eof: # No more input string + if param.default is not inspect.Parameter.empty: + final_value_for_param = param.default + elif param.kind != inspect.Parameter.VAR_KEYWORD: + raise MissingRequiredArgument(param.name) + else: # VAR_KEYWORD at EOF is fine + break + else: # Input available + is_last_pos_str_greedy = ( + param == params_to_parse[-1] + and param.annotation is str + and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + + if is_last_pos_str_greedy: + arg_str_value = view.read_rest().strip() + if ( + not arg_str_value + and param.default is not inspect.Parameter.empty + ): + final_value_for_param = param.default + else: # Includes empty string if that's what's left + final_value_for_param = arg_str_value + else: # Not greedy, or not string, or not last positional + if view.buffer[view.index] == '"': + arg_str_value = view.get_quoted_string() + if arg_str_value == "" and view.buffer[view.index] == '"': + raise BadArgument( + f"Unterminated quoted string for argument '{param.name}'." + ) + else: + arg_str_value = view.get_word() + + # If final_value_for_param was not set by greedy logic, try conversion + if final_value_for_param is inspect.Parameter.empty: + if ( + arg_str_value is None + ): # Should not happen if view.get_word/get_quoted_string is robust + if param.default is not inspect.Parameter.empty: + final_value_for_param = param.default + else: + raise MissingRequiredArgument(param.name) + else: # We have an arg_str_value (could be empty string "" from quotes) + annotation = param.annotation + origin = get_origin(annotation) + + if origin is Union: # Handles Optional[T] and Union[T1, T2] + union_args = get_args(annotation) + is_optional = ( + len(union_args) == 2 and type(None) in union_args + ) + + converted_for_union = False + last_err_union: Optional[BadArgument] = None + for t_arg in union_args: + if t_arg is type(None): + continue + try: + final_value_for_param = await run_converters( + ctx, t_arg, arg_str_value + ) + converted_for_union = True + break + except BadArgument as e: + last_err_union = e + + if not converted_for_union: + if ( + is_optional and param.default is None + ): # Special handling for Optional[T] if conversion failed + # If arg_str_value was "" and type was Optional[str], StringConverter would return "" + # If arg_str_value was "" and type was Optional[int], BadArgument would be raised. + # This path is for when all actual types in Optional[T] fail conversion. + # If default is None, we can assign None. + final_value_for_param = None + elif last_err_union: + raise last_err_union + else: # Should not be reached if logic is correct + raise BadArgument( + f"Could not convert '{arg_str_value}' to any of {union_args} for param '{param.name}'." + ) + elif annotation is inspect.Parameter.empty or annotation is str: + final_value_for_param = arg_str_value + else: # Standard type hint + final_value_for_param = await run_converters( + ctx, annotation, arg_str_value + ) + + # Final check if value was resolved + if final_value_for_param is inspect.Parameter.empty: + if param.default is not inspect.Parameter.empty: + final_value_for_param = param.default + elif param.kind != inspect.Parameter.VAR_KEYWORD: + # This state implies an issue if required and no default, and no input was parsed. + raise MissingRequiredArgument( + f"Parameter '{param.name}' could not be resolved." + ) + + # Assign to args_list or kwargs_dict if a value was determined + if final_value_for_param is not inspect.Parameter.empty: + if ( + param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + or param.kind == inspect.Parameter.POSITIONAL_ONLY + ): + args_list.append(final_value_for_param) + elif param.kind == inspect.Parameter.KEYWORD_ONLY: + kwargs_dict[param.name] = final_value_for_param + + return args_list, kwargs_dict + + async def process_commands(self, message: "Message") -> None: + if not message.content: + return + + prefix_to_use = await self.get_prefix(message) + if not prefix_to_use: + return + + actual_prefix: Optional[str] = None + if isinstance(prefix_to_use, list): + for p in prefix_to_use: + if message.content.startswith(p): + actual_prefix = p + break + if not actual_prefix: + return + elif isinstance(prefix_to_use, str): + if message.content.startswith(prefix_to_use): + actual_prefix = prefix_to_use + else: + return + else: + return + + if actual_prefix is None: + return + + content_without_prefix = message.content[len(actual_prefix) :] + view = StringView(content_without_prefix) + + command_name = view.get_word() + if not command_name: + return + + command = self.get_command(command_name) + if not command: + return + + ctx = CommandContext( + message=message, + bot=self.client, + prefix=actual_prefix, + command=command, + invoked_with=command_name, + cog=command.cog, + ) + + try: + parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view) + ctx.args = parsed_args + ctx.kwargs = parsed_kwargs + await command.invoke(ctx, *parsed_args, **parsed_kwargs) + except CommandError as e: + print(f"Command error for '{command.name}': {e}") + if hasattr(self.client, "on_command_error"): + await self.client.on_command_error(ctx, e) + except Exception as e: + print(f"Unexpected error invoking command '{command.name}': {e}") + exc = CommandInvokeError(e) + if hasattr(self.client, "on_command_error"): + await self.client.on_command_error(ctx, exc) diff --git a/disagreement/ext/commands/decorators.py b/disagreement/ext/commands/decorators.py new file mode 100644 index 0000000..e988ea2 --- /dev/null +++ b/disagreement/ext/commands/decorators.py @@ -0,0 +1,150 @@ +# disagreement/ext/commands/decorators.py + +import asyncio +import inspect +import time +from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable + +if TYPE_CHECKING: + from .core import Command, CommandContext # For type hinting return or internal use + + # from .cog import Cog # For Cog specific decorators + + +def command( + name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any +) -> Callable: + """ + A decorator that transforms a function into a Command. + + Args: + name (Optional[str]): The name of the command. Defaults to the function name. + aliases (Optional[List[str]]): Alternative names for the command. + **attrs: Additional attributes to pass to the Command constructor + (e.g., brief, description, hidden). + + Returns: + Callable: A decorator that registers the command. + """ + + def decorator( + func: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + if not asyncio.iscoroutinefunction(func): + raise TypeError("Command callback must be a coroutine function.") + + from .core import ( + Command, + ) # Late import to avoid circular dependencies at module load time + + # The actual registration will happen when a Cog is added or if commands are global. + # For now, this decorator creates a Command instance and attaches it to the function, + # or returns a Command instance that can be collected. + + cmd_name = name or func.__name__ + + # Store command attributes on the function itself for later collection by Cog or Client + # This is a common pattern. + if hasattr(func, "__command_attrs__"): + # This case might occur if decorators are stacked in an unusual way, + # or if a function is decorated multiple times (which should be disallowed or handled). + # For now, let's assume one @command decorator per function. + raise TypeError("Function is already a command or has command attributes.") + + # Create the command object. It will be registered by the Cog or Client. + cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs) + + # We can attach the command object to the function, so Cogs can find it. + func.__command_object__ = cmd # type: ignore # type: ignore[attr-defined] + return func # Return the original function, now marked. + # Or return `cmd` if commands are registered globally immediately. + # For Cogs, returning `func` and letting Cog collect is cleaner. + + return decorator + + +def listener( + name: Optional[str] = None, +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """ + A decorator that marks a function as an event listener within a Cog. + The actual registration happens when the Cog is added to the client. + + Args: + name (Optional[str]): The name of the event to listen to. + Defaults to the function name (e.g., `on_message`). + """ + + def decorator( + func: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + if not asyncio.iscoroutinefunction(func): + raise TypeError("Listener callback must be a coroutine function.") + + # 'name' here is from the outer 'listener' scope (closure) + actual_event_name = name or func.__name__ + # Store listener info on the function for Cog to collect + setattr(func, "__listener_name__", actual_event_name) + return func + + return decorator # This must be correctly indented under 'listener' + + +def check( + predicate: Callable[["CommandContext"], Awaitable[bool] | bool], +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Decorator to add a check to a command.""" + + def decorator( + func: Callable[..., Awaitable[None]], + ) -> Callable[..., Awaitable[None]]: + checks = getattr(func, "__command_checks__", []) + checks.append(predicate) + setattr(func, "__command_checks__", checks) + return func + + return decorator + + +def check_any( + *predicates: Callable[["CommandContext"], Awaitable[bool] | bool] +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Decorator that passes if any predicate returns ``True``.""" + + async def predicate(ctx: "CommandContext") -> bool: + from .errors import CheckAnyFailure, CheckFailure + + errors = [] + for p in predicates: + try: + result = p(ctx) + if inspect.isawaitable(result): + result = await result + if result: + return True + except CheckFailure as e: + errors.append(e) + raise CheckAnyFailure(errors) + + return check(predicate) + + +def cooldown( + rate: int, per: float +) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]: + """Simple per-user cooldown decorator.""" + + buckets: dict[str, dict[str, float]] = {} + + async def predicate(ctx: "CommandContext") -> bool: + from .errors import CommandOnCooldown + + now = time.monotonic() + user_buckets = buckets.setdefault(ctx.command.name, {}) + reset = user_buckets.get(ctx.author.id, 0) + if now < reset: + raise CommandOnCooldown(reset - now) + user_buckets[ctx.author.id] = now + per + return True + + return check(predicate) diff --git a/disagreement/ext/commands/errors.py b/disagreement/ext/commands/errors.py new file mode 100644 index 0000000..5fa6f06 --- /dev/null +++ b/disagreement/ext/commands/errors.py @@ -0,0 +1,76 @@ +# disagreement/ext/commands/errors.py + +""" +Custom exceptions for the command extension. +""" + +from disagreement.errors import DisagreementException + + +class CommandError(DisagreementException): + """Base exception for errors raised by the commands extension.""" + + pass + + +class CommandNotFound(CommandError): + """Exception raised when a command is not found.""" + + def __init__(self, command_name: str): + self.command_name = command_name + super().__init__(f"Command '{command_name}' not found.") + + +class BadArgument(CommandError): + """Exception raised when a command argument fails to parse or validate.""" + + pass + + +class MissingRequiredArgument(BadArgument): + """Exception raised when a required command argument is missing.""" + + def __init__(self, param_name: str): + self.param_name = param_name + super().__init__(f"Missing required argument: {param_name}") + + +class ArgumentParsingError(BadArgument): + """Exception raised during the argument parsing process.""" + + pass + + +class CheckFailure(CommandError): + """Exception raised when a command check fails.""" + + pass + + +class CheckAnyFailure(CheckFailure): + """Raised when :func:`check_any` fails all checks.""" + + def __init__(self, errors: list[CheckFailure]): + self.errors = errors + msg = "; ".join(str(e) for e in errors) + super().__init__(f"All checks failed: {msg}") + + +class CommandOnCooldown(CheckFailure): + """Raised when a command is invoked while on cooldown.""" + + def __init__(self, retry_after: float): + self.retry_after = retry_after + super().__init__(f"Command is on cooldown. Retry in {retry_after:.2f}s") + + +class CommandInvokeError(CommandError): + """Exception raised when an error occurs during command invocation.""" + + def __init__(self, original: Exception): + self.original = original + super().__init__(f"Error during command invocation: {original}") + + +# Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc. +# These might inherit from BadArgument. diff --git a/disagreement/ext/commands/help.py b/disagreement/ext/commands/help.py new file mode 100644 index 0000000..f81c961 --- /dev/null +++ b/disagreement/ext/commands/help.py @@ -0,0 +1,37 @@ +# disagreement/ext/commands/help.py + +from typing import List, Optional + +from .core import Command, CommandContext, CommandHandler + + +class HelpCommand(Command): + """Built-in command that displays help information for other commands.""" + + def __init__(self, handler: CommandHandler) -> None: + self.handler = handler + + async def callback(ctx: CommandContext, command: Optional[str] = None) -> None: + if command: + cmd = handler.get_command(command) + if not cmd or cmd.name.lower() != command.lower(): + await ctx.send(f"Command '{command}' not found.") + return + description = cmd.description or cmd.brief or "No description provided." + await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}") + else: + lines: List[str] = [] + for registered in dict.fromkeys(handler.commands.values()): + brief = registered.brief or registered.description or "" + lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip()) + if lines: + await ctx.send("\n".join(lines)) + else: + await ctx.send("No commands available.") + + super().__init__( + callback, + name="help", + brief="Show command help.", + description="Displays help for commands.", + ) diff --git a/disagreement/ext/commands/view.py b/disagreement/ext/commands/view.py new file mode 100644 index 0000000..d691266 --- /dev/null +++ b/disagreement/ext/commands/view.py @@ -0,0 +1,103 @@ +# disagreement/ext/commands/view.py + +import re + + +class StringView: + """ + A utility class to help with parsing strings, particularly for command arguments. + It keeps track of the current position in the string and provides methods + to read parts of it. + """ + + def __init__(self, buffer: str): + self.buffer: str = buffer + self.original: str = buffer # Keep original for error reporting if needed + self.index: int = 0 + self.end: int = len(buffer) + self.previous: int = 0 # Index before the last successful read + + @property + def remaining(self) -> str: + """Returns the rest of the string that hasn't been consumed.""" + return self.buffer[self.index :] + + @property + def eof(self) -> bool: + """Checks if the end of the string has been reached.""" + return self.index >= self.end + + def skip_whitespace(self) -> None: + """Skips any leading whitespace from the current position.""" + while not self.eof and self.buffer[self.index].isspace(): + self.index += 1 + + def get_word(self) -> str: + """ + Reads a "word" from the current position. + A word is a sequence of non-whitespace characters. + """ + self.skip_whitespace() + if self.eof: + return "" + + self.previous = self.index + match = re.match(r"\S+", self.buffer[self.index :]) + if match: + word = match.group(0) + self.index += len(word) + return word + return "" # Should not happen if not eof and skip_whitespace was called + + def get_quoted_string(self) -> str: + """ + Reads a string enclosed in double quotes. + Handles escaped quotes inside the string. + """ + self.skip_whitespace() + if self.eof or self.buffer[self.index] != '"': + return "" # Or raise an error, or return None + + self.previous = self.index + self.index += 1 # Skip the opening quote + result = [] + escaped = False + + while not self.eof: + char = self.buffer[self.index] + self.index += 1 + + if escaped: + result.append(char) + escaped = False + elif char == "\\": + escaped = True + elif char == '"': + return "".join(result) # Closing quote found + else: + result.append(char) + + # If loop finishes, means EOF was reached before closing quote + # This is an error condition. Restore index and indicate failure. + self.index = self.previous + # Consider raising an error like UnterminatedQuotedStringError + return "" # Or raise + + def read_rest(self) -> str: + """Reads all remaining characters from the current position.""" + self.skip_whitespace() + if self.eof: + return "" + + self.previous = self.index + result = self.buffer[self.index :] + self.index = self.end + return result + + def undo(self) -> None: + """Resets the current position to before the last successful read.""" + self.index = self.previous + + # Could add more methods like: + # peek() - look at next char without consuming + # match_regex(pattern) - consume if regex matches diff --git a/disagreement/ext/loader.py b/disagreement/ext/loader.py new file mode 100644 index 0000000..8e8d790 --- /dev/null +++ b/disagreement/ext/loader.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from importlib import import_module +import sys +from types import ModuleType +from typing import Dict + +__all__ = ["load_extension", "unload_extension"] + +_loaded_extensions: Dict[str, ModuleType] = {} + + +def load_extension(name: str) -> ModuleType: + """Load an extension by name. + + The extension module must define a ``setup`` coroutine or function that + will be called after loading. Any value returned by ``setup`` is ignored. + """ + + if name in _loaded_extensions: + raise ValueError(f"Extension '{name}' already loaded") + + module = import_module(name) + + if not hasattr(module, "setup"): + raise ImportError(f"Extension '{name}' does not define a setup function") + + module.setup() + _loaded_extensions[name] = module + return module + + +def unload_extension(name: str) -> None: + """Unload a previously loaded extension.""" + + module = _loaded_extensions.pop(name, None) + if module is None: + raise ValueError(f"Extension '{name}' is not loaded") + + if hasattr(module, "teardown"): + module.teardown() + + sys.modules.pop(name, None) diff --git a/disagreement/ext/tasks.py b/disagreement/ext/tasks.py new file mode 100644 index 0000000..028c0da --- /dev/null +++ b/disagreement/ext/tasks.py @@ -0,0 +1,89 @@ +import asyncio +from typing import Any, Awaitable, Callable, Optional + +__all__ = ["loop", "Task"] + + +class Task: + """Simple repeating task.""" + + def __init__(self, coro: Callable[..., Awaitable[Any]], *, seconds: float) -> None: + self._coro = coro + self._seconds = float(seconds) + self._task: Optional[asyncio.Task[None]] = None + + async def _run(self, *args: Any, **kwargs: Any) -> None: + try: + while True: + await self._coro(*args, **kwargs) + await asyncio.sleep(self._seconds) + except asyncio.CancelledError: + pass + + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: + if self._task is None or self._task.done(): + self._task = asyncio.create_task(self._run(*args, **kwargs)) + return self._task + + def stop(self) -> None: + if self._task is not None: + self._task.cancel() + self._task = None + + @property + def running(self) -> bool: + return self._task is not None and not self._task.done() + + +class _Loop: + def __init__(self, func: Callable[..., Awaitable[Any]], seconds: float) -> None: + self.func = func + self.seconds = seconds + self._task: Optional[Task] = None + self._owner: Any = None + + def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop": + return _BoundLoop(self, obj) + + def _coro(self, *args: Any, **kwargs: Any) -> Awaitable[Any]: + if self._owner is None: + return self.func(*args, **kwargs) + return self.func(self._owner, *args, **kwargs) + + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: + self._task = Task(self._coro, seconds=self.seconds) + return self._task.start(*args, **kwargs) + + def stop(self) -> None: + if self._task is not None: + self._task.stop() + + @property + def running(self) -> bool: + return self._task.running if self._task else False + + +class _BoundLoop: + def __init__(self, parent: _Loop, owner: Any) -> None: + self._parent = parent + self._owner = owner + + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: + self._parent._owner = self._owner + return self._parent.start(*args, **kwargs) + + def stop(self) -> None: + self._parent.stop() + + @property + def running(self) -> bool: + return self._parent.running + + +def loop(*, seconds: float) -> Callable[[Callable[..., Awaitable[Any]]], _Loop]: + """Decorator to create a looping task.""" + + def decorator(func: Callable[..., Awaitable[Any]]) -> _Loop: + return _Loop(func, seconds) + + return decorator diff --git a/disagreement/gateway.py b/disagreement/gateway.py new file mode 100644 index 0000000..91f0539 --- /dev/null +++ b/disagreement/gateway.py @@ -0,0 +1,490 @@ +# disagreement/gateway.py + +""" +Manages the WebSocket connection to the Discord Gateway. +""" + +import asyncio +import traceback +import aiohttp +import json +import zlib +import time +from typing import Optional, TYPE_CHECKING, Any, Dict + +from .enums import GatewayOpcode, GatewayIntent +from .errors import GatewayException, DisagreementException, AuthenticationError +from .interactions import Interaction + +if TYPE_CHECKING: + from .client import Client # For type hinting + from .event_dispatcher import EventDispatcher + from .http import HTTPClient + from .interactions import Interaction # Added for INTERACTION_CREATE + +# ZLIB Decompression constants +ZLIB_SUFFIX = b"\x00\x00\xff\xff" +MAX_DECOMPRESSION_SIZE = 10 * 1024 * 1024 # 10 MiB, adjust as needed + + +class GatewayClient: + """ + Handles the Discord Gateway WebSocket connection, heartbeating, and event dispatching. + """ + + def __init__( + self, + http_client: "HTTPClient", + event_dispatcher: "EventDispatcher", + token: str, + intents: int, + client_instance: "Client", # Pass the main client instance + verbose: bool = False, + *, + shard_id: Optional[int] = None, + shard_count: Optional[int] = None, + ): + self._http: "HTTPClient" = http_client + self._dispatcher: "EventDispatcher" = event_dispatcher + self._token: str = token + self._intents: int = intents + self._client_instance: "Client" = client_instance # Store client instance + self.verbose: bool = verbose + self._shard_id: Optional[int] = shard_id + self._shard_count: Optional[int] = shard_count + + self._ws: Optional[aiohttp.ClientWebSocketResponse] = None + self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self._heartbeat_interval: Optional[float] = None + self._last_sequence: Optional[int] = None + self._session_id: Optional[str] = None + self._resume_gateway_url: Optional[str] = None + + self._keep_alive_task: Optional[asyncio.Task] = None + self._receive_task: Optional[asyncio.Task] = None + + # For zlib decompression + self._buffer = bytearray() + self._inflator = zlib.decompressobj() + + async def _decompress_message( + self, message_bytes: bytes + ) -> Optional[Dict[str, Any]]: + """Decompresses a zlib-compressed message from the Gateway.""" + self._buffer.extend(message_bytes) + + if len(message_bytes) < 4 or message_bytes[-4:] != ZLIB_SUFFIX: + # Message is not complete or not zlib compressed in the expected way + return None + # Or handle partial messages if Discord ever sends them fragmented like this, + # but typically each binary message is a complete zlib stream. + + try: + decompressed = self._inflator.decompress(self._buffer) + self._buffer.clear() # Reset buffer after successful decompression + return json.loads(decompressed.decode("utf-8")) + except zlib.error as e: + print(f"Zlib decompression error: {e}") + self._buffer.clear() # Clear buffer on error + self._inflator = zlib.decompressobj() # Reset inflator + return None + except json.JSONDecodeError as e: + print(f"JSON decode error after decompression: {e}") + return None + + async def _send_json(self, payload: Dict[str, Any]): + if self._ws and not self._ws.closed: + if self.verbose: + print(f"GATEWAY SEND: {payload}") + await self._ws.send_json(payload) + else: + print("Gateway send attempted but WebSocket is closed or not available.") + # raise GatewayException("WebSocket is not connected.") + + async def _heartbeat(self): + """Sends a heartbeat to the Gateway.""" + payload = {"op": GatewayOpcode.HEARTBEAT, "d": self._last_sequence} + await self._send_json(payload) + # print("Sent heartbeat.") + + async def _keep_alive(self): + """Manages the heartbeating loop.""" + if self._heartbeat_interval is None: + # This should not happen if HELLO was processed correctly + print("Error: Heartbeat interval not set. Cannot start keep_alive.") + return + + try: + while True: + await self._heartbeat() + await asyncio.sleep( + self._heartbeat_interval / 1000 + ) # Interval is in ms + except asyncio.CancelledError: + print("Keep_alive task cancelled.") + except Exception as e: + print(f"Error in keep_alive loop: {e}") + # Potentially trigger a reconnect here or notify client + await self._client_instance.close_gateway(code=1000) # Generic close + + async def _identify(self): + """Sends the IDENTIFY payload to the Gateway.""" + payload = { + "op": GatewayOpcode.IDENTIFY, + "d": { + "token": self._token, + "intents": self._intents, + "properties": { + "$os": "python", # Or platform.system() + "$browser": "disagreement", # Library name + "$device": "disagreement", # Library name + }, + "compress": True, # Request zlib compression + }, + } + if self._shard_id is not None and self._shard_count is not None: + payload["d"]["shard"] = [self._shard_id, self._shard_count] + await self._send_json(payload) + print("Sent IDENTIFY.") + + async def _resume(self): + """Sends the RESUME payload to the Gateway.""" + if not self._session_id or self._last_sequence is None: + print("Cannot RESUME: session_id or last_sequence is missing.") + await self._identify() # Fallback to identify + return + + payload = { + "op": GatewayOpcode.RESUME, + "d": { + "token": self._token, + "session_id": self._session_id, + "seq": self._last_sequence, + }, + } + await self._send_json(payload) + print( + f"Sent RESUME for session {self._session_id} at sequence {self._last_sequence}." + ) + async def update_presence( + self, + status: str, + activity_name: Optional[str] = None, + activity_type: int = 0, + since: int = 0, + afk: bool = False, + ): + """Sends the presence update payload to the Gateway.""" + payload = { + "op": GatewayOpcode.PRESENCE_UPDATE, + "d": { + "since": since, + "activities": [ + { + "name": activity_name, + "type": activity_type, + } + ] + if activity_name + else [], + "status": status, + "afk": afk, + }, + } + await self._send_json(payload) + + async def _handle_dispatch(self, data: Dict[str, Any]): + """Handles DISPATCH events (actual Discord events).""" + event_name = data.get("t") + sequence_num = data.get("s") + raw_event_d_payload = data.get( + "d" + ) # This is the 'd' field from the gateway event + + if sequence_num is not None: + self._last_sequence = sequence_num + + if event_name == "READY": # Special handling for READY + if not isinstance(raw_event_d_payload, dict): + print( + f"Error: READY event 'd' payload is not a dict or is missing: {raw_event_d_payload}" + ) + # Consider raising an error or attempting a reconnect + return + self._session_id = raw_event_d_payload.get("session_id") + self._resume_gateway_url = raw_event_d_payload.get("resume_gateway_url") + + app_id_str = "N/A" + # Store application_id on the client instance + if ( + "application" in raw_event_d_payload + and isinstance(raw_event_d_payload["application"], dict) + and "id" in raw_event_d_payload["application"] + ): + app_id_value = raw_event_d_payload["application"]["id"] + self._client_instance.application_id = ( + app_id_value # Snowflake can be str or int + ) + app_id_str = str(app_id_value) + else: + print( + f"Warning: Could not find application ID in READY payload. App commands may not work." + ) + + # Parse and store the bot's own user object + if "user" in raw_event_d_payload and isinstance( + raw_event_d_payload["user"], dict + ): + try: + # Assuming Client has a parse_user method that takes user data dict + # and returns a User object, also caching it. + bot_user_obj = self._client_instance.parse_user( + raw_event_d_payload["user"] + ) + self._client_instance.user = bot_user_obj + print( + f"Gateway READY. Bot User: {bot_user_obj.username}#{bot_user_obj.discriminator}. Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}" + ) + except Exception as e: + print(f"Error parsing bot user from READY payload: {e}") + print( + f"Gateway READY (user parse failed). Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}" + ) + else: + print( + f"Warning: Bot user object not found or invalid in READY payload." + ) + print( + f"Gateway READY (no user). Session ID: {self._session_id}. App ID: {app_id_str}. Resume URL: {self._resume_gateway_url}" + ) + + await self._dispatcher.dispatch(event_name, raw_event_d_payload) + elif event_name == "INTERACTION_CREATE": + # print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}") + if isinstance(raw_event_d_payload, dict): + interaction = Interaction( + data=raw_event_d_payload, client_instance=self._client_instance + ) + await self._dispatcher.dispatch( + "INTERACTION_CREATE", raw_event_d_payload + ) + # Dispatch to a new client method that will then call AppCommandHandler + if hasattr(self._client_instance, "process_interaction"): + asyncio.create_task( + self._client_instance.process_interaction(interaction) + ) # type: ignore + else: + print( + "Warning: Client instance does not have process_interaction method for INTERACTION_CREATE." + ) + else: + print( + f"Error: INTERACTION_CREATE event 'd' payload is not a dict: {raw_event_d_payload}" + ) + elif event_name == "RESUMED": + print("Gateway RESUMED successfully.") + # RESUMED 'd' payload is often an empty object or debug info. + # Ensure it's a dict for the dispatcher. + event_data_to_dispatch = ( + raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} + ) + await self._dispatcher.dispatch(event_name, event_data_to_dispatch) + elif event_name: + # For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing. + # Models/parsers in EventDispatcher will need to handle potentially empty dicts. + event_data_to_dispatch = ( + raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} + ) + # print(f"GATEWAY RECV EVENT: {event_name} | DATA: {event_data_to_dispatch}") + await self._dispatcher.dispatch(event_name, event_data_to_dispatch) + else: + print(f"Received dispatch with no event name: {data}") + + async def _process_message(self, msg: aiohttp.WSMessage): + """Processes a single message from the WebSocket.""" + if msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(msg.data) + except json.JSONDecodeError: + print( + f"Failed to decode JSON from Gateway: {msg.data[:200]}" + ) # Log snippet + return + elif msg.type == aiohttp.WSMsgType.BINARY: + decompressed_data = await self._decompress_message(msg.data) + if decompressed_data is None: + print("Failed to decompress or decode binary message from Gateway.") + return + data = decompressed_data + elif msg.type == aiohttp.WSMsgType.ERROR: + print( + f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}" + ) + raise GatewayException( + f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}" + ) + elif msg.type == aiohttp.WSMsgType.CLOSED: + close_code = ( + self._ws.close_code + if self._ws and hasattr(self._ws, "close_code") + else "N/A" + ) + print(f"WebSocket connection closed by server. Code: {close_code}") + # Raise an exception to signal the closure to the client's main run loop + raise GatewayException(f"WebSocket closed by server. Code: {close_code}") + else: + print(f"Received unhandled WebSocket message type: {msg.type}") + return + + if self.verbose: + print(f"GATEWAY RECV: {data}") + op = data.get("op") + # 'd' payload (event_data) is handled specifically by each opcode handler below + + if op == GatewayOpcode.DISPATCH: + await self._handle_dispatch(data) # _handle_dispatch will extract 'd' + elif op == GatewayOpcode.HEARTBEAT: # Server requests a heartbeat + await self._heartbeat() + elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect + print("Gateway requested RECONNECT. Closing and will attempt to reconnect.") + await self.close(code=4000) # Use a non-1000 code to indicate reconnect + elif op == GatewayOpcode.INVALID_SESSION: + # The 'd' payload for INVALID_SESSION is a boolean indicating resumability + can_resume = data.get("d") is True + print(f"Gateway indicated INVALID_SESSION. Resumable: {can_resume}") + if not can_resume: + self._session_id = None # Clear session_id to force re-identify + self._last_sequence = None + # Close and reconnect. The connect logic will decide to resume or identify. + await self.close( + code=4000 if can_resume else 4009 + ) # 4009 for non-resumable + elif op == GatewayOpcode.HELLO: + hello_d_payload = data.get("d") + if ( + not isinstance(hello_d_payload, dict) + or "heartbeat_interval" not in hello_d_payload + ): + print( + f"Error: HELLO event 'd' payload is invalid or missing heartbeat_interval: {hello_d_payload}" + ) + await self.close(code=1011) # Internal error, malformed HELLO + return + self._heartbeat_interval = hello_d_payload["heartbeat_interval"] + print(f"Gateway HELLO. Heartbeat interval: {self._heartbeat_interval}ms.") + # Start heartbeating + if self._keep_alive_task: + self._keep_alive_task.cancel() + self._keep_alive_task = self._loop.create_task(self._keep_alive()) + + # Identify or Resume + if self._session_id and self._resume_gateway_url: # Check if we can resume + print("Attempting to RESUME session.") + await self._resume() + else: + print("Performing initial IDENTIFY.") + await self._identify() + elif op == GatewayOpcode.HEARTBEAT_ACK: + # print("Received heartbeat ACK.") + pass # Good, connection is alive + else: + print(f"Received unhandled Gateway Opcode: {op} with data: {data}") + + async def _receive_loop(self): + """Continuously receives and processes messages from the WebSocket.""" + if not self._ws or self._ws.closed: + print("Receive loop cannot start: WebSocket is not connected or closed.") + return + + try: + async for msg in self._ws: + await self._process_message(msg) + except asyncio.CancelledError: + print("Receive_loop task cancelled.") + except aiohttp.ClientConnectionError as e: + print(f"ClientConnectionError in receive_loop: {e}. Attempting reconnect.") + # This might be handled by an outer reconnect loop in the Client class + await self.close(code=1006) # Abnormal closure + except Exception as e: + print(f"Unexpected error in receive_loop: {e}") + traceback.print_exc() + # Consider specific error types for more granular handling + await self.close(code=1011) # Internal error + finally: + print("Receive_loop ended.") + # If the loop ends unexpectedly (not due to explicit close), + # the main client might want to try reconnecting. + + async def connect(self): + """Connects to the Discord Gateway.""" + if self._ws and not self._ws.closed: + print("Gateway already connected or connecting.") + return + + gateway_url = ( + self._resume_gateway_url or (await self._http.get_gateway_bot())["url"] + ) + if not gateway_url.endswith("?v=10&encoding=json&compress=zlib-stream"): + gateway_url += "?v=10&encoding=json&compress=zlib-stream" + + print(f"Connecting to Gateway: {gateway_url}") + try: + await self._http._ensure_session() # Ensure the HTTP client's session is active + assert ( + self._http._session is not None + ), "HTTPClient session not initialized after ensure_session" + self._ws = await self._http._session.ws_connect(gateway_url, max_msg_size=0) + print("Gateway WebSocket connection established.") + + if self._receive_task: + self._receive_task.cancel() + self._receive_task = self._loop.create_task(self._receive_loop()) + + except aiohttp.ClientConnectorError as e: + raise GatewayException( + f"Failed to connect to Gateway (Connector Error): {e}" + ) from e + except aiohttp.WSServerHandshakeError as e: + if e.status == 401: # Unauthorized during handshake + raise AuthenticationError( + f"Gateway handshake failed (401 Unauthorized): {e.message}. Check your bot token." + ) from e + raise GatewayException( + f"Gateway handshake failed (Status: {e.status}): {e.message}" + ) from e + except Exception as e: # Catch other potential errors during connection + raise GatewayException( + f"An unexpected error occurred during Gateway connection: {e}" + ) from e + + async def close(self, code: int = 1000): + """Closes the Gateway connection.""" + print(f"Closing Gateway connection with code {code}...") + if self._keep_alive_task and not self._keep_alive_task.done(): + self._keep_alive_task.cancel() + try: + await self._keep_alive_task + except asyncio.CancelledError: + pass # Expected + + if self._receive_task and not self._receive_task.done(): + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass # Expected + + if self._ws and not self._ws.closed: + await self._ws.close(code=code) + print("Gateway WebSocket closed.") + + self._ws = None + # Do not reset session_id, last_sequence, or resume_gateway_url here + # if the close code indicates a resumable disconnect (e.g. 4000-4009, or server-initiated RECONNECT) + # The connect logic will decide whether to resume or re-identify. + # However, if it's a non-resumable close (e.g. Invalid Session non-resumable), clear them. + if code == 4009: # Invalid session, not resumable + print("Clearing session state due to non-resumable invalid session.") + self._session_id = None + self._last_sequence = None + self._resume_gateway_url = None # This might be re-fetched anyway diff --git a/disagreement/http.py b/disagreement/http.py new file mode 100644 index 0000000..1816581 --- /dev/null +++ b/disagreement/http.py @@ -0,0 +1,657 @@ +# disagreement/http.py + +""" +HTTP client for interacting with the Discord REST API. +""" + +import asyncio +import aiohttp # pylint: disable=import-error +import json +from urllib.parse import quote +from typing import Optional, Dict, Any, Union, TYPE_CHECKING, List + +from .errors import ( + HTTPException, + RateLimitError, + AuthenticationError, + DisagreementException, +) +from . import __version__ # For User-Agent + +if TYPE_CHECKING: + from .client import Client + from .models import Message + from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake + +# Discord API constants +API_BASE_URL = "https://discord.com/api/v10" # Using API v10 + + +class HTTPClient: + """Handles HTTP requests to the Discord API.""" + + def __init__( + self, + token: str, + client_session: Optional[aiohttp.ClientSession] = None, + verbose: bool = False, + ): + self.token = token + self._session: Optional[aiohttp.ClientSession] = ( + client_session # Can be externally managed + ) + self.user_agent = f"DiscordBot (https://github.com/yourusername/disagreement, {__version__})" # Customize URL + + self.verbose = verbose + + self._global_rate_limit_lock = asyncio.Event() + self._global_rate_limit_lock.set() # Initially unlocked + + async def _ensure_session(self): + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + + async def close(self): + """Closes the underlying aiohttp.ClientSession.""" + if self._session and not self._session.closed: + await self._session.close() + + async def request( + self, + method: str, + endpoint: str, + payload: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[Dict[str, Any]] = None, + is_json: bool = True, + use_auth_header: bool = True, + custom_headers: Optional[Dict[str, str]] = None, + ) -> Any: + """Makes an HTTP request to the Discord API.""" + await self._ensure_session() + + url = f"{API_BASE_URL}{endpoint}" + final_headers: Dict[str, str] = { # Renamed to final_headers + "User-Agent": self.user_agent, + } + if use_auth_header: + final_headers["Authorization"] = f"Bot {self.token}" + + if is_json and payload: + final_headers["Content-Type"] = "application/json" + + if custom_headers: # Merge custom headers + final_headers.update(custom_headers) + + if self.verbose: + print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}") + + # Global rate limit handling + await self._global_rate_limit_lock.wait() + + for attempt in range(5): # Max 5 retries for rate limits + assert self._session is not None, "ClientSession not initialized" + async with self._session.request( + method, + url, + json=payload if is_json else None, + data=payload if not is_json else None, + headers=final_headers, + params=params, + ) as response: + + data = None + try: + if response.headers.get("Content-Type", "").startswith( + "application/json" + ): + data = await response.json() + else: + # For non-JSON responses, like fetching images or other files + # We might return the raw response or handle it differently + # For now, let's assume most API calls expect JSON + data = await response.text() + except (aiohttp.ContentTypeError, json.JSONDecodeError): + data = ( + await response.text() + ) # Fallback to text if JSON parsing fails + + if self.verbose: + print(f"HTTP RESPONSE: {response.status} {url} | {data}") + + if 200 <= response.status < 300: + if response.status == 204: + return None + return data + + # Rate limit handling + if response.status == 429: # Rate limited + retry_after_str = response.headers.get("Retry-After", "1") + try: + retry_after = float(retry_after_str) + except ValueError: + retry_after = 1.0 # Default retry if header is malformed + + is_global = ( + response.headers.get("X-RateLimit-Global", "false").lower() + == "true" + ) + + error_message = f"Rate limited on {method} {endpoint}." + if data and isinstance(data, dict) and "message" in data: + error_message += f" Discord says: {data['message']}" + + if is_global: + self._global_rate_limit_lock.clear() + await asyncio.sleep(retry_after) + self._global_rate_limit_lock.set() + else: + await asyncio.sleep(retry_after) + + if attempt < 4: # Don't log on the last attempt before raising + print( + f"{error_message} Retrying after {retry_after}s (Attempt {attempt + 1}/5). Global: {is_global}" + ) + continue # Retry the request + else: # Last attempt failed + raise RateLimitError( + response, + message=error_message, + retry_after=retry_after, + is_global=is_global, + ) + + # Other error handling + if response.status == 401: # Unauthorized + raise AuthenticationError(response, "Invalid token provided.") + if response.status == 403: # Forbidden + raise HTTPException( + response, + "Missing permissions or access denied.", + status=response.status, + text=str(data), + ) + + # General HTTP error + error_text = str(data) if data else "Unknown error" + discord_error_code = ( + data.get("code") if isinstance(data, dict) else None + ) + raise HTTPException( + response, + f"API Error on {method} {endpoint}: {error_text}", + status=response.status, + text=error_text, + error_code=discord_error_code, + ) + + # Should not be reached if retries are exhausted by RateLimitError + raise DisagreementException( + f"Failed request to {method} {endpoint} after multiple retries." + ) + + # --- Specific API call methods --- + + async def get_gateway_bot(self) -> Dict[str, Any]: + """Gets the WSS URL and sharding information for the Gateway.""" + return await self.request("GET", "/gateway/bot") + + async def send_message( + self, + channel_id: str, + content: Optional[str] = None, + tts: bool = False, + embeds: Optional[List[Dict[str, Any]]] = None, + components: Optional[List[Dict[str, Any]]] = None, + allowed_mentions: Optional[dict] = None, + message_reference: Optional[Dict[str, Any]] = None, + flags: Optional[int] = None, + ) -> Dict[str, Any]: + """Sends a message to a channel. + + Returns the created message data as a dict. + """ + payload: Dict[str, Any] = {} + if content is not None: # Content is optional if embeds/components are present + payload["content"] = content + if tts: + payload["tts"] = True + if embeds: + payload["embeds"] = embeds + if components: + payload["components"] = components + if allowed_mentions: + payload["allowed_mentions"] = allowed_mentions + if flags: + payload["flags"] = flags + if message_reference: + payload["message_reference"] = message_reference + + if not payload: + raise ValueError("Message must have content, embeds, or components.") + + return await self.request( + "POST", f"/channels/{channel_id}/messages", payload=payload + ) + + async def edit_message( + self, + channel_id: str, + message_id: str, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Edits a message in a channel.""" + + return await self.request( + "PATCH", + f"/channels/{channel_id}/messages/{message_id}", + payload=payload, + ) + + async def get_message( + self, channel_id: "Snowflake", message_id: "Snowflake" + ) -> Dict[str, Any]: + """Fetches a message from a channel.""" + + return await self.request( + "GET", f"/channels/{channel_id}/messages/{message_id}" + ) + + async def create_reaction( + self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str + ) -> None: + """Adds a reaction to a message as the current user.""" + encoded = quote(emoji) + await self.request( + "PUT", + f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me", + ) + + async def delete_reaction( + self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str + ) -> None: + """Removes the current user's reaction from a message.""" + encoded = quote(emoji) + await self.request( + "DELETE", + f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me", + ) + + async def get_reactions( + self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str + ) -> List[Dict[str, Any]]: + """Fetches the users that reacted with a specific emoji.""" + encoded = quote(emoji) + return await self.request( + "GET", + f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}", + ) + + async def delete_channel( + self, channel_id: str, reason: Optional[str] = None + ) -> None: + """Deletes a channel. + + If the channel is a guild channel, requires the MANAGE_CHANNELS permission. + If the channel is a thread, requires the MANAGE_THREADS permission (if locked) or + be the thread creator (if not locked). + Deleting a category does not delete its child channels. + """ + custom_headers = {} + if reason: + custom_headers["X-Audit-Log-Reason"] = reason + + await self.request( + "DELETE", + f"/channels/{channel_id}", + custom_headers=custom_headers if custom_headers else None, + ) + + async def get_channel(self, channel_id: str) -> Dict[str, Any]: + """Fetches a channel by ID.""" + return await self.request("GET", f"/channels/{channel_id}") + + async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]: + """Fetches a user object for a given user ID.""" + return await self.request("GET", f"/users/{user_id}") + + async def get_guild_member( + self, guild_id: "Snowflake", user_id: "Snowflake" + ) -> Dict[str, Any]: + """Returns a guild member object for the specified user.""" + return await self.request("GET", f"/guilds/{guild_id}/members/{user_id}") + + async def kick_member( + self, guild_id: "Snowflake", user_id: "Snowflake", reason: Optional[str] = None + ) -> None: + """Kicks a member from the guild.""" + headers = {"X-Audit-Log-Reason": reason} if reason else None + await self.request( + "DELETE", + f"/guilds/{guild_id}/members/{user_id}", + custom_headers=headers, + ) + + async def ban_member( + self, + guild_id: "Snowflake", + user_id: "Snowflake", + *, + delete_message_seconds: int = 0, + reason: Optional[str] = None, + ) -> None: + """Bans a member from the guild.""" + payload = {} + if delete_message_seconds: + payload["delete_message_seconds"] = delete_message_seconds + headers = {"X-Audit-Log-Reason": reason} if reason else None + await self.request( + "PUT", + f"/guilds/{guild_id}/bans/{user_id}", + payload=payload if payload else None, + custom_headers=headers, + ) + + async def timeout_member( + self, + guild_id: "Snowflake", + user_id: "Snowflake", + *, + until: Optional[str], + reason: Optional[str] = None, + ) -> Dict[str, Any]: + """Times out a member until the given ISO8601 timestamp.""" + payload = {"communication_disabled_until": until} + headers = {"X-Audit-Log-Reason": reason} if reason else None + return await self.request( + "PATCH", + f"/guilds/{guild_id}/members/{user_id}", + payload=payload, + custom_headers=headers, + ) + + async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]: + """Returns a list of role objects for the guild.""" + return await self.request("GET", f"/guilds/{guild_id}/roles") + + async def get_guild(self, guild_id: "Snowflake") -> Dict[str, Any]: + """Fetches a guild object for a given guild ID.""" + return await self.request("GET", f"/guilds/{guild_id}") + + # Add other methods like: + # async def get_guild(self, guild_id: str) -> Dict[str, Any]: ... + # async def create_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: ... + # etc. + # --- Application Command Endpoints --- + + # Global Application Commands + async def get_global_application_commands( + self, application_id: "Snowflake", with_localizations: bool = False + ) -> List["ApplicationCommand"]: + """Fetches all global commands for your application.""" + params = {"with_localizations": str(with_localizations).lower()} + data = await self.request( + "GET", f"/applications/{application_id}/commands", params=params + ) + from .interactions import ApplicationCommand # Ensure constructor is available + + return [ApplicationCommand(cmd_data) for cmd_data in data] + + async def create_global_application_command( + self, application_id: "Snowflake", payload: Dict[str, Any] + ) -> "ApplicationCommand": + """Creates a new global command.""" + data = await self.request( + "POST", f"/applications/{application_id}/commands", payload=payload + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def get_global_application_command( + self, application_id: "Snowflake", command_id: "Snowflake" + ) -> "ApplicationCommand": + """Fetches a specific global command.""" + data = await self.request( + "GET", f"/applications/{application_id}/commands/{command_id}" + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def edit_global_application_command( + self, + application_id: "Snowflake", + command_id: "Snowflake", + payload: Dict[str, Any], + ) -> "ApplicationCommand": + """Edits a specific global command.""" + data = await self.request( + "PATCH", + f"/applications/{application_id}/commands/{command_id}", + payload=payload, + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def delete_global_application_command( + self, application_id: "Snowflake", command_id: "Snowflake" + ) -> None: + """Deletes a specific global command.""" + await self.request( + "DELETE", f"/applications/{application_id}/commands/{command_id}" + ) + + async def bulk_overwrite_global_application_commands( + self, application_id: "Snowflake", payload: List[Dict[str, Any]] + ) -> List["ApplicationCommand"]: + """Bulk overwrites all global commands for your application.""" + data = await self.request( + "PUT", f"/applications/{application_id}/commands", payload=payload + ) + from .interactions import ApplicationCommand + + return [ApplicationCommand(cmd_data) for cmd_data in data] + + # Guild Application Commands + async def get_guild_application_commands( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + with_localizations: bool = False, + ) -> List["ApplicationCommand"]: + """Fetches all commands for your application for a specific guild.""" + params = {"with_localizations": str(with_localizations).lower()} + data = await self.request( + "GET", + f"/applications/{application_id}/guilds/{guild_id}/commands", + params=params, + ) + from .interactions import ApplicationCommand + + return [ApplicationCommand(cmd_data) for cmd_data in data] + + async def create_guild_application_command( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + payload: Dict[str, Any], + ) -> "ApplicationCommand": + """Creates a new guild command.""" + data = await self.request( + "POST", + f"/applications/{application_id}/guilds/{guild_id}/commands", + payload=payload, + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def get_guild_application_command( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + command_id: "Snowflake", + ) -> "ApplicationCommand": + """Fetches a specific guild command.""" + data = await self.request( + "GET", + f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def edit_guild_application_command( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + command_id: "Snowflake", + payload: Dict[str, Any], + ) -> "ApplicationCommand": + """Edits a specific guild command.""" + data = await self.request( + "PATCH", + f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", + payload=payload, + ) + from .interactions import ApplicationCommand + + return ApplicationCommand(data) + + async def delete_guild_application_command( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + command_id: "Snowflake", + ) -> None: + """Deletes a specific guild command.""" + await self.request( + "DELETE", + f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", + ) + + async def bulk_overwrite_guild_application_commands( + self, + application_id: "Snowflake", + guild_id: "Snowflake", + payload: List[Dict[str, Any]], + ) -> List["ApplicationCommand"]: + """Bulk overwrites all commands for your application for a specific guild.""" + data = await self.request( + "PUT", + f"/applications/{application_id}/guilds/{guild_id}/commands", + payload=payload, + ) + from .interactions import ApplicationCommand + + return [ApplicationCommand(cmd_data) for cmd_data in data] + + # --- Interaction Response Endpoints --- + # Note: These methods return Dict[str, Any] representing the Message data. + # The caller (e.g., AppCommandHandler) will be responsible for constructing Message models + # if needed, as Message model instantiation requires a `client_instance`. + + async def create_interaction_response( + self, + interaction_id: "Snowflake", + interaction_token: str, + payload: "InteractionResponsePayload", + *, + ephemeral: bool = False, + ) -> None: + """Creates a response to an Interaction. + + Parameters + ---------- + ephemeral: bool + Ignored parameter for test compatibility. + """ + # Interaction responses do not use the bot token in the Authorization header. + # They are authenticated by the interaction_token in the URL. + await self.request( + "POST", + f"/interactions/{interaction_id}/{interaction_token}/callback", + payload=payload.to_dict(), + use_auth_header=False, + ) + + async def get_original_interaction_response( + self, application_id: "Snowflake", interaction_token: str + ) -> Dict[str, Any]: + """Gets the initial Interaction response.""" + # This endpoint uses the bot token for auth. + return await self.request( + "GET", f"/webhooks/{application_id}/{interaction_token}/messages/@original" + ) + + async def edit_original_interaction_response( + self, + application_id: "Snowflake", + interaction_token: str, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Edits the initial Interaction response.""" + return await self.request( + "PATCH", + f"/webhooks/{application_id}/{interaction_token}/messages/@original", + payload=payload, + use_auth_header=False, + ) # Docs imply webhook-style auth + + async def delete_original_interaction_response( + self, application_id: "Snowflake", interaction_token: str + ) -> None: + """Deletes the initial Interaction response.""" + await self.request( + "DELETE", + f"/webhooks/{application_id}/{interaction_token}/messages/@original", + use_auth_header=False, + ) # Docs imply webhook-style auth + + async def create_followup_message( + self, + application_id: "Snowflake", + interaction_token: str, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Creates a followup message for an Interaction.""" + # Followup messages are sent to a webhook endpoint. + return await self.request( + "POST", + f"/webhooks/{application_id}/{interaction_token}", + payload=payload, + use_auth_header=False, + ) # Docs imply webhook-style auth + + async def edit_followup_message( + self, + application_id: "Snowflake", + interaction_token: str, + message_id: "Snowflake", + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Edits a followup message for an Interaction.""" + return await self.request( + "PATCH", + f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}", + payload=payload, + use_auth_header=False, + ) # Docs imply webhook-style auth + + async def delete_followup_message( + self, + application_id: "Snowflake", + interaction_token: str, + message_id: "Snowflake", + ) -> None: + """Deletes a followup message for an Interaction.""" + await self.request( + "DELETE", + f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}", + use_auth_header=False, + ) + + async def trigger_typing(self, channel_id: str) -> None: + """Sends a typing indicator to the specified channel.""" + await self.request("POST", f"/channels/{channel_id}/typing") diff --git a/disagreement/hybrid_context.py b/disagreement/hybrid_context.py new file mode 100644 index 0000000..308a3e3 --- /dev/null +++ b/disagreement/hybrid_context.py @@ -0,0 +1,32 @@ +"""Utility class for working with either command or app contexts.""" + +from __future__ import annotations + +from typing import Any, Union + +from .ext.commands.core import CommandContext +from .ext.app_commands.context import AppCommandContext + + +class HybridContext: + """Wraps :class:`CommandContext` and :class:`AppCommandContext`. + + Provides a single :meth:`send` method that proxies to ``reply`` for + prefix commands and to ``send`` for slash commands. + """ + + def __init__(self, ctx: Union[CommandContext, AppCommandContext]): + self._ctx = ctx + + async def send(self, *args: Any, **kwargs: Any): + if isinstance(self._ctx, AppCommandContext): + return await self._ctx.send(*args, **kwargs) + return await self._ctx.reply(*args, **kwargs) + + async def edit(self, *args: Any, **kwargs: Any): + if hasattr(self._ctx, "edit"): + return await self._ctx.edit(*args, **kwargs) + raise AttributeError("Underlying context does not support editing.") + + def __getattr__(self, name: str) -> Any: + return getattr(self._ctx, name) diff --git a/disagreement/i18n.py b/disagreement/i18n.py new file mode 100644 index 0000000..36535f4 --- /dev/null +++ b/disagreement/i18n.py @@ -0,0 +1,22 @@ +import json +from typing import Dict, Optional + +_translations: Dict[str, Dict[str, str]] = {} + + +def set_translations(locale: str, mapping: Dict[str, str]) -> None: + """Set translations for a locale.""" + _translations[locale] = mapping + + +def load_translations(locale: str, file_path: str) -> None: + """Load translations for *locale* from a JSON file.""" + with open(file_path, "r", encoding="utf-8") as handle: + _translations[locale] = json.load(handle) + + +def translate(key: str, locale: str, *, default: Optional[str] = None) -> str: + """Return the translated string for *key* in *locale*.""" + return _translations.get(locale, {}).get( + key, default if default is not None else key + ) diff --git a/disagreement/interactions.py b/disagreement/interactions.py new file mode 100644 index 0000000..9e42de8 --- /dev/null +++ b/disagreement/interactions.py @@ -0,0 +1,572 @@ +# disagreement/interactions.py + +""" +Data models for Discord Interaction objects. +""" + +from typing import Optional, List, Dict, Union, Any, TYPE_CHECKING + +from .enums import ( + ApplicationCommandType, + ApplicationCommandOptionType, + InteractionType, + InteractionCallbackType, + IntegrationType, + InteractionContextType, + ChannelType, +) + +# Runtime imports for models used in this module +from .models import ( + User, + Message, + Member, + Role, + Embed, + PartialChannel, + Attachment, + ActionRow, + Component, + AllowedMentions, +) + +if TYPE_CHECKING: + # Import Client type only for type checking to avoid circular imports + from .client import Client + from .ui.modal import Modal + + # MessageFlags, PartialAttachment can be added if/when defined + +Snowflake = str + + +# Based on Application Command Option Choice Structure +class ApplicationCommandOptionChoice: + """Represents a choice for an application command option.""" + + def __init__(self, data: dict): + self.name: str = data["name"] + self.value: Union[str, int, float] = data["value"] + self.name_localizations: Optional[Dict[str, str]] = data.get( + "name_localizations" + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"name": self.name, "value": self.value} + if self.name_localizations: + payload["name_localizations"] = self.name_localizations + return payload + + +# Based on Application Command Option Structure +class ApplicationCommandOption: + """Represents an option for an application command.""" + + def __init__(self, data: dict): + self.type: ApplicationCommandOptionType = ApplicationCommandOptionType( + data["type"] + ) + self.name: str = data["name"] + self.description: str = data["description"] + self.required: bool = data.get("required", False) + + self.choices: Optional[List[ApplicationCommandOptionChoice]] = ( + [ApplicationCommandOptionChoice(c) for c in data["choices"]] + if data.get("choices") + else None + ) + + self.options: Optional[List["ApplicationCommandOption"]] = ( + [ApplicationCommandOption(o) for o in data["options"]] + if data.get("options") + else None + ) # For subcommands/groups + + self.channel_types: Optional[List[ChannelType]] = ( + [ChannelType(ct) for ct in data.get("channel_types", [])] + if data.get("channel_types") + else None + ) + self.min_value: Optional[Union[int, float]] = data.get("min_value") + self.max_value: Optional[Union[int, float]] = data.get("max_value") + self.min_length: Optional[int] = data.get("min_length") + self.max_length: Optional[int] = data.get("max_length") + self.autocomplete: bool = data.get("autocomplete", False) + self.name_localizations: Optional[Dict[str, str]] = data.get( + "name_localizations" + ) + self.description_localizations: Optional[Dict[str, str]] = data.get( + "description_localizations" + ) + + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "type": self.type.value, + "name": self.name, + "description": self.description, + } + if self.required: # Defaults to False, only include if True + payload["required"] = self.required + if self.choices: + payload["choices"] = [c.to_dict() for c in self.choices] + if self.options: # For subcommands/groups + payload["options"] = [o.to_dict() for o in self.options] + if self.channel_types: + payload["channel_types"] = [ct.value for ct in self.channel_types] + if self.min_value is not None: + payload["min_value"] = self.min_value + if self.max_value is not None: + payload["max_value"] = self.max_value + if self.min_length is not None: + payload["min_length"] = self.min_length + if self.max_length is not None: + payload["max_length"] = self.max_length + if self.autocomplete: # Defaults to False, only include if True + payload["autocomplete"] = self.autocomplete + if self.name_localizations: + payload["name_localizations"] = self.name_localizations + if self.description_localizations: + payload["description_localizations"] = self.description_localizations + return payload + + +# Based on Application Command Structure +class ApplicationCommand: + """Represents an application command.""" + + def __init__(self, data: dict): + self.id: Optional[Snowflake] = data.get("id") + self.type: ApplicationCommandType = ApplicationCommandType( + data.get("type", 1) + ) # Default to CHAT_INPUT + self.application_id: Optional[Snowflake] = data.get("application_id") + self.guild_id: Optional[Snowflake] = data.get("guild_id") + self.name: str = data["name"] + self.description: str = data.get( + "description", "" + ) # Empty for USER/MESSAGE commands + + self.options: Optional[List[ApplicationCommandOption]] = ( + [ApplicationCommandOption(o) for o in data["options"]] + if data.get("options") + else None + ) + + self.default_member_permissions: Optional[str] = data.get( + "default_member_permissions" + ) + self.dm_permission: Optional[bool] = data.get("dm_permission") # Deprecated + self.nsfw: bool = data.get("nsfw", False) + self.version: Optional[Snowflake] = data.get("version") + self.name_localizations: Optional[Dict[str, str]] = data.get( + "name_localizations" + ) + self.description_localizations: Optional[Dict[str, str]] = data.get( + "description_localizations" + ) + + self.integration_types: Optional[List[IntegrationType]] = ( + [IntegrationType(it) for it in data["integration_types"]] + if data.get("integration_types") + else None + ) + + self.contexts: Optional[List[InteractionContextType]] = ( + [InteractionContextType(c) for c in data["contexts"]] + if data.get("contexts") + else None + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + +# Based on Interaction Object's Resolved Data Structure +class ResolvedData: + """Represents resolved data for an interaction.""" + + def __init__( + self, data: dict, client_instance: Optional["Client"] = None + ): # client_instance for model hydration + # Models are now imported in TYPE_CHECKING block + + users_data = data.get("users", {}) + self.users: Dict[Snowflake, "User"] = { + uid: User(udata) for uid, udata in users_data.items() + } + + self.members: Dict[Snowflake, "Member"] = {} + for mid, mdata in data.get("members", {}).items(): + member_payload = dict(mdata) + member_payload.setdefault("id", mid) + if "user" not in member_payload and mid in users_data: + member_payload["user"] = users_data[mid] + self.members[mid] = Member(member_payload, client_instance=client_instance) + + self.roles: Dict[Snowflake, "Role"] = { + rid: Role(rdata) for rid, rdata in data.get("roles", {}).items() + } + + self.channels: Dict[Snowflake, "PartialChannel"] = { + cid: PartialChannel(cdata, client_instance=client_instance) + for cid, cdata in data.get("channels", {}).items() + } + + self.messages: Dict[Snowflake, "Message"] = ( + { + mid: Message(mdata, client_instance=client_instance) for mid, mdata in data.get("messages", {}).items() # type: ignore[misc] + } + if client_instance + else {} + ) # Only hydrate if client is available + + self.attachments: Dict[Snowflake, "Attachment"] = { + aid: Attachment(adata) for aid, adata in data.get("attachments", {}).items() + } + + def __repr__(self) -> str: + return f"" + + +# Based on Interaction Object's Data Structure (for Application Commands) +class InteractionData: + """Represents the data payload for an interaction.""" + + def __init__(self, data: dict, client_instance: Optional["Client"] = None): + self.id: Optional[Snowflake] = data.get("id") # Command ID + self.name: Optional[str] = data.get("name") # Command name + self.type: Optional[ApplicationCommandType] = ( + ApplicationCommandType(data["type"]) if data.get("type") else None + ) + + self.resolved: Optional[ResolvedData] = ( + ResolvedData(data["resolved"], client_instance=client_instance) + if data.get("resolved") + else None + ) + + # For CHAT_INPUT, this is List[ApplicationCommandInteractionDataOption] + # For USER/MESSAGE, this is not present or different. + # For now, storing as raw list of dicts. Parsing can happen in handler. + self.options: Optional[List[Dict[str, Any]]] = data.get("options") + + # For message components + self.custom_id: Optional[str] = data.get("custom_id") + self.component_type: Optional[int] = data.get("component_type") + self.values: Optional[List[str]] = data.get("values") + + self.guild_id: Optional[Snowflake] = data.get("guild_id") + self.target_id: Optional[Snowflake] = data.get( + "target_id" + ) # For USER/MESSAGE commands + + def __repr__(self) -> str: + return f"" + + +# Based on Interaction Object Structure +class Interaction: + """Represents an interaction from Discord.""" + + def __init__(self, data: dict, client_instance: "Client"): + self._client: "Client" = client_instance + + self.id: Snowflake = data["id"] + self.application_id: Snowflake = data["application_id"] + self.type: InteractionType = InteractionType(data["type"]) + + self.data: Optional[InteractionData] = ( + InteractionData(data["data"], client_instance=client_instance) + if data.get("data") + else None + ) + + self.guild_id: Optional[Snowflake] = data.get("guild_id") + self.channel_id: Optional[Snowflake] = data.get( + "channel_id" + ) # Will be present on command invocations + + member_data = data.get("member") + user_data_from_member = ( + member_data.get("user") if isinstance(member_data, dict) else None + ) + + self.member: Optional["Member"] = ( + Member(member_data, client_instance=self._client) if member_data else None + ) + + # User object is included within member if in guild, otherwise it's top-level + # If self.member was successfully hydrated, its .user attribute should be preferred if it exists. + # However, Member.__init__ handles setting User attributes. + # The primary source for User is data.get("user") or member_data.get("user"). + + if data.get("user"): + self.user: Optional["User"] = User(data["user"]) + elif user_data_from_member: + self.user: Optional["User"] = User(user_data_from_member) + elif ( + self.member + ): # If member was hydrated and has user attributes (e.g. from Member(User) inheritance) + # This assumes Member correctly populates its User parts. + self.user: Optional["User"] = self.member # Member is a User subclass + else: + self.user: Optional["User"] = None + + self.token: str = data["token"] # For responding to the interaction + self.version: int = data["version"] + + self.message: Optional["Message"] = ( + Message(data["message"], client_instance=client_instance) + if data.get("message") + else None + ) # For component interactions + + self.app_permissions: Optional[str] = data.get( + "app_permissions" + ) # Bitwise set of permissions the app has in the source channel + self.locale: Optional[str] = data.get( + "locale" + ) # Selected language of the invoking user + self.guild_locale: Optional[str] = data.get( + "guild_locale" + ) # Guild's preferred language + + self.response = InteractionResponse(self) + + async def respond( + self, + content: Optional[str] = None, + *, + embed: Optional[Embed] = None, + embeds: Optional[List[Embed]] = None, + components: Optional[List[ActionRow]] = None, + ephemeral: bool = False, + tts: bool = False, + ) -> None: + """|coro| + + Responds to this interaction. + + Parameters: + content (Optional[str]): The content of the message. + embed (Optional[Embed]): A single embed to send. + embeds (Optional[List[Embed]]): A list of embeds to send. + components (Optional[List[ActionRow]]): A list of ActionRow components. + ephemeral (bool): Whether the response should be ephemeral (only visible to the user). + tts (bool): Whether the message should be sent with text-to-speech. + """ + if embed and embeds: + raise ValueError("Cannot provide both embed and embeds.") + + data: Dict[str, Any] = {} + if tts: + data["tts"] = True + if content: + data["content"] = content + if embed: + data["embeds"] = [embed.to_dict()] + elif embeds: + data["embeds"] = [e.to_dict() for e in embeds] + if components: + data["components"] = [c.to_dict() for c in components] + if ephemeral: + data["flags"] = 1 << 6 # EPHEMERAL flag + + payload = InteractionResponsePayload( + type=InteractionCallbackType.CHANNEL_MESSAGE_WITH_SOURCE, + data=InteractionCallbackData(data), + ) + + await self._client._http.create_interaction_response( + interaction_id=self.id, + interaction_token=self.token, + payload=payload, + ) + + async def respond_modal(self, modal: "Modal") -> None: + """|coro| Send a modal in response to this interaction.""" + + from typing import Any, cast + + payload = { + "type": InteractionCallbackType.MODAL.value, + "data": modal.to_dict(), + } + await self._client._http.create_interaction_response( + interaction_id=self.id, + interaction_token=self.token, + payload=cast(Any, payload), + ) + + async def edit( + self, + content: Optional[str] = None, + *, + embed: Optional[Embed] = None, + embeds: Optional[List[Embed]] = None, + components: Optional[List[ActionRow]] = None, + attachments: Optional[List[Any]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + ) -> None: + """|coro| + + Edits the original response to this interaction. + + If the interaction is from a component, this will acknowledge the + interaction and update the message in one operation. + + Parameters: + content (Optional[str]): The new message content. + embed (Optional[Embed]): A single embed to send. Ignored if + ``embeds`` is provided. + embeds (Optional[List[Embed]]): A list of embeds to send. + components (Optional[List[ActionRow]]): Updated components for the + message. + attachments (Optional[List[Any]]): Attachments to include with the + message. + allowed_mentions (Optional[Dict[str, Any]]): Controls mentions in the + message. + """ + if embed and embeds: + raise ValueError("Cannot provide both embed and embeds.") + + payload_data: Dict[str, Any] = {} + if content is not None: + payload_data["content"] = content + if embed: + payload_data["embeds"] = [embed.to_dict()] + elif embeds is not None: + payload_data["embeds"] = [e.to_dict() for e in embeds] + if components is not None: + payload_data["components"] = [c.to_dict() for c in components] + if attachments is not None: + payload_data["attachments"] = [ + a.to_dict() if hasattr(a, "to_dict") else a for a in attachments + ] + if allowed_mentions is not None: + payload_data["allowed_mentions"] = allowed_mentions + + if self.type == InteractionType.MESSAGE_COMPONENT: + # For component interactions, we send an UPDATE_MESSAGE response + # to acknowledge the interaction and edit the message simultaneously. + payload = InteractionResponsePayload( + type=InteractionCallbackType.UPDATE_MESSAGE, + data=InteractionCallbackData(payload_data), + ) + await self._client._http.create_interaction_response( + self.id, self.token, payload + ) + else: + # For other interaction types (like an initial slash command response), + # we edit the original response via the webhook endpoint. + await self._client._http.edit_original_interaction_response( + application_id=self.application_id, + interaction_token=self.token, + payload=payload_data, + ) + + def __repr__(self) -> str: + return f"" + + +class InteractionResponse: + """Helper for sending responses for an :class:`Interaction`.""" + + def __init__(self, interaction: "Interaction") -> None: + self._interaction = interaction + + async def send_modal(self, modal: "Modal") -> None: + """Sends a modal response.""" + payload = InteractionResponsePayload( + type=InteractionCallbackType.MODAL, + data=InteractionCallbackData(modal.to_dict()), + ) + await self._interaction._client._http.create_interaction_response( + self._interaction.id, + self._interaction.token, + payload, + ) + + +# Based on Interaction Response Object's Data Structure +class InteractionCallbackData: + """Data for an interaction response.""" + + def __init__(self, data: dict): + self.tts: Optional[bool] = data.get("tts") + self.content: Optional[str] = data.get("content") + self.embeds: Optional[List[Embed]] = ( + [Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None + ) + self.allowed_mentions: Optional[AllowedMentions] = ( + AllowedMentions(data["allowed_mentions"]) + if data.get("allowed_mentions") + else None + ) + self.flags: Optional[int] = data.get("flags") # MessageFlags enum could be used + from .components import component_factory + + self.components: Optional[List[Component]] = ( + [component_factory(c) for c in data.get("components", [])] + if data.get("components") + else None + ) + self.attachments: Optional[List[Attachment]] = ( + [Attachment(a) for a in data.get("attachments", [])] + if data.get("attachments") + else None + ) + + def to_dict(self) -> dict: + # Helper to convert to dict for sending to Discord API + payload = {} + if self.tts is not None: + payload["tts"] = self.tts + if self.content is not None: + payload["content"] = self.content + if self.embeds is not None: + payload["embeds"] = [e.to_dict() for e in self.embeds] + if self.allowed_mentions is not None: + payload["allowed_mentions"] = self.allowed_mentions.to_dict() + if self.flags is not None: + payload["flags"] = self.flags + if self.components is not None: + payload["components"] = [c.to_dict() for c in self.components] + if self.attachments is not None: + payload["attachments"] = [a.to_dict() for a in self.attachments] + return payload + + def __repr__(self) -> str: + return f"" + + +# Based on Interaction Response Object Structure +class InteractionResponsePayload: + """Payload for responding to an interaction.""" + + def __init__( + self, + type: InteractionCallbackType, + data: Optional[InteractionCallbackData] = None, + ): + self.type: InteractionCallbackType = type + self.data: Optional[InteractionCallbackData] = data + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"type": self.type.value} + if self.data: + payload["data"] = self.data.to_dict() + return payload + + def __repr__(self) -> str: + return f"" diff --git a/disagreement/logging_config.py b/disagreement/logging_config.py new file mode 100644 index 0000000..3839381 --- /dev/null +++ b/disagreement/logging_config.py @@ -0,0 +1,26 @@ +import logging +from typing import Optional + + +def setup_logging(level: int, file: Optional[str] = None) -> None: + """Configure logging for the library. + + Parameters + ---------- + level: + Logging level from the :mod:`logging` module. + file: + Optional file path to write logs to. If ``None``, logs are sent to + standard output. + """ + handlers: list[logging.Handler] = [] + if file is None: + handlers.append(logging.StreamHandler()) + else: + handlers.append(logging.FileHandler(file)) + + logging.basicConfig( + level=level, + handlers=handlers, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) diff --git a/disagreement/models.py b/disagreement/models.py new file mode 100644 index 0000000..698d134 --- /dev/null +++ b/disagreement/models.py @@ -0,0 +1,1642 @@ +# disagreement/models.py + +""" +Data models for Discord objects. +""" + +import json +from typing import Optional, TYPE_CHECKING, List, Dict, Any, Union + +from .errors import DisagreementException, HTTPException +from .enums import ( # These enums will need to be defined in disagreement/enums.py + VerificationLevel, + MessageNotificationLevel, + ExplicitContentFilterLevel, + MFALevel, + GuildNSFWLevel, + PremiumTier, + GuildFeature, + ChannelType, + ComponentType, + ButtonStyle, # Added for Button + # SelectMenuType will be part of ComponentType or a new enum if needed +) + + +if TYPE_CHECKING: + from .client import Client # For type hinting to avoid circular imports + from .enums import OverwriteType # For PermissionOverwrite model + from .ui.view import View + + # Forward reference Message if it were used in type hints before its definition + # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. + from .components import component_factory + + +class User: + """Represents a Discord User. + + Attributes: + id (str): The user's unique ID. + username (str): The user's username. + discriminator (str): The user's 4-digit discord-tag. + bot (bool): Whether the user belongs to an OAuth2 application. Defaults to False. + avatar (Optional[str]): The user's avatar hash, if any. + """ + + def __init__(self, data: dict): + self.id: str = data["id"] + self.username: str = data["username"] + self.discriminator: str = data["discriminator"] + self.bot: bool = data.get("bot", False) + self.avatar: Optional[str] = data.get("avatar") + + @property + def mention(self) -> str: + """str: Returns a string that allows you to mention the user.""" + return f"<@{self.id}>" + + def __repr__(self) -> str: + return f"" + + +class Message: + """Represents a message sent in a channel on Discord. + + Attributes: + id (str): The message's unique ID. + channel_id (str): The ID of the channel the message was sent in. + guild_id (Optional[str]): The ID of the guild the message was sent in, if applicable. + author (User): The user who sent the message. + content (str): The actual content of the message. + timestamp (str): When this message was sent (ISO8601 timestamp). + components (Optional[List[ActionRow]]): Structured components attached + to the message if present. + """ + + def __init__(self, data: dict, client_instance: "Client"): + self._client: "Client" = ( + client_instance # Store reference to client for methods like reply + ) + + self.id: str = data["id"] + self.channel_id: str = data["channel_id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.author: User = User(data["author"]) + self.content: str = data["content"] + self.timestamp: str = data["timestamp"] + if data.get("components"): + self.components: Optional[List[ActionRow]] = [ + ActionRow.from_dict(c, client_instance) + for c in data.get("components", []) + ] + else: + self.components = None + # Add other fields as needed, e.g., attachments, embeds, reactions, etc. + # self.mentions: List[User] = [User(u) for u in data.get("mentions", [])] + # self.mention_roles: List[str] = data.get("mention_roles", []) + # self.mention_everyone: bool = data.get("mention_everyone", False) + + async def reply( + self, + content: Optional[str] = None, + *, # Make additional params keyword-only + tts: bool = False, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + mention_author: Optional[bool] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """|coro| + + Sends a reply to the message. + This is a shorthand for `Client.send_message` in the message's channel. + + Parameters: + content (Optional[str]): The content of the message. + tts (bool): Whether the message should be sent with text-to-speech. + embed (Optional[Embed]): A single embed to send. Cannot be used with `embeds`. + embeds (Optional[List[Embed]]): A list of embeds to send. + components (Optional[List[ActionRow]]): A list of ActionRow components. + allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. + mention_author (Optional[bool]): Whether to mention the author in the reply. If ``None`` the + client's :attr:`mention_replies` setting is used. + flags (Optional[int]): Message flags. + view (Optional[View]): A view to send with the message. + + Returns: + Message: The message that was sent. + + Raises: + HTTPException: Sending the message failed. + ValueError: If both `embed` and `embeds` are provided. + """ + # Determine allowed mentions for the reply + if mention_author is None: + mention_author = getattr(self._client, "mention_replies", False) + + if allowed_mentions is None: + allowed_mentions = {"replied_user": mention_author} + else: + allowed_mentions = dict(allowed_mentions) + allowed_mentions.setdefault("replied_user", mention_author) + + # Client.send_message is already updated to handle these parameters + return await self._client.send_message( + channel_id=self.channel_id, + content=content, + tts=tts, + embed=embed, + embeds=embeds, + components=components, + allowed_mentions=allowed_mentions, + message_reference={ + "message_id": self.id, + "channel_id": self.channel_id, + "guild_id": self.guild_id, + }, + flags=flags, + view=view, + ) + + async def edit( + self, + *, + content: Optional[str] = None, + embed: Optional["Embed"] = None, + embeds: Optional[List["Embed"]] = None, + components: Optional[List["ActionRow"]] = None, + allowed_mentions: Optional[Dict[str, Any]] = None, + flags: Optional[int] = None, + view: Optional["View"] = None, + ) -> "Message": + """|coro| + + Edits this message. + + Parameters are the same as :meth:`Client.edit_message`. + """ + + return await self._client.edit_message( + channel_id=self.channel_id, + message_id=self.id, + content=content, + embed=embed, + embeds=embeds, + components=components, + allowed_mentions=allowed_mentions, + flags=flags, + view=view, + ) + + def __repr__(self) -> str: + return f"" + + +class EmbedFooter: + """Represents an embed footer.""" + + def __init__(self, data: Dict[str, Any]): + self.text: str = data["text"] + self.icon_url: Optional[str] = data.get("icon_url") + self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") + + def to_dict(self) -> Dict[str, Any]: + payload = {"text": self.text} + if self.icon_url: + payload["icon_url"] = self.icon_url + if self.proxy_icon_url: + payload["proxy_icon_url"] = self.proxy_icon_url + return payload + + +class EmbedImage: + """Represents an embed image.""" + + def __init__(self, data: Dict[str, Any]): + self.url: str = data["url"] + self.proxy_url: Optional[str] = data.get("proxy_url") + self.height: Optional[int] = data.get("height") + self.width: Optional[int] = data.get("width") + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"url": self.url} + if self.proxy_url: + payload["proxy_url"] = self.proxy_url + if self.height: + payload["height"] = self.height + if self.width: + payload["width"] = self.width + return payload + + def __repr__(self) -> str: + return f"" + + +class EmbedThumbnail(EmbedImage): # Similar structure to EmbedImage + """Represents an embed thumbnail.""" + + pass + + +class EmbedAuthor: + """Represents an embed author.""" + + def __init__(self, data: Dict[str, Any]): + self.name: str = data["name"] + self.url: Optional[str] = data.get("url") + self.icon_url: Optional[str] = data.get("icon_url") + self.proxy_icon_url: Optional[str] = data.get("proxy_icon_url") + + def to_dict(self) -> Dict[str, Any]: + payload = {"name": self.name} + if self.url: + payload["url"] = self.url + if self.icon_url: + payload["icon_url"] = self.icon_url + if self.proxy_icon_url: + payload["proxy_icon_url"] = self.proxy_icon_url + return payload + + +class EmbedField: + """Represents an embed field.""" + + def __init__(self, data: Dict[str, Any]): + self.name: str = data["name"] + self.value: str = data["value"] + self.inline: bool = data.get("inline", False) + + def to_dict(self) -> Dict[str, Any]: + return {"name": self.name, "value": self.value, "inline": self.inline} + + +class Embed: + """Represents a Discord embed. + + Attributes can be set directly or via methods like `set_author`, `add_field`. + """ + + def __init__(self, data: Optional[Dict[str, Any]] = None): + data = data or {} + self.title: Optional[str] = data.get("title") + self.type: str = data.get("type", "rich") # Default to "rich" for sending + self.description: Optional[str] = data.get("description") + self.url: Optional[str] = data.get("url") + self.timestamp: Optional[str] = data.get("timestamp") # ISO8601 timestamp + self.color: Optional[int] = data.get("color") + + self.footer: Optional[EmbedFooter] = ( + EmbedFooter(data["footer"]) if data.get("footer") else None + ) + self.image: Optional[EmbedImage] = ( + EmbedImage(data["image"]) if data.get("image") else None + ) + self.thumbnail: Optional[EmbedThumbnail] = ( + EmbedThumbnail(data["thumbnail"]) if data.get("thumbnail") else None + ) + # Video and Provider are less common for bot-sent embeds, can be added if needed. + self.author: Optional[EmbedAuthor] = ( + EmbedAuthor(data["author"]) if data.get("author") else None + ) + self.fields: List[EmbedField] = ( + [EmbedField(f) for f in data["fields"]] if data.get("fields") else [] + ) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"type": self.type} + if self.title: + payload["title"] = self.title + if self.description: + payload["description"] = self.description + if self.url: + payload["url"] = self.url + if self.timestamp: + payload["timestamp"] = self.timestamp + if self.color is not None: + payload["color"] = self.color + if self.footer: + payload["footer"] = self.footer.to_dict() + if self.image: + payload["image"] = self.image.to_dict() + if self.thumbnail: + payload["thumbnail"] = self.thumbnail.to_dict() + if self.author: + payload["author"] = self.author.to_dict() + if self.fields: + payload["fields"] = [f.to_dict() for f in self.fields] + return payload + + # Convenience methods for building embeds can be added here + # e.g., set_author, add_field, set_footer, set_image, etc. + + +class Attachment: + """Represents a message attachment.""" + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] + self.filename: str = data["filename"] + self.description: Optional[str] = data.get("description") + self.content_type: Optional[str] = data.get("content_type") + self.size: Optional[int] = data.get("size") + self.url: Optional[str] = data.get("url") + self.proxy_url: Optional[str] = data.get("proxy_url") + self.height: Optional[int] = data.get("height") # If image + self.width: Optional[int] = data.get("width") # If image + self.ephemeral: bool = data.get("ephemeral", False) + + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"id": self.id, "filename": self.filename} + if self.description is not None: + payload["description"] = self.description + if self.content_type is not None: + payload["content_type"] = self.content_type + if self.size is not None: + payload["size"] = self.size + if self.url is not None: + payload["url"] = self.url + if self.proxy_url is not None: + payload["proxy_url"] = self.proxy_url + if self.height is not None: + payload["height"] = self.height + if self.width is not None: + payload["width"] = self.width + if self.ephemeral: + payload["ephemeral"] = self.ephemeral + return payload + + +class AllowedMentions: + """Represents allowed mentions for a message or interaction response.""" + + def __init__(self, data: Dict[str, Any]): + self.parse: List[str] = data.get("parse", []) + self.roles: List[str] = data.get("roles", []) + self.users: List[str] = data.get("users", []) + self.replied_user: bool = data.get("replied_user", False) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"parse": self.parse} + if self.roles: + payload["roles"] = self.roles + if self.users: + payload["users"] = self.users + if self.replied_user: + payload["replied_user"] = self.replied_user + return payload + + +class RoleTags: + """Represents tags for a role.""" + + def __init__(self, data: Dict[str, Any]): + self.bot_id: Optional[str] = data.get("bot_id") + self.integration_id: Optional[str] = data.get("integration_id") + self.premium_subscriber: Optional[bool] = ( + data.get("premium_subscriber") is None + ) # presence of null value means true + + def to_dict(self) -> Dict[str, Any]: + payload = {} + if self.bot_id: + payload["bot_id"] = self.bot_id + if self.integration_id: + payload["integration_id"] = self.integration_id + if self.premium_subscriber: + payload["premium_subscriber"] = None # Explicitly null + return payload + + +class Role: + """Represents a Discord Role.""" + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] + self.name: str = data["name"] + self.color: int = data["color"] + self.hoist: bool = data["hoist"] + self.icon: Optional[str] = data.get("icon") + self.unicode_emoji: Optional[str] = data.get("unicode_emoji") + self.position: int = data["position"] + self.permissions: str = data["permissions"] # String of bitwise permissions + self.managed: bool = data["managed"] + self.mentionable: bool = data["mentionable"] + self.tags: Optional[RoleTags] = ( + RoleTags(data["tags"]) if data.get("tags") else None + ) + + @property + def mention(self) -> str: + """str: Returns a string that allows you to mention the role.""" + return f"<@&{self.id}>" + + def __repr__(self) -> str: + return f"" + + +class Member(User): # Member inherits from User + """Represents a Guild Member. + This class combines User attributes with guild-specific Member attributes. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client: Optional["Client"] = client_instance + self.guild_id: Optional[str] = None + # User part is nested under 'user' key in member data from gateway/API + user_data = data.get("user", {}) + # If 'id' is not in user_data but is top-level (e.g. from interaction resolved member without user object) + if "id" not in user_data and "id" in data: + # This case is less common for full member objects but can happen. + # We'd need to construct a partial user from top-level member fields if 'user' is missing. + # For now, assume 'user' object is present for full Member hydration. + # If 'user' is missing, the User part might be incomplete. + pass # User fields will be missing or default if 'user' not in data. + + super().__init__( + user_data if user_data else data + ) # Pass user_data or data if user_data is empty + + self.nick: Optional[str] = data.get("nick") + self.avatar: Optional[str] = data.get("avatar") # Guild-specific avatar hash + self.roles: List[str] = data.get("roles", []) # List of role IDs + self.joined_at: str = data["joined_at"] # ISO8601 timestamp + self.premium_since: Optional[str] = data.get( + "premium_since" + ) # ISO8601 timestamp + self.deaf: bool = data.get("deaf", False) + self.mute: bool = data.get("mute", False) + self.pending: bool = data.get("pending", False) + self.permissions: Optional[str] = data.get( + "permissions" + ) # Permissions in the channel, if applicable + self.communication_disabled_until: Optional[str] = data.get( + "communication_disabled_until" + ) # ISO8601 timestamp + + # If 'user' object was present, ensure User attributes are from there + if user_data: + self.id = user_data.get("id", self.id) # Prefer user.id if available + self.username = user_data.get("username", self.username) + self.discriminator = user_data.get("discriminator", self.discriminator) + self.bot = user_data.get("bot", self.bot) + # User's global avatar is User.avatar, Member.avatar is guild-specific + # super() already set self.avatar from user_data if present. + # The self.avatar = data.get("avatar") line above overwrites it with guild avatar. This is correct. + + def __repr__(self) -> str: + return f"" + + async def kick(self, *, reason: Optional[str] = None) -> None: + if not self.guild_id or not self._client: + raise DisagreementException("Member.kick requires guild_id and client") + await self._client._http.kick_member(self.guild_id, self.id, reason=reason) + + async def ban( + self, + *, + delete_message_seconds: int = 0, + reason: Optional[str] = None, + ) -> None: + if not self.guild_id or not self._client: + raise DisagreementException("Member.ban requires guild_id and client") + await self._client._http.ban_member( + self.guild_id, + self.id, + delete_message_seconds=delete_message_seconds, + reason=reason, + ) + + async def timeout( + self, until: Optional[str], *, reason: Optional[str] = None + ) -> None: + if not self.guild_id or not self._client: + raise DisagreementException("Member.timeout requires guild_id and client") + await self._client._http.timeout_member( + self.guild_id, + self.id, + until=until, + reason=reason, + ) + + +class PartialEmoji: + """Represents a partial emoji, often used in components or reactions. + + This typically means only id, name, and animated are known. + For unicode emojis, id will be None and name will be the unicode character. + """ + + def __init__(self, data: Dict[str, Any]): + self.id: Optional[str] = data.get("id") + self.name: Optional[str] = data.get( + "name" + ) # Can be None for unknown custom emoji, or unicode char + self.animated: bool = data.get("animated", False) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + if self.id: + payload["id"] = self.id + if self.name: + payload["name"] = self.name + if self.animated: # Only include if true, as per some Discord patterns + payload["animated"] = self.animated + return payload + + def __str__(self) -> str: + if self.id: + return f"<{'a' if self.animated else ''}:{self.name}:{self.id}>" + return self.name or "" # For unicode emoji + + def __repr__(self) -> str: + return ( + f"" + ) + + +def to_partial_emoji( + value: Union[str, "PartialEmoji", None], +) -> Optional["PartialEmoji"]: + """Convert a string or PartialEmoji to a PartialEmoji instance. + + Args: + value: Either a unicode emoji string, a :class:`PartialEmoji`, or ``None``. + + Returns: + A :class:`PartialEmoji` or ``None`` if ``value`` was ``None``. + + Raises: + TypeError: If ``value`` is not ``str`` or :class:`PartialEmoji`. + """ + + if value is None or isinstance(value, PartialEmoji): + return value + if isinstance(value, str): + return PartialEmoji({"name": value, "id": None}) + raise TypeError("emoji must be a str or PartialEmoji") + + +class Emoji(PartialEmoji): + """Represents a custom guild emoji. + + Inherits id, name, animated from PartialEmoji. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + super().__init__(data) + self._client: Optional["Client"] = ( + client_instance # For potential future methods + ) + + # Roles this emoji is whitelisted to + self.roles: List[str] = data.get("roles", []) # List of role IDs + + # User object for the user that created this emoji (optional, only for GUILD_EMOJIS_AND_STICKERS intent) + self.user: Optional[User] = User(data["user"]) if data.get("user") else None + + self.require_colons: bool = data.get("require_colons", False) + self.managed: bool = data.get( + "managed", False + ) # If this emoji is managed by an integration + self.available: bool = data.get( + "available", True + ) # Whether this emoji can be used + + def __repr__(self) -> str: + return f"" + + +class StickerItem: + """Represents a sticker item, a basic representation of a sticker. + + Used in sticker packs and sometimes in message data. + """ + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] + self.name: str = data["name"] + self.format_type: int = data["format_type"] # StickerFormatType enum + + def __repr__(self) -> str: + return f"" + + +class Sticker(StickerItem): + """Represents a Discord sticker. + + Inherits id, name, format_type from StickerItem. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + super().__init__(data) + self._client: Optional["Client"] = client_instance + + self.pack_id: Optional[str] = data.get( + "pack_id" + ) # For standard stickers, ID of the pack + self.description: Optional[str] = data.get("description") + self.tags: str = data.get( + "tags", "" + ) # Comma-separated list of tags for guild stickers + # type is StickerType enum (STANDARD or GUILD) + # For guild stickers, this is 2. For standard stickers, this is 1. + self.type: int = data["type"] + self.available: bool = data.get( + "available", True + ) # Whether this sticker can be used + self.guild_id: Optional[str] = data.get( + "guild_id" + ) # ID of the guild that owns this sticker + + # User object of the user that uploaded the guild sticker + self.user: Optional[User] = User(data["user"]) if data.get("user") else None + + self.sort_value: Optional[int] = data.get( + "sort_value" + ) # The standard sticker's sort order within its pack + + def __repr__(self) -> str: + return f"" + + +class StickerPack: + """Represents a pack of standard stickers.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client: Optional["Client"] = client_instance + self.id: str = data["id"] + self.stickers: List[Sticker] = [ + Sticker(s_data, client_instance) for s_data in data.get("stickers", []) + ] + self.name: str = data["name"] + self.sku_id: str = data["sku_id"] + self.cover_sticker_id: Optional[str] = data.get("cover_sticker_id") + self.description: str = data["description"] + self.banner_asset_id: Optional[str] = data.get( + "banner_asset_id" + ) # ID of the pack's banner image + + def __repr__(self) -> str: + return f"" + + +class PermissionOverwrite: + """Represents a permission overwrite for a role or member in a channel.""" + + def __init__(self, data: Dict[str, Any]): + self.id: str = data["id"] # Role or user ID + self._type_val: int = int(data["type"]) # Store raw type for enum property + self.allow: str = data["allow"] # Bitwise value of allowed permissions + self.deny: str = data["deny"] # Bitwise value of denied permissions + + @property + def type(self) -> "OverwriteType": + from .enums import ( + OverwriteType, + ) # Local import to avoid circularity at module level + + return OverwriteType(self._type_val) + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "type": self.type.value, + "allow": self.allow, + "deny": self.deny, + } + + def __repr__(self) -> str: + return f"" + + +class Guild: + """Represents a Discord Guild (Server). + + Attributes: + id (str): Guild ID. + name (str): Guild name (2-100 characters, excluding @, #, :, ```). + icon (Optional[str]): Icon hash. + splash (Optional[str]): Splash hash. + discovery_splash (Optional[str]): Discovery splash hash; only present for discoverable guilds. + owner (Optional[bool]): True if the user is the owner of the guild. (Only for /users/@me/guilds endpoint) + owner_id (str): ID of owner. + permissions (Optional[str]): Total permissions for the user in the guild (excludes overwrites). (Only for /users/@me/guilds endpoint) + afk_channel_id (Optional[str]): ID of afk channel. + afk_timeout (int): AFK timeout in seconds. + widget_enabled (Optional[bool]): True if the server widget is enabled. + widget_channel_id (Optional[str]): The channel id that the widget will generate an invite to, or null if set to no invite. + verification_level (VerificationLevel): Verification level required for the guild. + default_message_notifications (MessageNotificationLevel): Default message notifications level. + explicit_content_filter (ExplicitContentFilterLevel): Explicit content filter level. + roles (List[Role]): Roles in the guild. + emojis (List[Dict]): Custom emojis. (Consider creating an Emoji model) + features (List[GuildFeature]): Enabled guild features. + mfa_level (MFALevel): Required MFA level for the guild. + application_id (Optional[str]): Application ID of the guild creator if it is bot-created. + system_channel_id (Optional[str]): The id of the channel where guild notices such as welcome messages and boost events are posted. + system_channel_flags (int): System channel flags. + rules_channel_id (Optional[str]): The id of the channel where Community guilds can display rules. + max_members (Optional[int]): The maximum number of members for the guild. + vanity_url_code (Optional[str]): The vanity url code for the guild. + description (Optional[str]): The description of a Community guild. + banner (Optional[str]): Banner hash. + premium_tier (PremiumTier): Premium tier (Server Boost level). + premium_subscription_count (Optional[int]): The number of boosts this guild currently has. + preferred_locale (str): The preferred locale of a Community guild. Defaults to "en-US". + public_updates_channel_id (Optional[str]): The id of the channel where admins and moderators of Community guilds receive notices from Discord. + max_video_channel_users (Optional[int]): The maximum number of users in a video channel. + welcome_screen (Optional[Dict]): The welcome screen of a Community guild. (Consider a WelcomeScreen model) + nsfw_level (GuildNSFWLevel): Guild NSFW level. + stickers (Optional[List[Dict]]): Custom stickers in the guild. (Consider a Sticker model) + premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled. + """ + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + self._client: "Client" = client_instance + self.id: str = data["id"] + self.name: str = data["name"] + self.icon: Optional[str] = data.get("icon") + self.splash: Optional[str] = data.get("splash") + self.discovery_splash: Optional[str] = data.get("discovery_splash") + self.owner: Optional[bool] = data.get("owner") + self.owner_id: str = data["owner_id"] + self.permissions: Optional[str] = data.get("permissions") + self.afk_channel_id: Optional[str] = data.get("afk_channel_id") + self.afk_timeout: int = data["afk_timeout"] + self.widget_enabled: Optional[bool] = data.get("widget_enabled") + self.widget_channel_id: Optional[str] = data.get("widget_channel_id") + self.verification_level: VerificationLevel = VerificationLevel( + data["verification_level"] + ) + self.default_message_notifications: MessageNotificationLevel = ( + MessageNotificationLevel(data["default_message_notifications"]) + ) + self.explicit_content_filter: ExplicitContentFilterLevel = ( + ExplicitContentFilterLevel(data["explicit_content_filter"]) + ) + + self.roles: List[Role] = [Role(r) for r in data.get("roles", [])] + self.emojis: List[Emoji] = [ + Emoji(e_data, client_instance) for e_data in data.get("emojis", []) + ] + + # Assuming GuildFeature can be constructed from string feature names or their values + self.features: List[GuildFeature] = [ + GuildFeature(f) if not isinstance(f, GuildFeature) else f + for f in data.get("features", []) + ] + + self.mfa_level: MFALevel = MFALevel(data["mfa_level"]) + self.application_id: Optional[str] = data.get("application_id") + self.system_channel_id: Optional[str] = data.get("system_channel_id") + self.system_channel_flags: int = data["system_channel_flags"] + self.rules_channel_id: Optional[str] = data.get("rules_channel_id") + self.max_members: Optional[int] = data.get("max_members") + self.vanity_url_code: Optional[str] = data.get("vanity_url_code") + self.description: Optional[str] = data.get("description") + self.banner: Optional[str] = data.get("banner") + self.premium_tier: PremiumTier = PremiumTier(data["premium_tier"]) + self.premium_subscription_count: Optional[int] = data.get( + "premium_subscription_count" + ) + self.preferred_locale: str = data.get("preferred_locale", "en-US") + self.public_updates_channel_id: Optional[str] = data.get( + "public_updates_channel_id" + ) + self.max_video_channel_users: Optional[int] = data.get( + "max_video_channel_users" + ) + self.approximate_member_count: Optional[int] = data.get( + "approximate_member_count" + ) + self.approximate_presence_count: Optional[int] = data.get( + "approximate_presence_count" + ) + self.welcome_screen: Optional["WelcomeScreen"] = ( + WelcomeScreen(data["welcome_screen"], client_instance) + if data.get("welcome_screen") + else None + ) + self.nsfw_level: GuildNSFWLevel = GuildNSFWLevel(data["nsfw_level"]) + self.stickers: Optional[List[Sticker]] = ( + [Sticker(s_data, client_instance) for s_data in data.get("stickers", [])] + if data.get("stickers") + else None + ) + self.premium_progress_bar_enabled: bool = data.get( + "premium_progress_bar_enabled", False + ) + + # Internal caches, populated by events or specific fetches + self._channels: Dict[str, "Channel"] = {} + self._members: Dict[str, Member] = {} + self._threads: Dict[str, "Thread"] = {} + + def get_channel(self, channel_id: str) -> Optional["Channel"]: + return self._channels.get(channel_id) + + def get_member(self, user_id: str) -> Optional[Member]: + return self._members.get(user_id) + + def get_role(self, role_id: str) -> Optional[Role]: + return next((role for role in self.roles if role.id == role_id), None) + + def __repr__(self) -> str: + return f"" + + +class Channel: + """Base class for Discord channels.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + self._client: "Client" = client_instance + self.id: str = data["id"] + self._type_val: int = int(data["type"]) # Store raw type for enum property + + self.guild_id: Optional[str] = data.get("guild_id") + self.name: Optional[str] = data.get("name") + self.position: Optional[int] = data.get("position") + self.permission_overwrites: List["PermissionOverwrite"] = [ + PermissionOverwrite(d) for d in data.get("permission_overwrites", []) + ] + self.nsfw: Optional[bool] = data.get("nsfw", False) + self.parent_id: Optional[str] = data.get( + "parent_id" + ) # ID of the parent category channel or thread parent + + @property + def type(self) -> ChannelType: + return ChannelType(self._type_val) + + @property + def mention(self) -> str: + return f"<#{self.id}>" + + async def delete(self, reason: Optional[str] = None): + await self._client._http.delete_channel(self.id, reason=reason) + + def __repr__(self) -> str: + return f"" + + +class TextChannel(Channel): + """Represents a guild text channel or announcement channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + self.topic: Optional[str] = data.get("topic") + self.last_message_id: Optional[str] = data.get("last_message_id") + self.rate_limit_per_user: Optional[int] = data.get("rate_limit_per_user", 0) + self.default_auto_archive_duration: Optional[int] = data.get( + "default_auto_archive_duration" + ) + self.last_pin_timestamp: Optional[str] = data.get("last_pin_timestamp") + + async def send( + self, + content: Optional[str] = None, + *, + embed: Optional[Embed] = None, + embeds: Optional[List[Embed]] = None, + components: Optional[List["ActionRow"]] = None, # Added components + ) -> "Message": # Forward reference Message + if not hasattr(self._client, "send_message"): + raise NotImplementedError( + "Client.send_message is required for TextChannel.send" + ) + + return await self._client.send_message( + channel_id=self.id, + content=content, + embed=embed, + embeds=embeds, + components=components, + ) + + def __repr__(self) -> str: + return f"" + + +class VoiceChannel(Channel): + """Represents a guild voice channel or stage voice channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + self.bitrate: int = data.get("bitrate", 64000) + self.user_limit: int = data.get("user_limit", 0) + self.rtc_region: Optional[str] = data.get("rtc_region") + self.video_quality_mode: Optional[int] = data.get("video_quality_mode") + + def __repr__(self) -> str: + return f"" + + +class CategoryChannel(Channel): + """Represents a guild category channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + + @property + def channels(self) -> List[Channel]: + if not self.guild_id or not hasattr(self._client, "get_guild"): + return [] + guild = self._client.get_guild(self.guild_id) + if not guild or not hasattr( + guild, "_channels" + ): # Ensure guild and _channels exist + return [] + + categorized_channels = [ + ch + for ch in guild._channels.values() + if getattr(ch, "parent_id", None) == self.id + ] + return sorted( + categorized_channels, + key=lambda c: c.position if c.position is not None else -1, + ) + + def __repr__(self) -> str: + return f"" + + +class ThreadMetadata: + """Represents the metadata of a thread.""" + + def __init__(self, data: Dict[str, Any]): + self.archived: bool = data["archived"] + self.auto_archive_duration: int = data["auto_archive_duration"] + self.archive_timestamp: str = data["archive_timestamp"] + self.locked: bool = data["locked"] + self.invitable: Optional[bool] = data.get("invitable") + self.create_timestamp: Optional[str] = data.get("create_timestamp") + + +class Thread(TextChannel): # Threads are a specialized TextChannel + """Represents a Discord Thread.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) # Handles common text channel fields + self.owner_id: Optional[str] = data.get("owner_id") + # parent_id is already handled by base Channel init if present in data + self.message_count: Optional[int] = data.get("message_count") + self.member_count: Optional[int] = data.get("member_count") + self.thread_metadata: ThreadMetadata = ThreadMetadata(data["thread_metadata"]) + self.member: Optional["ThreadMember"] = ( + ThreadMember(data["member"], client_instance) + if data.get("member") + else None + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class DMChannel(Channel): + """Represents a Direct Message channel.""" + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + super().__init__(data, client_instance) + self.last_message_id: Optional[str] = data.get("last_message_id") + self.recipients: List[User] = [ + User(u_data) for u_data in data.get("recipients", []) + ] + + @property + def recipient(self) -> Optional[User]: + return self.recipients[0] if self.recipients else None + + async def send( + self, + content: Optional[str] = None, + *, + embed: Optional[Embed] = None, + embeds: Optional[List[Embed]] = None, + components: Optional[List["ActionRow"]] = None, # Added components + ) -> "Message": + if not hasattr(self._client, "send_message"): + raise NotImplementedError( + "Client.send_message is required for DMChannel.send" + ) + + return await self._client.send_message( + channel_id=self.id, + content=content, + embed=embed, + embeds=embeds, + components=components, + ) + + def __repr__(self) -> str: + recipient_repr = self.recipient.username if self.recipient else "Unknown" + return f"" + + +class PartialChannel: + """Represents a partial channel object, often from interactions.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client: Optional["Client"] = client_instance + self.id: str = data["id"] + self.name: Optional[str] = data.get("name") + self._type_val: int = int(data["type"]) + self.permissions: Optional[str] = data.get("permissions") + + @property + def type(self) -> ChannelType: + return ChannelType(self._type_val) + + @property + def mention(self) -> str: + return f"<#{self.id}>" + + async def fetch_full_channel(self) -> Optional[Channel]: + if not self._client or not hasattr(self._client, "fetch_channel"): + # Log or raise if fetching is not possible + return None + try: + # This assumes Client.fetch_channel exists and returns a full Channel object + return await self._client.fetch_channel(self.id) + except HTTPException as exc: + print(f"HTTP error while fetching channel {self.id}: {exc}") + except (json.JSONDecodeError, KeyError, ValueError) as exc: + print(f"Failed to parse channel {self.id}: {exc}") + except DisagreementException as exc: + print(f"Error fetching channel {self.id}: {exc}") + return None + + def __repr__(self) -> str: + type_name = self.type.name if hasattr(self.type, "name") else self._type_val + return f"" + + +# --- Message Components --- + + +class Component: + """Base class for message components.""" + + def __init__(self, type: ComponentType): + self.type: ComponentType = type + self.custom_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {"type": self.type.value} + if self.custom_id: + payload["custom_id"] = self.custom_id + return payload + + +class ActionRow(Component): + """Represents an Action Row, a container for other components.""" + + def __init__(self, components: Optional[List[Component]] = None): + super().__init__(ComponentType.ACTION_ROW) + self.components: List[Component] = components or [] + + def add_component(self, component: Component): + if isinstance(component, ActionRow): + raise ValueError("Cannot nest ActionRows inside another ActionRow.") + + select_types = { + ComponentType.STRING_SELECT, + ComponentType.USER_SELECT, + ComponentType.ROLE_SELECT, + ComponentType.MENTIONABLE_SELECT, + ComponentType.CHANNEL_SELECT, + } + + if component.type in select_types: + if self.components: + raise ValueError( + "Select menu components must be the only component in an ActionRow." + ) + self.components.append(component) + return self + + if any(c.type in select_types for c in self.components): + raise ValueError( + "Cannot add components to an ActionRow that already contains a select menu." + ) + + if len(self.components) >= 5: + raise ValueError("ActionRow cannot have more than 5 components.") + + self.components.append(component) + return self + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["components"] = [c.to_dict() for c in self.components] + return payload + + @classmethod + def from_dict( + cls, data: Dict[str, Any], client: Optional["Client"] = None + ) -> "ActionRow": + """Deserialize an action row payload.""" + from .components import component_factory + + row = cls() + for comp_data in data.get("components", []): + try: + row.add_component(component_factory(comp_data, client)) + except Exception: + # Skip components that fail to parse for now + continue + return row + + +class Button(Component): + """Represents a button component.""" + + def __init__( + self, + *, # Make parameters keyword-only for clarity + style: ButtonStyle, + label: Optional[str] = None, + emoji: Optional["PartialEmoji"] = None, # Changed to PartialEmoji type + custom_id: Optional[str] = None, + url: Optional[str] = None, + disabled: bool = False, + ): + super().__init__(ComponentType.BUTTON) + + if style == ButtonStyle.LINK and url is None: + raise ValueError("Link buttons must have a URL.") + if style != ButtonStyle.LINK and custom_id is None: + raise ValueError("Non-link buttons must have a custom_id.") + if label is None and emoji is None: + raise ValueError("Button must have a label or an emoji.") + + self.style: ButtonStyle = style + self.label: Optional[str] = label + self.emoji: Optional[PartialEmoji] = emoji + self.custom_id = custom_id + self.url: Optional[str] = url + self.disabled: bool = disabled + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["style"] = self.style.value + if self.label: + payload["label"] = self.label + if self.emoji: + payload["emoji"] = self.emoji.to_dict() # Call to_dict() + if self.custom_id: + payload["custom_id"] = self.custom_id + if self.url: + payload["url"] = self.url + if self.disabled: + payload["disabled"] = self.disabled + return payload + + +class SelectOption: + """Represents an option in a select menu.""" + + def __init__( + self, + *, # Make parameters keyword-only + label: str, + value: str, + description: Optional[str] = None, + emoji: Optional["PartialEmoji"] = None, # Changed to PartialEmoji type + default: bool = False, + ): + self.label: str = label + self.value: str = value + self.description: Optional[str] = description + self.emoji: Optional["PartialEmoji"] = emoji + self.default: bool = default + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "label": self.label, + "value": self.value, + } + if self.description: + payload["description"] = self.description + if self.emoji: + payload["emoji"] = self.emoji.to_dict() # Call to_dict() + if self.default: + payload["default"] = self.default + return payload + + +class SelectMenu(Component): + """Represents a select menu component. + + Currently supports STRING_SELECT (type 3). + User (5), Role (6), Mentionable (7), Channel (8) selects are not yet fully modeled. + """ + + def __init__( + self, + *, # Make parameters keyword-only + custom_id: str, + options: List[SelectOption], + placeholder: Optional[str] = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + channel_types: Optional[List[ChannelType]] = None, + # For other select types, specific fields would be needed. + # This constructor primarily targets STRING_SELECT (type 3). + type: ComponentType = ComponentType.STRING_SELECT, # Default to string select + ): + super().__init__(type) # Pass the specific select menu type + + if not (1 <= len(options) <= 25): + raise ValueError("Select menu must have between 1 and 25 options.") + if not ( + 0 <= min_values <= 25 + ): # Discord docs say min_values can be 0 for some types + raise ValueError("min_values must be between 0 and 25.") + if not (1 <= max_values <= 25): + raise ValueError("max_values must be between 1 and 25.") + if min_values > max_values: + raise ValueError("min_values cannot be greater than max_values.") + + self.custom_id = custom_id + self.options: List[SelectOption] = options + self.placeholder: Optional[str] = placeholder + self.min_values: int = min_values + self.max_values: int = max_values + self.disabled: bool = disabled + self.channel_types: Optional[List[ChannelType]] = channel_types + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() # Gets {"type": self.type.value} + payload["custom_id"] = self.custom_id + payload["options"] = [opt.to_dict() for opt in self.options] + if self.placeholder: + payload["placeholder"] = self.placeholder + payload["min_values"] = self.min_values + payload["max_values"] = self.max_values + if self.disabled: + payload["disabled"] = self.disabled + if self.type == ComponentType.CHANNEL_SELECT and self.channel_types: + payload["channel_types"] = [ct.value for ct in self.channel_types] + return payload + + +class UnfurledMediaItem: + """Represents an unfurled media item.""" + + def __init__( + self, + url: str, + proxy_url: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + content_type: Optional[str] = None, + ): + self.url = url + self.proxy_url = proxy_url + self.height = height + self.width = width + self.content_type = content_type + + def to_dict(self) -> Dict[str, Any]: + return { + "url": self.url, + "proxy_url": self.proxy_url, + "height": self.height, + "width": self.width, + "content_type": self.content_type, + } + + +class MediaGalleryItem: + """Represents an item in a media gallery.""" + + def __init__( + self, + media: UnfurledMediaItem, + description: Optional[str] = None, + spoiler: bool = False, + ): + self.media = media + self.description = description + self.spoiler = spoiler + + def to_dict(self) -> Dict[str, Any]: + return { + "media": self.media.to_dict(), + "description": self.description, + "spoiler": self.spoiler, + } + + +class TextDisplay(Component): + """Represents a text display component.""" + + def __init__(self, content: str, id: Optional[int] = None): + super().__init__(ComponentType.TEXT_DISPLAY) + self.content = content + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["content"] = self.content + if self.id is not None: + payload["id"] = self.id + return payload + + +class Thumbnail(Component): + """Represents a thumbnail component.""" + + def __init__( + self, + media: UnfurledMediaItem, + description: Optional[str] = None, + spoiler: bool = False, + id: Optional[int] = None, + ): + super().__init__(ComponentType.THUMBNAIL) + self.media = media + self.description = description + self.spoiler = spoiler + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["media"] = self.media.to_dict() + if self.description: + payload["description"] = self.description + if self.spoiler: + payload["spoiler"] = self.spoiler + if self.id is not None: + payload["id"] = self.id + return payload + + +class Section(Component): + """Represents a section component.""" + + def __init__( + self, + components: List[TextDisplay], + accessory: Optional[Union[Thumbnail, Button]] = None, + id: Optional[int] = None, + ): + super().__init__(ComponentType.SECTION) + self.components = components + self.accessory = accessory + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["components"] = [c.to_dict() for c in self.components] + if self.accessory: + payload["accessory"] = self.accessory.to_dict() + if self.id is not None: + payload["id"] = self.id + return payload + + +class MediaGallery(Component): + """Represents a media gallery component.""" + + def __init__(self, items: List[MediaGalleryItem], id: Optional[int] = None): + super().__init__(ComponentType.MEDIA_GALLERY) + self.items = items + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["items"] = [i.to_dict() for i in self.items] + if self.id is not None: + payload["id"] = self.id + return payload + + +class File(Component): + """Represents a file component.""" + + def __init__( + self, file: UnfurledMediaItem, spoiler: bool = False, id: Optional[int] = None + ): + super().__init__(ComponentType.FILE) + self.file = file + self.spoiler = spoiler + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["file"] = self.file.to_dict() + if self.spoiler: + payload["spoiler"] = self.spoiler + if self.id is not None: + payload["id"] = self.id + return payload + + +class Separator(Component): + """Represents a separator component.""" + + def __init__( + self, divider: bool = True, spacing: int = 1, id: Optional[int] = None + ): + super().__init__(ComponentType.SEPARATOR) + self.divider = divider + self.spacing = spacing + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["divider"] = self.divider + payload["spacing"] = self.spacing + if self.id is not None: + payload["id"] = self.id + return payload + + +class Container(Component): + """Represents a container component.""" + + def __init__( + self, + components: List[Component], + accent_color: Optional[int] = None, + spoiler: bool = False, + id: Optional[int] = None, + ): + super().__init__(ComponentType.CONTAINER) + self.components = components + self.accent_color = accent_color + self.spoiler = spoiler + self.id = id + + def to_dict(self) -> Dict[str, Any]: + payload = super().to_dict() + payload["components"] = [c.to_dict() for c in self.components] + if self.accent_color: + payload["accent_color"] = self.accent_color + if self.spoiler: + payload["spoiler"] = self.spoiler + if self.id is not None: + payload["id"] = self.id + return payload + + +class WelcomeChannel: + """Represents a channel shown in the server's welcome screen. + + Attributes: + channel_id (str): The ID of the channel. + description (str): The description shown for the channel. + emoji_id (Optional[str]): The ID of the emoji, if custom. + emoji_name (Optional[str]): The name of the emoji if custom, or the unicode character if standard. + """ + + def __init__(self, data: Dict[str, Any]): + self.channel_id: str = data["channel_id"] + self.description: str = data["description"] + self.emoji_id: Optional[str] = data.get("emoji_id") + self.emoji_name: Optional[str] = data.get("emoji_name") + + def __repr__(self) -> str: + return ( + f"" + ) + + +class WelcomeScreen: + """Represents the welcome screen of a Community guild. + + Attributes: + description (Optional[str]): The server description shown in the welcome screen. + welcome_channels (List[WelcomeChannel]): The channels shown in the welcome screen. + """ + + def __init__(self, data: Dict[str, Any], client_instance: "Client"): + self._client: "Client" = ( + client_instance # May be useful for fetching channel objects + ) + self.description: Optional[str] = data.get("description") + self.welcome_channels: List[WelcomeChannel] = [ + WelcomeChannel(wc_data) for wc_data in data.get("welcome_channels", []) + ] + + def __repr__(self) -> str: + return f"" + + +class ThreadMember: + """Represents a member of a thread. + + Attributes: + id (Optional[str]): The ID of the thread. Not always present. + user_id (Optional[str]): The ID of the user. Not always present. + join_timestamp (str): When the user joined the thread (ISO8601 timestamp). + flags (int): User-specific flags for thread settings. + member (Optional[Member]): The guild member object for this user, if resolved. + Only available from GUILD_MEMBERS intent and if fetched. + """ + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): # client_instance for member resolution + self._client: Optional["Client"] = client_instance + self.id: Optional[str] = data.get("id") # Thread ID + self.user_id: Optional[str] = data.get("user_id") + self.join_timestamp: str = data["join_timestamp"] + self.flags: int = data["flags"] + + # The 'member' field in ThreadMember payload is a full guild member object. + # This is present in some contexts like when listing thread members. + self.member: Optional[Member] = ( + Member(data["member"], client_instance) if data.get("member") else None + ) + + # Note: The 'presence' field is not included as it's often unavailable or too dynamic for a simple model. + + def __repr__(self) -> str: + return f"" + + +class PresenceUpdate: + """Represents a PRESENCE_UPDATE event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.user = User(data["user"]) + self.guild_id: Optional[str] = data.get("guild_id") + self.status: Optional[str] = data.get("status") + self.activities: List[Dict[str, Any]] = data.get("activities", []) + self.client_status: Dict[str, Any] = data.get("client_status", {}) + + def __repr__(self) -> str: + return f"" + + +class TypingStart: + """Represents a TYPING_START event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.channel_id: str = data["channel_id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.user_id: str = data["user_id"] + self.timestamp: int = data["timestamp"] + self.member: Optional[Member] = ( + Member(data["member"], client_instance) if data.get("member") else None + ) + + def __repr__(self) -> str: + return f"" + + +def channel_factory(data: Dict[str, Any], client: "Client") -> Channel: + """Create a channel object from raw API data.""" + channel_type = data.get("type") + + if channel_type in ( + ChannelType.GUILD_TEXT.value, + ChannelType.GUILD_ANNOUNCEMENT.value, + ): + return TextChannel(data, client) + if channel_type in ( + ChannelType.GUILD_VOICE.value, + ChannelType.GUILD_STAGE_VOICE.value, + ): + return VoiceChannel(data, client) + if channel_type == ChannelType.GUILD_CATEGORY.value: + return CategoryChannel(data, client) + if channel_type in ( + ChannelType.ANNOUNCEMENT_THREAD.value, + ChannelType.PUBLIC_THREAD.value, + ChannelType.PRIVATE_THREAD.value, + ): + return Thread(data, client) + if channel_type in (ChannelType.DM.value, ChannelType.GROUP_DM.value): + return DMChannel(data, client) + + return Channel(data, client) diff --git a/disagreement/oauth.py b/disagreement/oauth.py new file mode 100644 index 0000000..623b3fd --- /dev/null +++ b/disagreement/oauth.py @@ -0,0 +1,109 @@ +"""OAuth2 utilities.""" + +from __future__ import annotations + +import aiohttp +from typing import List, Optional, Dict, Any, Union +from urllib.parse import urlencode + +from .errors import HTTPException + + +def build_authorization_url( + client_id: str, + redirect_uri: str, + scope: Union[str, List[str]], + *, + state: Optional[str] = None, + response_type: str = "code", + prompt: Optional[str] = None, +) -> str: + """Return the Discord OAuth2 authorization URL.""" + if isinstance(scope, list): + scope = " ".join(scope) + + params = { + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": response_type, + "scope": scope, + } + if state is not None: + params["state"] = state + if prompt is not None: + params["prompt"] = prompt + + return "https://discord.com/oauth2/authorize?" + urlencode(params) + + +async def exchange_code_for_token( + client_id: str, + client_secret: str, + code: str, + redirect_uri: str, + *, + session: Optional[aiohttp.ClientSession] = None, +) -> Dict[str, Any]: + """Exchange an authorization code for an access token.""" + close = False + if session is None: + session = aiohttp.ClientSession() + close = True + + data = { + "client_id": client_id, + "client_secret": client_secret, + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + } + + resp = await session.post( + "https://discord.com/api/v10/oauth2/token", + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + try: + json_data = await resp.json() + if resp.status != 200: + raise HTTPException(resp, message="OAuth token exchange failed") + finally: + if close: + await session.close() + return json_data + + +async def refresh_access_token( + refresh_token: str, + client_id: str, + client_secret: str, + *, + session: Optional[aiohttp.ClientSession] = None, +) -> Dict[str, Any]: + """Refresh an access token using a refresh token.""" + + close = False + if session is None: + session = aiohttp.ClientSession() + close = True + + data = { + "client_id": client_id, + "client_secret": client_secret, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + resp = await session.post( + "https://discord.com/api/v10/oauth2/token", + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + try: + json_data = await resp.json() + if resp.status != 200: + raise HTTPException(resp, message="OAuth token refresh failed") + finally: + if close: + await session.close() + return json_data diff --git a/disagreement/permissions.py b/disagreement/permissions.py new file mode 100644 index 0000000..4f1a00f --- /dev/null +++ b/disagreement/permissions.py @@ -0,0 +1,99 @@ +"""Utility helpers for working with Discord permission bitmasks.""" + +from __future__ import annotations + +from enum import IntFlag +from typing import Iterable, List + + +class Permissions(IntFlag): + """Discord guild and channel permissions.""" + + CREATE_INSTANT_INVITE = 1 << 0 + KICK_MEMBERS = 1 << 1 + BAN_MEMBERS = 1 << 2 + ADMINISTRATOR = 1 << 3 + MANAGE_CHANNELS = 1 << 4 + MANAGE_GUILD = 1 << 5 + ADD_REACTIONS = 1 << 6 + VIEW_AUDIT_LOG = 1 << 7 + PRIORITY_SPEAKER = 1 << 8 + STREAM = 1 << 9 + VIEW_CHANNEL = 1 << 10 + SEND_MESSAGES = 1 << 11 + SEND_TTS_MESSAGES = 1 << 12 + MANAGE_MESSAGES = 1 << 13 + EMBED_LINKS = 1 << 14 + ATTACH_FILES = 1 << 15 + READ_MESSAGE_HISTORY = 1 << 16 + MENTION_EVERYONE = 1 << 17 + USE_EXTERNAL_EMOJIS = 1 << 18 + VIEW_GUILD_INSIGHTS = 1 << 19 + CONNECT = 1 << 20 + SPEAK = 1 << 21 + MUTE_MEMBERS = 1 << 22 + DEAFEN_MEMBERS = 1 << 23 + MOVE_MEMBERS = 1 << 24 + USE_VAD = 1 << 25 + CHANGE_NICKNAME = 1 << 26 + MANAGE_NICKNAMES = 1 << 27 + MANAGE_ROLES = 1 << 28 + MANAGE_WEBHOOKS = 1 << 29 + MANAGE_GUILD_EXPRESSIONS = 1 << 30 + USE_APPLICATION_COMMANDS = 1 << 31 + REQUEST_TO_SPEAK = 1 << 32 + MANAGE_EVENTS = 1 << 33 + MANAGE_THREADS = 1 << 34 + CREATE_PUBLIC_THREADS = 1 << 35 + CREATE_PRIVATE_THREADS = 1 << 36 + USE_EXTERNAL_STICKERS = 1 << 37 + SEND_MESSAGES_IN_THREADS = 1 << 38 + USE_EMBEDDED_ACTIVITIES = 1 << 39 + MODERATE_MEMBERS = 1 << 40 + VIEW_CREATOR_MONETIZATION_ANALYTICS = 1 << 41 + USE_SOUNDBOARD = 1 << 42 + CREATE_GUILD_EXPRESSIONS = 1 << 43 + CREATE_EVENTS = 1 << 44 + USE_EXTERNAL_SOUNDS = 1 << 45 + SEND_VOICE_MESSAGES = 1 << 46 + + +def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int: + """Return a combined integer value for multiple permissions.""" + + value = 0 + for perm in perms: + if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)): + value |= permissions_value(*perm) + else: + value |= int(perm) + return value + + +def has_permissions( + current: int | str | Permissions, + *perms: Permissions | int | Iterable[Permissions | int], +) -> bool: + """Return ``True`` if ``current`` includes all ``perms``.""" + + current_val = int(current) + needed = permissions_value(*perms) + return (current_val & needed) == needed + + +def missing_permissions( + current: int | str | Permissions, + *perms: Permissions | int | Iterable[Permissions | int], +) -> List[Permissions]: + """Return the subset of ``perms`` not present in ``current``.""" + + current_val = int(current) + missing: List[Permissions] = [] + for perm in perms: + if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)): + missing.extend(missing_permissions(current_val, *perm)) + else: + perm_val = int(perm) + if not current_val & perm_val: + missing.append(Permissions(perm_val)) + return missing diff --git a/disagreement/rate_limiter.py b/disagreement/rate_limiter.py new file mode 100644 index 0000000..8043a88 --- /dev/null +++ b/disagreement/rate_limiter.py @@ -0,0 +1,75 @@ +"""Asynchronous rate limiter for Discord HTTP requests.""" + +from __future__ import annotations + +import asyncio +import time +from typing import Dict, Mapping + + +class _Bucket: + def __init__(self) -> None: + self.remaining: int = 1 + self.reset_at: float = 0.0 + self.lock = asyncio.Lock() + + +class RateLimiter: + """Rate limiter implementing per-route buckets and a global queue.""" + + def __init__(self) -> None: + self._buckets: Dict[str, _Bucket] = {} + self._global_event = asyncio.Event() + self._global_event.set() + + def _get_bucket(self, route: str) -> _Bucket: + bucket = self._buckets.get(route) + if bucket is None: + bucket = _Bucket() + self._buckets[route] = bucket + return bucket + + async def acquire(self, route: str) -> _Bucket: + bucket = self._get_bucket(route) + while True: + await self._global_event.wait() + async with bucket.lock: + now = time.monotonic() + if bucket.remaining <= 0 and now < bucket.reset_at: + await asyncio.sleep(bucket.reset_at - now) + continue + if bucket.remaining > 0: + bucket.remaining -= 1 + return bucket + + def release(self, route: str, headers: Mapping[str, str]) -> None: + bucket = self._get_bucket(route) + try: + remaining = int(headers.get("X-RateLimit-Remaining", bucket.remaining)) + reset_after = float(headers.get("X-RateLimit-Reset-After", "0")) + bucket.remaining = remaining + bucket.reset_at = time.monotonic() + reset_after + except ValueError: + pass + + if headers.get("X-RateLimit-Global", "false").lower() == "true": + retry_after = float(headers.get("Retry-After", "0")) + self._global_event.clear() + asyncio.create_task(self._lift_global(retry_after)) + + async def handle_rate_limit( + self, route: str, retry_after: float, is_global: bool + ) -> None: + bucket = self._get_bucket(route) + bucket.remaining = 0 + bucket.reset_at = time.monotonic() + retry_after + if is_global: + self._global_event.clear() + await asyncio.sleep(retry_after) + self._global_event.set() + else: + await asyncio.sleep(retry_after) + + async def _lift_global(self, delay: float) -> None: + await asyncio.sleep(delay) + self._global_event.set() diff --git a/disagreement/shard_manager.py b/disagreement/shard_manager.py new file mode 100644 index 0000000..a457628 --- /dev/null +++ b/disagreement/shard_manager.py @@ -0,0 +1,65 @@ +# disagreement/shard_manager.py + +"""Sharding utilities for managing multiple gateway connections.""" + +from __future__ import annotations + +import asyncio +from typing import List, TYPE_CHECKING + +from .gateway import GatewayClient + +if TYPE_CHECKING: # pragma: no cover - for type checking only + from .client import Client + + +class Shard: + """Represents a single gateway shard.""" + + def __init__(self, shard_id: int, shard_count: int, gateway: GatewayClient) -> None: + self.id: int = shard_id + self.count: int = shard_count + self.gateway: GatewayClient = gateway + + async def connect(self) -> None: + """Connects this shard's gateway.""" + await self.gateway.connect() + + async def close(self) -> None: + """Closes this shard's gateway.""" + await self.gateway.close() + + +class ShardManager: + """Manages multiple :class:`Shard` instances.""" + + def __init__(self, client: "Client", shard_count: int) -> None: + self.client: "Client" = client + self.shard_count: int = shard_count + self.shards: List[Shard] = [] + + def _create_shards(self) -> None: + if self.shards: + return + for shard_id in range(self.shard_count): + gateway = GatewayClient( + http_client=self.client._http, + event_dispatcher=self.client._event_dispatcher, + token=self.client.token, + intents=self.client.intents, + client_instance=self.client, + verbose=self.client.verbose, + shard_id=shard_id, + shard_count=self.shard_count, + ) + self.shards.append(Shard(shard_id, self.shard_count, gateway)) + + async def start(self) -> None: + """Starts all shards.""" + self._create_shards() + await asyncio.gather(*(s.connect() for s in self.shards)) + + async def close(self) -> None: + """Closes all shards.""" + await asyncio.gather(*(s.close() for s in self.shards)) + self.shards.clear() diff --git a/disagreement/typing.py b/disagreement/typing.py new file mode 100644 index 0000000..73c5ce8 --- /dev/null +++ b/disagreement/typing.py @@ -0,0 +1,42 @@ +import asyncio +from contextlib import suppress +from typing import Optional, TYPE_CHECKING + +from .errors import DisagreementException + +if TYPE_CHECKING: + from .client import Client + +if __name__ == "typing": + # For direct module execution testing + pass + + +class Typing: + """Async context manager for Discord typing indicator.""" + + def __init__(self, client: "Client", channel_id: str) -> None: + self._client = client + self._channel_id = channel_id + self._task: Optional[asyncio.Task] = None + + async def _run(self) -> None: + try: + while True: + await self._client._http.trigger_typing(self._channel_id) + await asyncio.sleep(5) + except asyncio.CancelledError: + pass + + async def __aenter__(self) -> "Typing": + if self._client._closed: + raise DisagreementException("Client is closed.") + await self._client._http.trigger_typing(self._channel_id) + self._task = asyncio.create_task(self._run()) + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + if self._task: + self._task.cancel() + with suppress(asyncio.CancelledError): + await self._task diff --git a/disagreement/ui/__init__.py b/disagreement/ui/__init__.py new file mode 100644 index 0000000..9d93129 --- /dev/null +++ b/disagreement/ui/__init__.py @@ -0,0 +1,17 @@ +from .view import View +from .item import Item +from .button import Button, button +from .select import Select, select +from .modal import Modal, TextInput, text_input + +__all__ = [ + "View", + "Item", + "Button", + "button", + "Select", + "select", + "Modal", + "TextInput", + "text_input", +] diff --git a/disagreement/ui/button.py b/disagreement/ui/button.py new file mode 100644 index 0000000..aa308b0 --- /dev/null +++ b/disagreement/ui/button.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING + +from .item import Item +from ..enums import ComponentType, ButtonStyle +from ..models import PartialEmoji, to_partial_emoji + +if TYPE_CHECKING: + from ..interactions import Interaction + + +class Button(Item): + """Represents a button component in a View. + + Args: + style (ButtonStyle): The style of the button. + label (Optional[str]): The text that appears on the button. + emoji (Optional[str | PartialEmoji]): The emoji that appears on the button. + custom_id (Optional[str]): The developer-defined identifier for the button. + url (Optional[str]): The URL for the button. + disabled (bool): Whether the button is disabled. + row (Optional[int]): The row the button should be placed in, from 0 to 4. + """ + + def __init__( + self, + *, + style: ButtonStyle = ButtonStyle.SECONDARY, + label: Optional[str] = None, + emoji: Optional[str | PartialEmoji] = None, + custom_id: Optional[str] = None, + url: Optional[str] = None, + disabled: bool = False, + row: Optional[int] = None, + ): + super().__init__(type=ComponentType.BUTTON) + if not label and not emoji: + raise ValueError("A button must have a label and/or an emoji.") + + if url and custom_id: + raise ValueError("A button cannot have both a URL and a custom_id.") + + self.style = style + self.label = label + self.emoji = to_partial_emoji(emoji) + self.custom_id = custom_id + self.url = url + self.disabled = disabled + self._row = row + + def to_dict(self) -> dict[str, Any]: + """Converts the button to a dictionary that can be sent to Discord.""" + payload = { + "type": ComponentType.BUTTON.value, + "style": self.style.value, + "disabled": self.disabled, + } + if self.label: + payload["label"] = self.label + if self.emoji: + payload["emoji"] = self.emoji.to_dict() + if self.url: + payload["url"] = self.url + if self.custom_id: + payload["custom_id"] = self.custom_id + return payload + + +def button( + *, + label: Optional[str] = None, + custom_id: Optional[str] = None, + style: ButtonStyle = ButtonStyle.SECONDARY, + emoji: Optional[str | PartialEmoji] = None, + url: Optional[str] = None, + disabled: bool = False, + row: Optional[int] = None, +) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Button]: + """A decorator to create a button in a View.""" + + def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Button: + if not asyncio.iscoroutinefunction(func): + raise TypeError("Button callback must be a coroutine function.") + + item = Button( + label=label, + custom_id=custom_id, + style=style, + emoji=emoji, + url=url, + disabled=disabled, + row=row, + ) + item.callback = func + return item + + return decorator diff --git a/disagreement/ui/item.py b/disagreement/ui/item.py new file mode 100644 index 0000000..ec6d2c9 --- /dev/null +++ b/disagreement/ui/item.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING + +from ..models import Component + +if TYPE_CHECKING: + from .view import View + from ..interactions import Interaction + + +class Item(Component): + """Represents a UI item that can be placed in a View. + + This is a base class and is not meant to be used directly. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._view: Optional[View] = None + self._row: Optional[int] = None + # This is the callback associated with this item. + self.callback: Optional[ + Callable[["View", Interaction], Coroutine[Any, Any, Any]] + ] = None + + @property + def view(self) -> Optional[View]: + return self._view + + @property + def row(self) -> Optional[int]: + return self._row + + def _refresh_from_data(self, data: dict[str, Any]) -> None: + # This is used to update the item's state from incoming interaction data. + # For example, a button's disabled state could be updated here. + pass diff --git a/disagreement/ui/modal.py b/disagreement/ui/modal.py new file mode 100644 index 0000000..80271cc --- /dev/null +++ b/disagreement/ui/modal.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import Any, Callable, Coroutine, Optional, List, TYPE_CHECKING +import asyncio + +from .item import Item +from .view import View +from ..enums import ComponentType, TextInputStyle +from ..models import ActionRow + +if TYPE_CHECKING: # pragma: no cover - for type hints only + from ..interactions import Interaction + + +class TextInput(Item): + """Represents a text input component inside a modal.""" + + def __init__( + self, + *, + label: str, + custom_id: Optional[str] = None, + style: TextInputStyle = TextInputStyle.SHORT, + placeholder: Optional[str] = None, + required: bool = True, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + row: Optional[int] = None, + ) -> None: + super().__init__(type=ComponentType.TEXT_INPUT) + self.label = label + self.custom_id = custom_id + self.style = style + self.placeholder = placeholder + self.required = required + self.min_length = min_length + self.max_length = max_length + self._row = row + + def to_dict(self) -> dict[str, Any]: + payload = { + "type": ComponentType.TEXT_INPUT.value, + "style": self.style.value, + "label": self.label, + "required": self.required, + } + if self.custom_id: + payload["custom_id"] = self.custom_id + if self.placeholder: + payload["placeholder"] = self.placeholder + if self.min_length is not None: + payload["min_length"] = self.min_length + if self.max_length is not None: + payload["max_length"] = self.max_length + return payload + + +def text_input( + *, + label: str, + custom_id: Optional[str] = None, + style: TextInputStyle = TextInputStyle.SHORT, + placeholder: Optional[str] = None, + required: bool = True, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + row: Optional[int] = None, +) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], TextInput]: + """Decorator to define a text input callback inside a :class:`Modal`.""" + + def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> TextInput: + if not asyncio.iscoroutinefunction(func): + raise TypeError("TextInput callback must be a coroutine function.") + + item = TextInput( + label=label, + custom_id=custom_id, + style=style, + placeholder=placeholder, + required=required, + min_length=min_length, + max_length=max_length, + row=row, + ) + item.callback = func + return item + + return decorator + + +class Modal: + """Represents a modal dialog.""" + + def __init__(self, *, title: str, custom_id: str) -> None: + self.title = title + self.custom_id = custom_id + self._children: List[TextInput] = [] + + for item in self.__class__.__dict__.values(): + if isinstance(item, TextInput): + self.add_item(item) + + @property + def children(self) -> List[TextInput]: + return self._children + + def add_item(self, item: TextInput) -> None: + if not isinstance(item, TextInput): + raise TypeError("Only TextInput items can be added to a Modal.") + if len(self._children) >= 5: + raise ValueError("A modal can only have up to 5 text inputs.") + item._view = None # Not part of a view but reuse item base + self._children.append(item) + + def to_components(self) -> List[ActionRow]: + rows: List[ActionRow] = [] + for child in self.children: + row = ActionRow(components=[child]) + rows.append(row) + return rows + + def to_dict(self) -> dict[str, Any]: + return { + "title": self.title, + "custom_id": self.custom_id, + "components": [r.to_dict() for r in self.to_components()], + } + + async def callback( + self, interaction: Interaction + ) -> None: # pragma: no cover - default + pass diff --git a/disagreement/ui/select.py b/disagreement/ui/select.py new file mode 100644 index 0000000..b3929ef --- /dev/null +++ b/disagreement/ui/select.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Callable, Coroutine, List, Optional, TYPE_CHECKING + +from .item import Item +from ..enums import ComponentType +from ..models import SelectOption + +if TYPE_CHECKING: + from ..interactions import Interaction + + +class Select(Item): + """Represents a select menu component in a View. + + Args: + custom_id (str): The developer-defined identifier for the select menu. + options (List[SelectOption]): The choices in the select menu. + placeholder (Optional[str]): The placeholder text that is shown if nothing is selected. + min_values (int): The minimum number of items that must be chosen. + max_values (int): The maximum number of items that can be chosen. + disabled (bool): Whether the select menu is disabled. + row (Optional[int]): The row the select menu should be placed in, from 0 to 4. + """ + + def __init__( + self, + *, + custom_id: str, + options: List[SelectOption], + placeholder: Optional[str] = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + row: Optional[int] = None, + ): + super().__init__(type=ComponentType.STRING_SELECT) + self.custom_id = custom_id + self.options = options + self.placeholder = placeholder + self.min_values = min_values + self.max_values = max_values + self.disabled = disabled + self._row = row + + def to_dict(self) -> dict[str, Any]: + """Converts the select menu to a dictionary that can be sent to Discord.""" + payload = { + "type": ComponentType.STRING_SELECT.value, + "custom_id": self.custom_id, + "options": [option.to_dict() for option in self.options], + "disabled": self.disabled, + } + if self.placeholder: + payload["placeholder"] = self.placeholder + if self.min_values is not None: + payload["min_values"] = self.min_values + if self.max_values is not None: + payload["max_values"] = self.max_values + return payload + + +def select( + *, + custom_id: str, + options: List[SelectOption], + placeholder: Optional[str] = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + row: Optional[int] = None, +) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Select]: + """A decorator to create a select menu in a View.""" + + def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Select: + if not asyncio.iscoroutinefunction(func): + raise TypeError("Select callback must be a coroutine function.") + + item = Select( + custom_id=custom_id, + options=options, + placeholder=placeholder, + min_values=min_values, + max_values=max_values, + disabled=disabled, + row=row, + ) + item.callback = func + return item + + return decorator diff --git a/disagreement/ui/view.py b/disagreement/ui/view.py new file mode 100644 index 0000000..1891775 --- /dev/null +++ b/disagreement/ui/view.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import asyncio +import uuid +from typing import Any, Callable, Coroutine, Dict, List, Optional, TYPE_CHECKING + +from ..models import ActionRow +from .item import Item + +if TYPE_CHECKING: + from ..client import Client + from ..interactions import Interaction + + +class View: + """Represents a container for UI components that can be sent with a message. + + Args: + timeout (Optional[float]): The number of seconds to wait for an interaction before the view times out. + Defaults to 180. + """ + + def __init__(self, *, timeout: Optional[float] = 180.0): + self.timeout = timeout + self.id = str(uuid.uuid4()) + self.__children: List[Item] = [] + self.__stopped = asyncio.Event() + self._client: Optional[Client] = None + self._message_id: Optional[str] = None + + for item in self.__class__.__dict__.values(): + if isinstance(item, Item): + self.add_item(item) + + @property + def children(self) -> List[Item]: + return self.__children + + def add_item(self, item: Item): + """Adds an item to the view.""" + if not isinstance(item, Item): + raise TypeError("Only instances of 'Item' can be added to a View.") + + if len(self.__children) >= 25: + raise ValueError("A view can only have a maximum of 25 components.") + + item._view = self + self.__children.append(item) + + @property + def message_id(self) -> Optional[str]: + return self._message_id + + @message_id.setter + def message_id(self, value: str): + self._message_id = value + + def to_components(self) -> List[ActionRow]: + """Converts the view's children into a list of ActionRow components. + + This retains the original, simple layout behaviour where each item is + placed in its own :class:`ActionRow` to ensure backward compatibility. + """ + + rows: List[ActionRow] = [] + + for item in self.children: + if item.custom_id is None: + item.custom_id = ( + f"{self.id}:{item.__class__.__name__}:{len(self.__children)}" + ) + + rows.append(ActionRow(components=[item])) + + return rows + + def layout_components_advanced(self) -> List[ActionRow]: + """Group compatible components into rows following Discord rules.""" + + rows: List[ActionRow] = [] + + for item in self.children: + if item.custom_id is None: + item.custom_id = ( + f"{self.id}:{item.__class__.__name__}:{len(self.__children)}" + ) + + target_row = item.row + if target_row is not None: + if not 0 <= target_row <= 4: + raise ValueError("Row index must be between 0 and 4.") + + while len(rows) <= target_row: + if len(rows) >= 5: + raise ValueError("A view can have at most 5 action rows.") + rows.append(ActionRow()) + + rows[target_row].add_component(item) + continue + + placed = False + for row in rows: + try: + row.add_component(item) + placed = True + break + except ValueError: + continue + + if not placed: + if len(rows) >= 5: + raise ValueError("A view can have at most 5 action rows.") + new_row = ActionRow([item]) + rows.append(new_row) + + return rows + + def to_components_payload(self) -> List[Dict[str, Any]]: + """Converts the view's children into a list of component dictionaries + that can be sent to the Discord API.""" + return [row.to_dict() for row in self.to_components()] + + async def _dispatch(self, interaction: Interaction): + """Called by the client to dispatch an interaction to the correct item.""" + if self.timeout is not None: + self.__stopped.set() # Reset the timeout on each interaction + self.__stopped.clear() + + if interaction.data: + custom_id = interaction.data.custom_id + for child in self.children: + if child.custom_id == custom_id: + if child.callback: + await child.callback(self, interaction) + break + + async def wait(self) -> bool: + """Waits until the view has stopped interacting.""" + return await self.__stopped.wait() + + def stop(self): + """Stops the view from listening to interactions.""" + if not self.__stopped.is_set(): + self.__stopped.set() + + async def on_timeout(self): + """Called when the view times out.""" + pass # User can override this + + async def _start(self, client: Client): + """Starts the view's internal listener.""" + self._client = client + if self.timeout is not None: + asyncio.create_task(self._timeout_task()) + + async def _timeout_task(self): + """The task that waits for the timeout and then stops the view.""" + try: + await asyncio.wait_for(self.wait(), timeout=self.timeout) + except asyncio.TimeoutError: + self.stop() + await self.on_timeout() + if self._client and self._message_id: + # Remove the view from the client's listeners + self._client._views.pop(self._message_id, None) diff --git a/disagreement/voice_client.py b/disagreement/voice_client.py new file mode 100644 index 0000000..f173fb6 --- /dev/null +++ b/disagreement/voice_client.py @@ -0,0 +1,120 @@ +# disagreement/voice_client.py +"""Voice gateway and UDP audio client.""" + +from __future__ import annotations + +import asyncio +import contextlib +import socket +from typing import Optional, Sequence + +import aiohttp + + +class VoiceClient: + """Handles the Discord voice WebSocket connection and UDP streaming.""" + + def __init__( + self, + endpoint: str, + session_id: str, + token: str, + guild_id: int, + user_id: int, + *, + ws=None, + udp: Optional[socket.socket] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + verbose: bool = False, + ) -> None: + self.endpoint = endpoint + self.session_id = session_id + self.token = token + self.guild_id = str(guild_id) + self.user_id = str(user_id) + self._ws: Optional[aiohttp.ClientWebSocketResponse] = ws + self._udp = udp + self._session: Optional[aiohttp.ClientSession] = None + self._heartbeat_task: Optional[asyncio.Task] = None + self._heartbeat_interval: Optional[float] = None + self._loop = loop or asyncio.get_event_loop() + self.verbose = verbose + self.ssrc: Optional[int] = None + self.secret_key: Optional[Sequence[int]] = None + self._server_ip: Optional[str] = None + self._server_port: Optional[int] = None + + async def connect(self) -> None: + if self._ws is None: + self._session = aiohttp.ClientSession() + self._ws = await self._session.ws_connect(self.endpoint) + + hello = await self._ws.receive_json() + self._heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000 + self._heartbeat_task = self._loop.create_task(self._heartbeat()) + + await self._ws.send_json( + { + "op": 0, + "d": { + "server_id": self.guild_id, + "user_id": self.user_id, + "session_id": self.session_id, + "token": self.token, + }, + } + ) + + ready = await self._ws.receive_json() + data = ready["d"] + self.ssrc = data["ssrc"] + self._server_ip = data["ip"] + self._server_port = data["port"] + + if self._udp is None: + self._udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self._udp.connect((self._server_ip, self._server_port)) + + await self._ws.send_json( + { + "op": 1, + "d": { + "protocol": "udp", + "data": { + "address": self._udp.getsockname()[0], + "port": self._udp.getsockname()[1], + "mode": "xsalsa20_poly1305", + }, + }, + } + ) + + session_desc = await self._ws.receive_json() + self.secret_key = session_desc["d"].get("secret_key") + + async def _heartbeat(self) -> None: + assert self._ws is not None + assert self._heartbeat_interval is not None + try: + while True: + await self._ws.send_json({"op": 3, "d": int(self._loop.time() * 1000)}) + await asyncio.sleep(self._heartbeat_interval) + except asyncio.CancelledError: + pass + + async def send_audio_frame(self, frame: bytes) -> None: + if not self._udp: + raise RuntimeError("UDP socket not initialised") + self._udp.send(frame) + + async def close(self) -> None: + if self._heartbeat_task: + self._heartbeat_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._heartbeat_task + if self._ws: + await self._ws.close() + if self._session: + await self._session.close() + if self._udp: + self._udp.close() diff --git a/docs/caching.md b/docs/caching.md new file mode 100644 index 0000000..9a30d45 --- /dev/null +++ b/docs/caching.md @@ -0,0 +1,18 @@ +# Caching + +Disagreement ships with a simple in-memory cache used by the HTTP and Gateway clients. Cached objects reduce API requests and improve performance. + +The client automatically caches guilds, channels and users as they are received from events or HTTP calls. You can access cached data through lookup helpers such as `Client.get_guild`. + +The cache can be cleared manually if needed: + +```python +client.cache.clear() +``` + +## Next Steps + +- [Components](using_components.md) +- [Slash Commands](slash_commands.md) +- [Voice Features](voice_features.md) + diff --git a/docs/commands.md b/docs/commands.md new file mode 100644 index 0000000..e109899 --- /dev/null +++ b/docs/commands.md @@ -0,0 +1,51 @@ +# Commands Extension + +This guide covers the built-in prefix command system. + +## Help Command + +The command handler registers a `help` command automatically. Use it to list all available commands or get information about a single command. + +``` +!help # lists commands +!help ping # shows help for the "ping" command +``` + +The help command will show each command's brief description if provided. + +## Checks + +Use `commands.check` to prevent a command from running unless a predicate +returns ``True``. Checks may be regular or async callables that accept a +`CommandContext`. + +```python +from disagreement.ext.commands import command, check, CheckFailure + +def is_owner(ctx): + return ctx.author.id == "1" + +@command() +@check(is_owner) +async def secret(ctx): + await ctx.send("Only for the owner!") +``` + +When a check fails a :class:`CheckFailure` is raised and dispatched through the +command error handler. + +## Cooldowns + +Commands can be rate limited using the ``cooldown`` decorator. The example +below restricts usage to once every three seconds per user: + +```python +from disagreement.ext.commands import command, cooldown + +@command() +@cooldown(1, 3.0) +async def ping(ctx): + await ctx.send("Pong!") +``` + +Invoking a command while it is on cooldown raises :class:`CommandOnCooldown`. diff --git a/docs/context_menus.md b/docs/context_menus.md new file mode 100644 index 0000000..af10f7b --- /dev/null +++ b/docs/context_menus.md @@ -0,0 +1,21 @@ +# Context Menu Commands + +`disagreement` supports Discord's user and message context menu commands. Use the +`user_command` and `message_command` decorators from `ext.app_commands` to +define them. + +```python +from disagreement.ext.app_commands import user_command, message_command, AppCommandContext +from disagreement.models import User, Message + +@user_command(name="User Info") +async def user_info(ctx: AppCommandContext, user: User) -> None: + await ctx.send(f"User: {user.username}#{user.discriminator}") + +@message_command(name="Quote") +async def quote(ctx: AppCommandContext, message: Message) -> None: + await ctx.send(message.content) +``` + +Add the commands to your client's handler and run `sync_commands()` to register +them with Discord. diff --git a/docs/converters.md b/docs/converters.md new file mode 100644 index 0000000..b3007fd --- /dev/null +++ b/docs/converters.md @@ -0,0 +1,25 @@ +# Command Argument Converters + +`disagreement.ext.commands` provides a number of built in converters that will parse string arguments into richer objects. These converters are automatically used when a command callback annotates its parameters with one of the supported types. + +## Supported Types + +- `int`, `float`, `bool`, and `str` +- `Member` – resolves a user mention or ID to a `Member` object for the current guild +- `Role` – resolves a role mention or ID to a `Role` object +- `Guild` – resolves a guild ID to a `Guild` object + +## Example + +```python +from disagreement.ext.commands import command +from disagreement.ext.commands.core import CommandContext +from disagreement.models import Member + +@command() +async def kick(ctx: CommandContext, target: Member): + await target.kick() + await ctx.send(f"Kicked {target.display_name}") +``` + +The framework will automatically convert the first argument to a `Member` using the mention or ID provided by the user. diff --git a/docs/events.md b/docs/events.md new file mode 100644 index 0000000..21baf95 --- /dev/null +++ b/docs/events.md @@ -0,0 +1,25 @@ +# Events + +Disagreement dispatches Gateway events to asynchronous callbacks. Handlers can be registered with `@client.event` or `client.on_event`. +Listeners may be removed later using `EventDispatcher.unregister(event_name, coro)`. + + +## PRESENCE_UPDATE + +Triggered when a user's presence changes. The callback receives a `PresenceUpdate` model. + +```python +@client.event +async def on_presence_update(presence: disagreement.PresenceUpdate): + ... +``` + +## TYPING_START + +Dispatched when a user begins typing in a channel. The callback receives a `TypingStart` model. + +```python +@client.event +async def on_typing_start(typing: disagreement.TypingStart): + ... +``` diff --git a/docs/gateway.md b/docs/gateway.md new file mode 100644 index 0000000..a71a4fe --- /dev/null +++ b/docs/gateway.md @@ -0,0 +1,17 @@ +# Gateway Connection and Reconnection + +`GatewayClient` manages the library's WebSocket connection. When the connection drops unexpectedly, it will now automatically attempt to reconnect using an exponential backoff strategy with jitter. + +The default behaviour tries up to five reconnect attempts, doubling the delay each time up to a configurable maximum. A small random jitter is added to spread out reconnect attempts when multiple clients restart at once. + +You can control the maximum number of retries and the backoff cap when constructing `Client`: + +```python +bot = Client( + token="your-token", + gateway_max_retries=10, + gateway_max_backoff=120.0, +) +``` + +These values are passed to `GatewayClient` and applied whenever the connection needs to be re-established. diff --git a/docs/i18n.md b/docs/i18n.md new file mode 100644 index 0000000..1b97816 --- /dev/null +++ b/docs/i18n.md @@ -0,0 +1,36 @@ +# Internationalization + +Disagreement can translate command names, descriptions and other text using JSON translation files. + +## Providing Translations + +Use `disagreement.i18n.load_translations` to load a JSON file for a locale. + +```python +from disagreement import i18n + +i18n.load_translations("es", "path/to/es.json") +``` + +The JSON file should map translation keys to translated strings: + +```json +{ + "greet": "Hola", + "description": "Comando de saludo" +} +``` + +You can also set translations programmatically with `i18n.set_translations`. + +## Using with Commands + +Pass a `locale` argument when defining an `AppCommand` or using the decorators. The command name and description will be looked up using the loaded translations. + +```python +@slash_command(name="greet", description="description", locale="es") +async def greet(ctx): + await ctx.send(i18n.translate("greet", ctx.locale or "es")) +``` + +If a translation is missing the key itself is returned. diff --git a/docs/oauth2.md b/docs/oauth2.md new file mode 100644 index 0000000..e279dbf --- /dev/null +++ b/docs/oauth2.md @@ -0,0 +1,48 @@ +# OAuth2 Setup + +This guide explains how to perform a basic OAuth2 flow with `disagreement`. + +1. Generate the authorization URL: + +```python +from disagreement.oauth import build_authorization_url + +url = build_authorization_url( + client_id="YOUR_CLIENT_ID", + redirect_uri="https://your.app/callback", + scope=["identify"], +) +print(url) +``` + +2. After the user authorizes your application and you receive a code, exchange it for a token: + +```python +import aiohttp +from disagreement.oauth import exchange_code_for_token + +async def get_token(code: str): + return await exchange_code_for_token( + client_id="YOUR_CLIENT_ID", + client_secret="YOUR_CLIENT_SECRET", + code=code, + redirect_uri="https://your.app/callback", + ) +``` + +`exchange_code_for_token` returns the JSON payload from Discord which includes +`access_token`, `refresh_token` and expiry information. + +3. When the access token expires, you can refresh it using the provided refresh +token: + +```python +from disagreement.oauth import refresh_access_token + +async def refresh(token: str): + return await refresh_access_token( + refresh_token=token, + client_id="YOUR_CLIENT_ID", + client_secret="YOUR_CLIENT_SECRET", + ) +``` diff --git a/docs/permissions.md b/docs/permissions.md new file mode 100644 index 0000000..7392dac --- /dev/null +++ b/docs/permissions.md @@ -0,0 +1,62 @@ +# Permission Helpers + +The `disagreement.permissions` module defines an :class:`~enum.IntFlag` +`Permissions` enumeration along with helper functions for working with +the Discord permission bitmask. + +## Permissions Enum + +Each attribute of ``Permissions`` represents a single permission bit. The value +is a power of two so multiple permissions can be combined using bitwise OR. + +```python +from disagreement.permissions import Permissions + +value = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES +``` + +## Helper Functions + +### ``permissions_value`` + +```python +permissions_value(*perms) -> int +``` + +Return an integer bitmask from one or more ``Permissions`` values. Nested +iterables are flattened automatically. + +### ``has_permissions`` + +```python +has_permissions(current, *perms) -> bool +``` + +Return ``True`` if ``current`` (an ``int`` or ``Permissions``) contains all of +the provided permissions. + +### ``missing_permissions`` + +```python +missing_permissions(current, *perms) -> List[Permissions] +``` + +Return a list of permissions that ``current`` does not contain. + +## Example + +```python +from disagreement.permissions import ( + Permissions, + has_permissions, + missing_permissions, +) + +current = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES + +if has_permissions(current, Permissions.SEND_MESSAGES): + print("Can send messages") + +print(missing_permissions(current, Permissions.ADMINISTRATOR)) +``` + diff --git a/docs/presence.md b/docs/presence.md new file mode 100644 index 0000000..0e5da19 --- /dev/null +++ b/docs/presence.md @@ -0,0 +1,29 @@ +# Updating Presence + +The `Client.change_presence` method allows you to update the bot's status and displayed activity. + +## Status Strings + +- `online` – show the bot as online +- `idle` – mark the bot as away +- `dnd` – do not disturb +- `invisible` – appear offline + +## Activity Types + +An activity dictionary must include a `name` and a `type` field. The type value corresponds to Discord's activity types: + +| Type | Meaning | +|-----:|--------------| +| `0` | Playing | +| `1` | Streaming | +| `2` | Listening | +| `3` | Watching | +| `4` | Custom | +| `5` | Competing | + +Example: + +```python +await client.change_presence(status="idle", activity={"name": "with Discord", "type": 0}) +``` diff --git a/docs/reactions.md b/docs/reactions.md new file mode 100644 index 0000000..40ca754 --- /dev/null +++ b/docs/reactions.md @@ -0,0 +1,32 @@ +# Handling Reactions + +`disagreement` provides simple helpers for adding and removing message reactions. + +## HTTP Methods + +Use the `HTTPClient` methods directly if you need lower level control: + +```python +await client._http.create_reaction(channel_id, message_id, "👍") +await client._http.delete_reaction(channel_id, message_id, "👍") +users = await client._http.get_reactions(channel_id, message_id, "👍") +``` + +You can also use the higher level helpers on :class:`Client`: + +```python +await client.create_reaction(channel_id, message_id, "👍") +await client.delete_reaction(channel_id, message_id, "👍") +users = await client.get_reactions(channel_id, message_id, "👍") +``` + +## Reaction Events + +Register listeners for `MESSAGE_REACTION_ADD` and `MESSAGE_REACTION_REMOVE`. +Each listener receives a `Reaction` model instance. + +```python +@client.on_event("MESSAGE_REACTION_ADD") +async def on_reaction(reaction: disagreement.Reaction): + print(f"{reaction.user_id} reacted with {reaction.emoji}") +``` diff --git a/docs/slash_commands.md b/docs/slash_commands.md new file mode 100644 index 0000000..42b24ff --- /dev/null +++ b/docs/slash_commands.md @@ -0,0 +1,22 @@ +# Using Slash Commands + +The library provides a slash command framework via the `ext.app_commands` package. Define commands with decorators and register them with Discord. + +```python +from disagreement.ext.app_commands import AppCommandGroup + +bot_commands = AppCommandGroup("bot", "Bot commands") + +@bot_commands.command(name="ping") +async def ping(ctx): + await ctx.respond("Pong!") +``` + +Use `AppCommandGroup` to group related commands. See the [components guide](using_components.md) for building interactive responses. + +## Next Steps + +- [Components](using_components.md) +- [Caching](caching.md) +- [Voice Features](voice_features.md) + diff --git a/docs/task_loop.md b/docs/task_loop.md new file mode 100644 index 0000000..86da7c6 --- /dev/null +++ b/docs/task_loop.md @@ -0,0 +1,15 @@ +# Task Loops + +The tasks extension allows you to run functions periodically. Decorate an async function with `@tasks.loop(seconds=...)` and start it using `.start()`. + +```python +from disagreement.ext import tasks + +@tasks.loop(seconds=5.0) +async def announce(): + print("Hello from a loop") + +announce.start() +``` + +Stop the loop with `.stop()` when you no longer need it. diff --git a/docs/typing_indicator.md b/docs/typing_indicator.md new file mode 100644 index 0000000..0550047 --- /dev/null +++ b/docs/typing_indicator.md @@ -0,0 +1,16 @@ +# Typing Indicator + +The library exposes an async context manager to send the typing indicator for a channel. + +```python +import asyncio +import disagreement + +client = disagreement.Client(token="YOUR_TOKEN") + +async def indicate(channel_id: str): + async with client.typing(channel_id): + await long_running_task() +``` + +This uses the underlying HTTP endpoint `/channels/{channel_id}/typing`. diff --git a/docs/using_components.md b/docs/using_components.md new file mode 100644 index 0000000..9ebd65a --- /dev/null +++ b/docs/using_components.md @@ -0,0 +1,168 @@ +# Using Message Components + +This guide explains how to work with the `disagreement` message component models. These examples are up to date with the current code base. + +## Enabling the New Component System + +Messages that use the component system must include the flag `IS_COMPONENTS_V2` (value `1 << 15`). Once this flag is set on a message it cannot be removed. + +## Component Categories + +The library exposes three broad categories of components: + +- **Layout Components** – organize the placement of other components. +- **Content Components** – display static text or media. +- **Interactive Components** – allow the user to interact with your message. + +## Action Row + +`ActionRow` is a layout container. It may hold up to five buttons or a single select menu. + +```python +from disagreement.models import ActionRow, Button +from disagreement.enums import ButtonStyle + +row = ActionRow(components=[ + Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="btn") +]) +``` + +## Button + +Buttons provide a clickable UI element. + +```python +from disagreement.models import Button +from disagreement.enums import ButtonStyle + +button = Button( + style=ButtonStyle.SUCCESS, + label="Confirm", + custom_id="confirm_button", +) +``` + +## Select Menus + +`SelectMenu` lets the user choose one or more options. The `type` parameter controls the menu variety (`STRING_SELECT`, `USER_SELECT`, `ROLE_SELECT`, `MENTIONABLE_SELECT`, `CHANNEL_SELECT`). + +```python +from disagreement.models import SelectMenu, SelectOption +from disagreement.enums import ComponentType, ChannelType + +menu = SelectMenu( + custom_id="example", + options=[ + SelectOption(label="Option 1", value="1"), + SelectOption(label="Option 2", value="2"), + ], + placeholder="Choose an option", + min_values=1, + max_values=1, + type=ComponentType.STRING_SELECT, +) +``` + +For channel selects you may pass `channel_types` with a list of allowed `ChannelType` values. + +## Section + +`Section` groups one or more `TextDisplay` components and can include an accessory `Button` or `Thumbnail`. + +```python +from disagreement.models import Section, TextDisplay, Thumbnail, UnfurledMediaItem + +section = Section( + components=[ + TextDisplay(content="## Section Title"), + TextDisplay(content="Sections can hold multiple text displays."), + ], + accessory=Thumbnail(media=UnfurledMediaItem(url="https://example.com/img.png")), +) +``` + +## Text Display + +`TextDisplay` simply renders markdown text. + +```python +from disagreement.models import TextDisplay + +text_display = TextDisplay(content="**Bold text**") +``` + +## Thumbnail + +`Thumbnail` shows a small image. Set `spoiler=True` to hide the image until clicked. + +```python +from disagreement.models import Thumbnail, UnfurledMediaItem + +thumb = Thumbnail( + media=UnfurledMediaItem(url="https://example.com/image.png"), + description="A picture", + spoiler=False, +) +``` + +## Media Gallery + +`MediaGallery` holds multiple `MediaGalleryItem` objects. + +```python +from disagreement.models import MediaGallery, MediaGalleryItem, UnfurledMediaItem + +gallery = MediaGallery( + items=[ + MediaGalleryItem(media=UnfurledMediaItem(url="https://example.com/1.png")), + MediaGalleryItem(media=UnfurledMediaItem(url="https://example.com/2.png")), + ] +) +``` + +## File + +`File` displays an uploaded file. Use `spoiler=True` to mark it as a spoiler. + +```python +from disagreement.models import File, UnfurledMediaItem + +file_component = File( + file=UnfurledMediaItem(url="attachment://file.zip"), + spoiler=False, +) +``` + +## Separator + +`Separator` adds vertical spacing or an optional divider line between components. + +```python +from disagreement.models import Separator + +separator = Separator(divider=True, spacing=2) +``` + +## Container + +`Container` visually groups a set of components and can apply an accent colour or spoiler. + +```python +from disagreement.models import Container, TextDisplay + +container = Container( + components=[TextDisplay(content="Inside a container")], + accent_color=0xFF0000, + spoiler=False, +) +``` + +A container can itself contain layout and content components, letting you build complex messages. + + +## Next Steps + +- [Slash Commands](slash_commands.md) +- [Caching](caching.md) +- [Voice Features](voice_features.md) + diff --git a/docs/voice_client.md b/docs/voice_client.md new file mode 100644 index 0000000..d91db2b --- /dev/null +++ b/docs/voice_client.md @@ -0,0 +1,29 @@ +# VoiceClient + +`VoiceClient` provides a minimal interface to Discord's voice gateway. It handles the WebSocket handshake and lets you stream audio over UDP. + +## Basic Usage + +```python +import asyncio +import os +import disagreement + +vc = disagreement.VoiceClient( + os.environ["DISCORD_VOICE_ENDPOINT"], + os.environ["DISCORD_SESSION_ID"], + os.environ["DISCORD_VOICE_TOKEN"], + int(os.environ["DISCORD_GUILD_ID"]), + int(os.environ["DISCORD_USER_ID"]), +) + +asyncio.run(vc.connect()) +``` + +After connecting you can send raw Opus frames: + +```python +await vc.send_audio_frame(opus_bytes) +``` + +Call `await vc.close()` when finished. diff --git a/docs/voice_features.md b/docs/voice_features.md new file mode 100644 index 0000000..bf6aa74 --- /dev/null +++ b/docs/voice_features.md @@ -0,0 +1,17 @@ +# Voice Features + +Disagreement includes experimental support for connecting to voice channels. You can join a voice channel and play audio using an FFmpeg subprocess. + +```python +voice = await client.join_voice(guild_id, channel_id) +voice.play_file("welcome.mp3") +``` + +Voice support is optional and may require additional system dependencies such as FFmpeg. + +## Next Steps + +- [Components](using_components.md) +- [Slash Commands](slash_commands.md) +- [Caching](caching.md) + diff --git a/docs/webhooks.md b/docs/webhooks.md new file mode 100644 index 0000000..e099208 --- /dev/null +++ b/docs/webhooks.md @@ -0,0 +1,34 @@ +# Working with Webhooks + +The `HTTPClient` includes helper methods for creating, editing and deleting Discord webhooks. + +## Create a webhook + +```python +from disagreement.http import HTTPClient + +http = HTTPClient(token="TOKEN") +payload = {"name": "My Webhook"} +webhook_data = await http.create_webhook("123", payload) +``` + +## Edit a webhook + +```python +await http.edit_webhook("456", {"name": "Renamed"}) +``` + +## Delete a webhook + +```python +await http.delete_webhook("456") +``` + +The methods return the raw webhook JSON. You can construct a `Webhook` model if needed: + +```python +from disagreement.models import Webhook + +webhook = Webhook(webhook_data) +print(webhook.id, webhook.name) +``` diff --git a/examples/basic_bot.py b/examples/basic_bot.py new file mode 100644 index 0000000..ffbd995 --- /dev/null +++ b/examples/basic_bot.py @@ -0,0 +1,218 @@ +# examples/basic_bot.py + +""" +A basic example bot using the Disagreement library. + +To run this bot: +1. Make sure you have the 'disagreement' library installed or accessible in your PYTHONPATH. + If running from the project root, it should be discoverable. +2. Set the DISCORD_BOT_TOKEN environment variable to your bot's token. + e.g., export DISCORD_BOT_TOKEN="your_actual_token_here" (Linux/macOS) + set DISCORD_BOT_TOKEN="your_actual_token_here" (Windows CMD) + $env:DISCORD_BOT_TOKEN="your_actual_token_here" (Windows PowerShell) +3. Run this script: python examples/basic_bot.py +""" + +import asyncio +import os +import logging # Optional: for more detailed logging + +# Assuming the 'disagreement' package is in the parent directory or installed +import sys +import traceback + +# Add project root to path if running script directly from examples folder +# and disagreement is not installed +if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +try: + import disagreement + from disagreement.ext import commands # Import the new commands extension +except ImportError: + print( + "Failed to import disagreement. Make sure it's installed or PYTHONPATH is set correctly." + ) + print( + "If running from the 'examples' directory, try running from the project root: python -m examples.basic_bot" + ) + sys.exit(1) + +from dotenv import load_dotenv + +load_dotenv() + +# Optional: Configure logging for more insight, especially for gateway events +# logging.basicConfig(level=logging.DEBUG) # For very verbose output +# logging.getLogger('disagreement.gateway').setLevel(logging.INFO) # Or DEBUG +# logging.getLogger('disagreement.http').setLevel(logging.INFO) + +# --- Bot Configuration --- +BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN") + +# --- Intents Configuration --- +# Define the intents your bot needs. For basic message reading and responding: +intents = ( + disagreement.GatewayIntent.GUILDS + | disagreement.GatewayIntent.GUILD_MESSAGES + | disagreement.GatewayIntent.MESSAGE_CONTENT +) # MESSAGE_CONTENT is privileged! + +# If you don't need message content and only react to commands/mentions, +# you might not need MESSAGE_CONTENT intent. +# intents = disagreement.GatewayIntent.default() # A good starting point without privileged intents +# intents |= disagreement.GatewayIntent.MESSAGE_CONTENT # Add if needed + +# --- Initialize the Client --- +if not BOT_TOKEN: + print("Error: The DISCORD_BOT_TOKEN environment variable is not set.") + print("Please set it before running the bot.") + sys.exit(1) + +# Initialize Client with a command prefix +client = disagreement.Client(token=BOT_TOKEN, intents=intents, command_prefix="!") + + +# --- Define a Cog for example commands --- +class ExampleCog(commands.Cog): # Ensuring this uses commands.Cog + def __init__( + self, bot_client + ): # Renamed client to bot_client to avoid conflict with self.client + super().__init__(bot_client) # Pass the client instance to the base Cog + + @commands.command(name="hello", aliases=["hi"]) + async def hello_command(self, ctx: commands.CommandContext, *, who: str = "world"): + """Greets someone.""" + await ctx.reply(f"Hello {ctx.author.mention} and {who}!") + print(f"Executed 'hello' command for {ctx.author.username}, greeting {who}") + + @commands.command() + async def ping(self, ctx: commands.CommandContext): + """Responds with Pong!""" + await ctx.reply("Pong!") + print(f"Executed 'ping' command for {ctx.author.username}") + + @commands.command() + async def me(self, ctx: commands.CommandContext): + """Shows information about the invoking user.""" + reply_content = ( + f"Hello {ctx.author.mention}!\n" + f"Your User ID is: {ctx.author.id}\n" + f"Your Username: {ctx.author.username}#{ctx.author.discriminator}\n" + f"Are you a bot? {'Yes' if ctx.author.bot else 'No'}" + ) + await ctx.reply(reply_content) + print(f"Executed 'me' command for {ctx.author.username}") + + @commands.command(name="add") + async def add_numbers(self, ctx: commands.CommandContext, num1: int, num2: int): + """Adds two numbers.""" + result = num1 + num2 + await ctx.reply(f"The sum of {num1} and {num2} is {result}.") + print( + f"Executed 'add' command for {ctx.author.username}: {num1} + {num2} = {result}" + ) + + @commands.command(name="say") + async def say_something(self, ctx: commands.CommandContext, *, text_to_say: str): + """Repeats the text you provide.""" + await ctx.reply(f"You said: {text_to_say}") + print( + f"Executed 'say' command for {ctx.author.username}, saying: {text_to_say}" + ) + + @commands.command(name="quit") + async def quit_command(self, ctx: commands.CommandContext): + """Shuts down the bot (requires YOUR_USER_ID to be set).""" + # Replace YOUR_USER_ID with your actual Discord User ID for a safe shutdown command + your_user_id = "YOUR_USER_ID_REPLACE_ME" # IMPORTANT: Replace this + if str(ctx.author.id) == your_user_id: + print("Quit command received. Shutting down...") + await ctx.reply("Shutting down...") + await self.client.close() # Access client via self.client from Cog + else: + await ctx.reply("You are not authorized to use this command.") + print( + f"Unauthorized quit attempt by {ctx.author.username} ({ctx.author.id})" + ) + + +# --- Event Handlers --- + + +@client.event +async def on_ready(): + """Called when the bot is ready and connected to Discord.""" + if client.user: + print( + f"Bot is ready! Logged in as {client.user.username}#{client.user.discriminator}" + ) + print(f"User ID: {client.user.id}") + else: + print("Bot is ready, but client.user is missing!") + print("------") + print("Disagreement Bot is operational.") + print("Listening for commands...") + + +@client.event +async def on_message(message: disagreement.Message): + """Called when a message is created and received.""" + # Command processing is now handled by the CommandHandler via client._process_message_for_commands + # This on_message can be used for other message-related logic if needed, + # or removed if all message handling is command-based. + + # Example: Log all messages (excluding bot's own, if client.user was available) + # if client.user and message.author.id == client.user.id: + # return + + print( + f"General on_message: #{message.channel_id} from {message.author.username}: {message.content}" + ) + # The old if/elif command structure is no longer needed here. + + +@client.on_event( + "GUILD_CREATE" +) # Example of listening to a specific event by its Discord name +async def on_guild_available(guild_data: dict): # Receives raw data for now + # In a real scenario, guild_data would be parsed into a Guild model + print(f"Guild available: {guild_data.get('name')} (ID: {guild_data.get('id')})") + + +# --- Main Execution --- +async def main(): + print("Starting Disagreement Bot...") + try: + # Add the Cog to the client + client.add_cog(ExampleCog(client)) # Pass client instance to Cog constructor + # client.add_cog is synchronous, but it schedules cog.cog_load() if it's async. + + await client.run() + except disagreement.AuthenticationError: + print( + "Authentication failed. Please check your bot token and ensure it's correct." + ) + except disagreement.DisagreementException as e: + print(f"A Disagreement library error occurred: {e}") + except KeyboardInterrupt: + print("Bot shutting down due to KeyboardInterrupt...") + except Exception as e: + print(f"An unexpected error occurred: {e}") + traceback.print_exc() + finally: + if not client.is_closed(): + print("Ensuring client is closed...") + await client.close() + print("Bot has been shut down.") + + +if __name__ == "__main__": + # Note: On Windows, the default asyncio event loop policy might not support add_signal_handler. + # If you encounter issues with Ctrl+C not working as expected, + # you might need to adjust the event loop policy or handle shutdown differently. + # For example, for Windows: + # if os.name == 'nt': + # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(main()) diff --git a/examples/component_bot.py b/examples/component_bot.py new file mode 100644 index 0000000..eae1e67 --- /dev/null +++ b/examples/component_bot.py @@ -0,0 +1,292 @@ +import os +import asyncio +from typing import Union, Optional +from disagreement import Client, ui, HybridContext +from disagreement.models import ( + Message, + SelectOption, + User, + Member, + Role, + Attachment, + Channel, + ActionRow, + Button, + Section, + TextDisplay, + Thumbnail, + UnfurledMediaItem, + MediaGallery, + MediaGalleryItem, + Container, +) +from disagreement.enums import ( + ButtonStyle, + GatewayIntent, + ChannelType, + MessageFlags, + InteractionCallbackType, + MessageFlags, +) +from disagreement.ext.commands.cog import Cog +from disagreement.ext.commands.core import CommandContext +from disagreement.ext.app_commands.decorators import hybrid_command, slash_command +from disagreement.ext.app_commands.context import AppCommandContext +from disagreement.interactions import ( + Interaction, + InteractionResponsePayload, + InteractionCallbackData, +) +from dotenv import load_dotenv + +load_dotenv() + +# Get the bot token and application ID from the environment variables +token = os.getenv("DISCORD_BOT_TOKEN") +application_id = os.getenv("DISCORD_APPLICATION_ID") + +if not token: + raise ValueError("Bot token not found in environment variables") +if not application_id: + raise ValueError("Application ID not found in environment variables") + +# Define the intents +intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT + +# Create a new client +client = Client( + + token=token, + + intents=intents, + + command_prefix="!", + + mention_replies=True, +) + +# Simple stock data used for the stock cycling example +STOCKS = [ + {"symbol": "AAPL", "name": "Apple Inc.", "price": "$175"}, + {"symbol": "MSFT", "name": "Microsoft Corp.", "price": "$315"}, + {"symbol": "GOOGL", "name": "Alphabet Inc.", "price": "$128"}, +] + + +# Define a View class that contains our components +class MyView(ui.View): + def __init__(self): + super().__init__(timeout=180) # 180-second timeout + self.click_count = 0 + + @ui.button(label="Click Me!", style=ButtonStyle.SUCCESS, emoji="🖱️") + async def click_me(self, interaction: Interaction): + self.click_count += 1 + await interaction.respond( + content=f"You've clicked the button {self.click_count} times!", + ephemeral=True, + ) + + @ui.select( + custom_id="string_select", + placeholder="Choose an option", + options=[ + SelectOption( + label="Option 1", value="opt1", description="This is the first option" + ), + SelectOption( + label="Option 2", value="opt2", description="This is the second option" + ), + ], + ) + async def select_menu(self, interaction: Interaction): + if interaction.data and interaction.data.values: + await interaction.respond( + content=f"You selected: {interaction.data.values[0]}", + ephemeral=True, + ) + + async def on_timeout(self): + # This method is called when the view times out. + # You can use this to edit the original message, for example. + print("View has timed out!") + + +# View for cycling through available stocks +class StockView(ui.View): + def __init__(self): + super().__init__(timeout=180) + self.index = 0 + + @ui.button(label="Next Stock", style=ButtonStyle.PRIMARY) + async def next_stock(self, interaction: Interaction): + self.index = (self.index + 1) % len(STOCKS) + stock = STOCKS[self.index] + # Edit the message by responding to the interaction with an update + await interaction.edit( + content=f"**{stock['symbol']}** - {stock['name']}\nPrice: {stock['price']}", + components=self.to_components(), + # Preserve the original reply mention + allowed_mentions={"replied_user": True}, + ) + + +class ComponentCommandsCog(Cog): + def __init__(self, client: Client): + super().__init__(client) + + @hybrid_command(name="components", description="Sends interactive components.") + async def components_command(self, ctx: Union[CommandContext, AppCommandContext]): + # Send a message with the view using a hybrid context helper + hybrid = HybridContext(ctx) + await hybrid.send("Here are your components:", view=MyView()) + + @hybrid_command( + name="stocks", description="Shows stock information with navigation." + ) + async def stocks_command(self, ctx: Union[CommandContext, AppCommandContext]): + # Show the first stock and attach a button to cycle through them + first = STOCKS[0] + hybrid = HybridContext(ctx) + await hybrid.send( + f"**{first['symbol']}** - {first['name']}\nPrice: {first['price']}", + view=StockView(), + ) + + @hybrid_command(name="sectiondemo", description="Shows a section layout.") + async def section_demo(self, ctx: Union[CommandContext, AppCommandContext]) -> None: + section = Section( + components=[ + TextDisplay(content="## Advanced Components"), + TextDisplay(content="Sections group text with accessories."), + ], + accessory=Thumbnail( + media=UnfurledMediaItem(url="https://placehold.co/100x100.png") + ), + ) + container = Container(components=[section], accent_color=0x5865F2) + hybrid = HybridContext(ctx) + await hybrid.send( + components=[container], + flags=MessageFlags.IS_COMPONENTS_V2.value, + ) + + @hybrid_command(name="gallerydemo", description="Shows a media gallery.") + async def gallery_demo(self, ctx: Union[CommandContext, AppCommandContext]) -> None: + gallery = MediaGallery( + items=[ + MediaGalleryItem( + media=UnfurledMediaItem(url="https://placehold.co/600x400.png") + ), + MediaGalleryItem( + media=UnfurledMediaItem(url="https://placehold.co/600x400.jpg") + ), + ] + ) + hybrid = HybridContext(ctx) + await hybrid.send( + components=[gallery], + flags=MessageFlags.IS_COMPONENTS_V2.value, + ) + + @hybrid_command( + name="complex_components", + description="Shows a complex layout with multiple containers.", + ) + async def complex_components( + self, ctx: Union[CommandContext, AppCommandContext] + ) -> None: + container1 = Container( + components=[ + Section( + components=[ + TextDisplay(content="## Complex Layout Example"), + TextDisplay( + content="This container has an accent color and includes a section with an action row of buttons. There is a thumbnail accessory to the right." + ), + ], + accessory=Thumbnail( + media=UnfurledMediaItem(url="https://placehold.co/100x100.png") + ), + ), + ActionRow( + components=[ + Button( + style=ButtonStyle.PRIMARY, + label="Primary", + custom_id="complex_primary", + ), + Button( + style=ButtonStyle.SUCCESS, + label="Success", + custom_id="complex_success", + ), + Button( + style=ButtonStyle.DANGER, + label="Destructive", + custom_id="complex_destructive", + ), + ] + ), + ], + accent_color=0x5865F2, + ) + container2 = Container( + components=[ + TextDisplay( + content="## Another Container\nThis container has no accent color and includes a media gallery." + ), + MediaGallery( + items=[ + MediaGalleryItem( + media=UnfurledMediaItem( + url="https://placehold.co/300x200.png" + ) + ), + MediaGalleryItem( + media=UnfurledMediaItem( + url="https://placehold.co/300x200.jpg" + ) + ), + MediaGalleryItem( + media=UnfurledMediaItem( + url="https://placehold.co/300x200.gif" + ) + ), + ] + ), + ] + ) + + hybrid = HybridContext(ctx) + await hybrid.send( + components=[container1, container2], + flags=MessageFlags.IS_COMPONENTS_V2.value, + ) + + +async def main(): + @client.event + async def on_ready(): + if client.user: + print(f"Logged in as {client.user.username}") + if client.application_id: + try: + print("Attempting to sync application commands...") + await client.app_command_handler.sync_commands( + application_id=client.application_id + ) + print("Application commands synced successfully.") + except Exception as e: + print(f"Error syncing application commands: {e}") + else: + print( + "Client's application ID is not set. Skipping application command sync." + ) + + client.add_cog(ComponentCommandsCog(client)) + await client.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/context_menus.py b/examples/context_menus.py new file mode 100644 index 0000000..8ceb1e7 --- /dev/null +++ b/examples/context_menus.py @@ -0,0 +1,71 @@ +"""Examples showing how to use context menu commands.""" + +import os +import sys + +# Allow running example from repository root +if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from disagreement.client import Client +from disagreement.ext.app_commands import ( + user_command, + message_command, + AppCommandContext, +) +from disagreement.models import User, Message + +from dotenv import load_dotenv + +load_dotenv() + +BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "") +APP_ID = os.environ.get("DISCORD_APPLICATION_ID", "") +client = Client(token=BOT_TOKEN, application_id=APP_ID) + + +@client.event +async def on_ready(): + """Called when the bot is ready and connected to Discord.""" + if client.user: + print(f"Bot is ready! Logged in as {client.user.username}") + print("Attempting to sync application commands...") + try: + if client.application_id: + await client.app_command_handler.sync_commands( + application_id=client.application_id + ) + print("Application commands synced successfully.") + else: + print("Skipping command sync: application ID is not set.") + except Exception as e: + print(f"Error syncing application commands: {e}") + else: + print("Bot is ready, but client.user is missing!") + print("------") + + +@user_command(name="User Info") +async def user_info(ctx: AppCommandContext, user: User) -> None: + await ctx.send( + f"Selected user: {user.username}#{user.discriminator}", ephemeral=True + ) + + +@message_command(name="Quote") +async def quote(ctx: AppCommandContext, message: Message) -> None: + await ctx.send(f"Quoted message: {message.content}", ephemeral=True) + + +client.app_command_handler.add_command(user_info) +client.app_command_handler.add_command(quote) + + +async def main() -> None: + await client.run() + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/examples/hybrid_bot.py b/examples/hybrid_bot.py new file mode 100644 index 0000000..eeb8f19 --- /dev/null +++ b/examples/hybrid_bot.py @@ -0,0 +1,315 @@ +import asyncio +import os +import logging +from typing import Any, Optional, Literal, Union + +from disagreement import HybridContext + +from disagreement.client import Client +from disagreement.ext.commands.cog import Cog +from disagreement.ext.commands.core import CommandContext +from disagreement.ext.app_commands.decorators import ( + slash_command, + user_command, + message_command, + hybrid_command, +) +from disagreement.ext.app_commands.commands import ( + AppCommandGroup, +) # For defining groups +from disagreement.ext.app_commands.context import AppCommandContext # Added +from disagreement.models import ( + User, + Member, + Role, + Attachment, + Message, + Channel, +) # For type hints +from disagreement.enums import ( + ChannelType, +) # For channel option type hints, assuming it exists + +# from disagreement.interactions import Interaction # Replaced by AppCommandContext + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +from dotenv import load_dotenv + +load_dotenv() + + +# --- Define a Test Cog --- +class TestCog(Cog): + def __init__(self, client: Client): + super().__init__(client) + + @slash_command(name="greet", description="Sends a greeting.") + async def greet_slash(self, ctx: AppCommandContext, name: str): + await ctx.send(f"Hello, {name}! (Slash)") + + @user_command(name="Show User Info") + async def show_user_info_user( + self, ctx: AppCommandContext, user: User + ): # Target user is in ctx.interaction.data.target_id and resolved + target_user = ( + ctx.interaction.data.resolved.users.get(ctx.interaction.data.target_id) + if ctx.interaction.data + and ctx.interaction.data.resolved + and ctx.interaction.data.target_id + else user + ) + if target_user: + await ctx.send( + f"User: {target_user.username}#{target_user.discriminator} (ID: {target_user.id}) (User Cmd)", + ephemeral=True, + ) + else: + await ctx.send("Could not find user information.", ephemeral=True) + + @message_command(name="Quote Message") + async def quote_message_msg( + self, ctx: AppCommandContext, message: Message + ): # Target message is in ctx.interaction.data.target_id and resolved + target_message = ( + ctx.interaction.data.resolved.messages.get(ctx.interaction.data.target_id) + if ctx.interaction.data + and ctx.interaction.data.resolved + and ctx.interaction.data.target_id + else message + ) + if target_message: + await ctx.send( + f'Quoting {target_message.author.username}: "{target_message.content}" (Message Cmd)', + ephemeral=True, + ) + else: + await ctx.send("Could not find message to quote.", ephemeral=True) + + @hybrid_command(name="ping", description="Checks bot latency.", aliases=["pong"]) + async def ping_hybrid( + self, ctx: Union[CommandContext, AppCommandContext], arg: Optional[str] = None + ): + # latency = self.client.latency # Assuming client has latency attribute from gateway - Commented out for now + latency_ms = "N/A" # Placeholder + hybrid = HybridContext(ctx) + await hybrid.send(f"Pong! Arg: {arg} (Hybrid)") + + @slash_command(name="options_test", description="Tests various option types.") + async def options_test_slash( + self, + ctx: AppCommandContext, + text: str, + integer: int, + boolean: bool, + number: float, + user_option: User, + role_option: Role, + attachment_option: Attachment, + choice_option_str: Literal["apple", "banana", "cherry"], + # Channel and member options as well as numeric Literal choices are + # not yet exercised in tests pending full library support. + ): + response_parts = [ + f"Text: {text}", + f"Integer: {integer}", + f"Boolean: {boolean}", + f"Number: {number}", + f"User: {user_option.username}#{user_option.discriminator}", + f"Role: {role_option.name}", + f"Attachment: {attachment_option.filename} (URL: {attachment_option.url})", + f"Choice Str: {choice_option_str}", + ] + await ctx.send("\n".join(response_parts), ephemeral=True) + + # --- Subcommand Group Test --- + # Define the group as a class attribute. + # The AppCommandHandler's discovery mechanism (via Cog) should pick up AppCommandGroup instances. + settings_group = AppCommandGroup( + name="settings", + description="Manage bot settings.", + # guild_ids can be added here if the group is guild-specific + ) + + @slash_command( + name="show", description="Shows current setting values.", parent=settings_group + ) + async def settings_show( + self, ctx: AppCommandContext, setting_name: Optional[str] = None + ): + if setting_name: + await ctx.send( + f"Showing value for setting: {setting_name} (Value: Placeholder)", + ephemeral=True, + ) + else: + await ctx.send( + "Showing all settings: (Placeholder for all settings)", ephemeral=True + ) + + @slash_command( + name="update", description="Updates a setting.", parent=settings_group + ) + async def settings_update( + self, ctx: AppCommandContext, setting_name: str, value: str + ): + await ctx.send( + f"Updated setting: {setting_name} to value: {value}", ephemeral=True + ) + + # The Cog's metaclass or command registration logic should handle adding `settings_group` + # (and its subcommands) to the client's AppCommandHandler. + # The decorators now handle associating subcommands with their parent group. + + @slash_command( + name="numeric_choices_test", description="Tests integer and float choices." + ) + async def numeric_choices_test_slash( + self, + ctx: AppCommandContext, + int_choice: Literal[10, 20, 30, 42], + float_choice: float, + ): + response = ( + f"Integer Choice: {int_choice} (Type: {type(int_choice).__name__})\n" + f"Float Choice: {float_choice} (Type: {type(float_choice).__name__})" + ) + await ctx.send(response, ephemeral=True) + + @slash_command( + name="numeric_choices_extended", + description="Tests additional integer and float choice handling.", + ) + async def numeric_choices_extended_slash( + self, + ctx: AppCommandContext, + int_choice: Literal[-5, 0, 5], + float_choice: float, + ): + response = ( + f"Int Choice: {int_choice} (Type: {type(int_choice).__name__})\n" + f"Float Choice: {float_choice} (Type: {type(float_choice).__name__})" + ) + await ctx.send(response, ephemeral=True) + + @slash_command( + name="channel_member_test", + description="Tests channel and member options.", + ) + async def channel_member_test_slash( + self, + ctx: AppCommandContext, + channel: Channel, + member: Member, + ): + response = ( + f"Channel: {channel.name} (Type: {channel.type.name})\n" + f"Member: {member.username}#{member.discriminator}" + ) + await ctx.send(response, ephemeral=True) + + @slash_command( + name="channel_types_test", + description="Demonstrates multiple channel type options.", + ) + async def channel_types_test_slash( + self, + ctx: AppCommandContext, + text_channel: Channel, + voice_channel: Channel, + category_channel: Channel, + ): + response = ( + f"Text: {text_channel.type.name}\n" + f"Voice: {voice_channel.type.name}\n" + f"Category: {category_channel.type.name}" + ) + await ctx.send(response, ephemeral=True) + + +# --- Main Bot Script --- +async def main(): + bot_token = os.getenv("DISCORD_BOT_TOKEN") + application_id = os.getenv("DISCORD_APPLICATION_ID") + + if not bot_token: + logger.error("Error: DISCORD_BOT_TOKEN environment variable not set.") + return + if not application_id: + logger.error("Error: DISCORD_APPLICATION_ID environment variable not set.") + return + + client = Client(token=bot_token, command_prefix="!", application_id=application_id) + + @client.event + async def on_ready(): + if client.user: + logger.info( + f"Bot logged in as {client.user.username}#{client.user.discriminator}" + ) + else: + logger.error( + "Client ready, but client.user is not populated! This should not happen." + ) + return # Avoid proceeding if basic client info isn't there + + if client.application_id: + logger.info(f"Application ID is: {client.application_id}") + # Sync application commands (global in this case) + try: + logger.info("Attempting to sync application commands...") + # Ensure application_id is not None before passing + app_id_to_sync = client.application_id + if ( + app_id_to_sync is not None + ): # Redundant due to outer if, but good for clarity + await client.app_command_handler.sync_commands( + application_id=app_id_to_sync + ) + logger.info("Application commands synced successfully.") + else: # Should not be reached if outer if client.application_id is true + logger.error( + "Application ID was None despite initial check. Skipping sync." + ) + except Exception as e: + logger.error(f"Error syncing application commands: {e}", exc_info=True) + else: + # This case should be less likely now that Client gets it from READY. + # If DISCORD_APPLICATION_ID was critical as a fallback, that logic would be here. + # For now, we rely on the READY event. + logger.warning( + "Client's application ID is not set after READY. Skipping application command sync." + ) + # Check if the environment variable was provided, as a diagnostic. + if not application_id: + logger.warning( + "DISCORD_APPLICATION_ID environment variable was also not provided." + ) + + client.add_cog(TestCog(client)) + + try: + await client.run() + except KeyboardInterrupt: + logger.info("Bot shutting down...") + except Exception as e: + logger.error( + f"An error occurred in the bot's main run loop: {e}", exc_info=True + ) + finally: + if not client.is_closed(): + await client.close() + logger.info("Bot has been closed.") + + +if __name__ == "__main__": + # For Windows, to allow graceful shutdown with Ctrl+C + if os.name == "nt": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Main loop interrupted. Exiting.") diff --git a/examples/modal_command.py b/examples/modal_command.py new file mode 100644 index 0000000..120747f --- /dev/null +++ b/examples/modal_command.py @@ -0,0 +1,65 @@ +"""Example showing how to present a modal using a slash command.""" + +import os +import asyncio +from dotenv import load_dotenv + +from disagreement import Client, ui +from disagreement.enums import GatewayIntent, TextInputStyle +from disagreement.ext.app_commands.decorators import slash_command +from disagreement.ext.app_commands.context import AppCommandContext + +load_dotenv() + +token = os.getenv("DISCORD_BOT_TOKEN", "") +application_id = os.getenv("DISCORD_APPLICATION_ID", "") + +client = Client( + token=token, application_id=application_id, intents=GatewayIntent.default() +) + + +class FeedbackModal(ui.Modal): + def __init__(self) -> None: + super().__init__(title="Feedback", custom_id="feedback") + + @ui.text_input(label="Your feedback", style=TextInputStyle.PARAGRAPH) + async def feedback(self, interaction): + await interaction.respond(content="Thanks for your feedback!", ephemeral=True) + + +@slash_command(name="feedback", description="Send feedback via a modal") +async def feedback_command(ctx: AppCommandContext): + await ctx.interaction.respond_modal(FeedbackModal()) + + +client.app_command_handler.add_command(feedback_command) + + +@client.event +async def on_ready(): + """Called when the bot is ready and connected to Discord.""" + if client.user: + print(f"Bot is ready! Logged in as {client.user.username}") + print("Attempting to sync application commands...") + try: + if client.application_id: + await client.app_command_handler.sync_commands( + application_id=client.application_id + ) + print("Application commands synced successfully.") + else: + print("Skipping command sync: application ID is not set.") + except Exception as e: + print(f"Error syncing application commands: {e}") + else: + print("Bot is ready, but client.user is missing!") + print("------") + + +async def main(): + await client.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/modal_send.py b/examples/modal_send.py new file mode 100644 index 0000000..b6b007e --- /dev/null +++ b/examples/modal_send.py @@ -0,0 +1,63 @@ +"""Example showing how to send a modal.""" + +import os +import sys + +from dotenv import load_dotenv + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from disagreement import Client, GatewayIntent, ui # type: ignore +from disagreement.ext.app_commands.decorators import slash_command +from disagreement.ext.app_commands.context import AppCommandContext + +load_dotenv() +TOKEN = os.getenv("DISCORD_BOT_TOKEN", "") +APP_ID = os.getenv("DISCORD_APPLICATION_ID", "") + +if not TOKEN: + print("DISCORD_BOT_TOKEN not set") + sys.exit(1) + +client = Client(token=TOKEN, intents=GatewayIntent.default(), application_id=APP_ID) + + +class NameModal(ui.Modal): + def __init__(self): + super().__init__(title="Your Name", custom_id="name_modal") + self.name = ui.TextInput(label="Name", custom_id="name") + + +@slash_command(name="namemodal", description="Shows a modal") +async def _namemodal(ctx: AppCommandContext): + await ctx.interaction.response.send_modal(NameModal()) + + +client.app_command_handler.add_command(_namemodal) + + +@client.event +async def on_ready(): + """Called when the bot is ready and connected to Discord.""" + if client.user: + print(f"Bot is ready! Logged in as {client.user.username}") + print("Attempting to sync application commands...") + try: + if client.application_id: + await client.app_command_handler.sync_commands( + application_id=client.application_id + ) + print("Application commands synced successfully.") + else: + print("Skipping command sync: application ID is not set.") + except Exception as e: + print(f"Error syncing application commands: {e}") + else: + print("Bot is ready, but client.user is missing!") + print("------") + + +if __name__ == "__main__": + import asyncio + + asyncio.run(client.run()) diff --git a/examples/sharded_bot.py b/examples/sharded_bot.py new file mode 100644 index 0000000..fb20e54 --- /dev/null +++ b/examples/sharded_bot.py @@ -0,0 +1,36 @@ +"""Example bot demonstrating gateway sharding.""" + +import asyncio +import os +import sys + +# Ensure local package is importable when running from the examples directory +if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import disagreement + +TOKEN = os.environ.get("DISCORD_BOT_TOKEN") +if not TOKEN: + raise RuntimeError("DISCORD_BOT_TOKEN environment variable not set") + +client = disagreement.Client(token=TOKEN, shard_count=2) + + +@client.event +async def on_ready(): + if client.user: + print(f"Shard bot ready as {client.user.username}#{client.user.discriminator}") + else: + print("Shard bot ready") + + +async def main(): + if not TOKEN: + print("DISCORD_BOT_TOKEN environment variable not set") + return + await client.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/task_loop.py b/examples/task_loop.py new file mode 100644 index 0000000..e8f1c1f --- /dev/null +++ b/examples/task_loop.py @@ -0,0 +1,30 @@ +"""Example showing the tasks extension.""" + +import asyncio +import os +import sys + +# Allow running from the examples folder without installing +if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from disagreement.ext import tasks + +counter = 0 + + +@tasks.loop(seconds=1.0) +async def ticker() -> None: + global counter + counter += 1 + print(f"Tick {counter}") + + +async def main() -> None: + ticker.start() + await asyncio.sleep(5) + ticker.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/voice_bot.py b/examples/voice_bot.py new file mode 100644 index 0000000..c60b0b6 --- /dev/null +++ b/examples/voice_bot.py @@ -0,0 +1,61 @@ +"""Example bot demonstrating VoiceClient usage.""" + +import os +import asyncio +import sys + +# If running from the examples directory +if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from typing import cast + +from dotenv import load_dotenv + +import disagreement + +load_dotenv() + +_VOICE_ENDPOINT = os.getenv("DISCORD_VOICE_ENDPOINT") +_VOICE_TOKEN = os.getenv("DISCORD_VOICE_TOKEN") +_VOICE_SESSION_ID = os.getenv("DISCORD_SESSION_ID") +_GUILD_ID = os.getenv("DISCORD_GUILD_ID") +_USER_ID = os.getenv("DISCORD_USER_ID") + +if not all([_VOICE_ENDPOINT, _VOICE_TOKEN, _VOICE_SESSION_ID, _GUILD_ID, _USER_ID]): + print("Missing one or more required environment variables for voice connection") + sys.exit(1) + +assert _VOICE_ENDPOINT +assert _VOICE_TOKEN +assert _VOICE_SESSION_ID +assert _GUILD_ID +assert _USER_ID + +VOICE_ENDPOINT = cast(str, _VOICE_ENDPOINT) +VOICE_TOKEN = cast(str, _VOICE_TOKEN) +VOICE_SESSION_ID = cast(str, _VOICE_SESSION_ID) +GUILD_ID = int(cast(str, _GUILD_ID)) +USER_ID = int(cast(str, _USER_ID)) + + +async def main() -> None: + vc = disagreement.VoiceClient( + VOICE_ENDPOINT, + VOICE_SESSION_ID, + VOICE_TOKEN, + GUILD_ID, + USER_ID, + ) + await vc.connect() + + try: + # Send silence frame as an example + await vc.send_audio_frame(b"\xf8\xff\xfe") + await asyncio.sleep(1) + finally: + await vc.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a78f520 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,55 @@ +[project] +name = "disagreement" +version = "0.0.1" +description = "A Python library for the Discord API." +readme = "README.md" +requires-python = ">=3.11" +license = {text = "BSD 3-Clause"} +authors = [ + {name = "Slipstream", email = "me@slipstreamm.dev"} +] +keywords = ["discord", "api", "bot", "async", "aiohttp"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Internet", +] + +dependencies = [ + "aiohttp>=3.9.0,<4.0.0", +] + +[project.optional-dependencies] +test = [ + "pytest>=8.0.0", + "pytest-asyncio>=1.0.0", + "hypothesis>=6.89.0", +] +dev = [ + "dotenv>=0.0.5", +] + +[project.urls] +Homepage = "https://github.com/Slipstreamm/disagreement" +Issues = "https://github.com/Slipstreamm/disagreement/issues" + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +# Optional: for linting/formatting, e.g., Ruff +# [tool.ruff] +# line-length = 88 +# select = ["E", "W", "F", "I", "UP", "C4", "B"] # Example rule set +# ignore = [] + +# [tool.ruff.format] +# quote-style = "double" diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..084e35e --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,9 @@ +{ + "include": ["."], + "exclude": ["**/node_modules", "**/__pycache__", "**/.venv", "**/.git", "**/dist", "**/build", "**/tests/**", "tavilytool.py"], + "ignore": [], + "reportMissingImports": true, + "reportMissingTypeStubs": false, + "pythonVersion": "3.13", + "typeCheckingMode": "standard" +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7f70d29 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +-e .[test,dev] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..1d70936 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[options] +packages = find: \ No newline at end of file diff --git a/tavilytool.py b/tavilytool.py new file mode 100644 index 0000000..8f63da7 --- /dev/null +++ b/tavilytool.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +""" +Tavily API Script for AI Agents +Execute with: python tavily.py "your search query" +""" + +import os +import sys +import json +import requests # type: ignore +import argparse +from typing import Dict, List, Optional + + +class TavilyAPI: + def __init__(self, api_key: str): + self.api_key = api_key + self.base_url = "https://api.tavily.com" + + def search( + self, + query: str, + search_depth: str = "basic", + include_answer: bool = True, + include_images: bool = False, + include_raw_content: bool = False, + max_results: int = 5, + include_domains: Optional[List[str]] = None, + exclude_domains: Optional[List[str]] = None, + ) -> Dict: + """ + Perform a search using Tavily API + + Args: + query: Search query string + search_depth: "basic" or "advanced" + include_answer: Include AI-generated answer + include_images: Include images in results + include_raw_content: Include raw HTML content + max_results: Maximum number of results (1-20) + include_domains: List of domains to include + exclude_domains: List of domains to exclude + + Returns: + Dictionary containing search results + """ + url = f"{self.base_url}/search" + + payload = { + "api_key": self.api_key, + "query": query, + "search_depth": search_depth, + "include_answer": include_answer, + "include_images": include_images, + "include_raw_content": include_raw_content, + "max_results": max_results, + } + + if include_domains: + payload["include_domains"] = include_domains + if exclude_domains: + payload["exclude_domains"] = exclude_domains + + try: + response = requests.post(url, json=payload, timeout=30) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + return {"error": f"API request failed: {str(e)}"} + except json.JSONDecodeError: + return {"error": "Invalid JSON response from API"} + + +def format_results(results: Dict) -> str: + """Format search results for display""" + if "error" in results: + return f"❌ Error: {results['error']}" + + output = [] + + # Add answer if available + if results.get("answer"): + output.append("🤖 AI Answer:") + output.append(f" {results['answer']}") + output.append("") + + # Add search results + if results.get("results"): + output.append("🔍 Search Results:") + for i, result in enumerate(results["results"], 1): + output.append(f" {i}. {result.get('title', 'No title')}") + output.append(f" URL: {result.get('url', 'No URL')}") + if result.get("content"): + # Truncate content to first 200 chars + content = ( + result["content"][:200] + "..." + if len(result["content"]) > 200 + else result["content"] + ) + output.append(f" Content: {content}") + output.append("") + + # Add images if available + if results.get("images"): + output.append("🖼️ Images:") + for img in results["images"][:3]: # Show first 3 images + output.append(f" {img}") + output.append("") + + return "\n".join(output) + + +def main(): + parser = argparse.ArgumentParser(description="Search using Tavily API") + parser.add_argument("query", help="Search query") + parser.add_argument( + "--depth", + choices=["basic", "advanced"], + default="basic", + help="Search depth (default: basic)", + ) + parser.add_argument( + "--max-results", + type=int, + default=5, + help="Maximum number of results (default: 5)", + ) + parser.add_argument( + "--include-images", action="store_true", help="Include images in results" + ) + parser.add_argument( + "--no-answer", action="store_true", help="Don't include AI-generated answer" + ) + parser.add_argument( + "--include-domains", nargs="+", help="Include only these domains" + ) + parser.add_argument("--exclude-domains", nargs="+", help="Exclude these domains") + parser.add_argument("--raw", action="store_true", help="Output raw JSON response") + + args = parser.parse_args() + + # Get API key from environment + api_key = os.getenv("TAVILY_API_KEY") + if not api_key: + print("❌ Error: TAVILY_API_KEY environment variable not set") + print("Set it with: export TAVILY_API_KEY='your-api-key-here'") + sys.exit(1) + + # Initialize Tavily API + tavily = TavilyAPI(api_key) + + # Perform search + results = tavily.search( + query=args.query, + search_depth=args.depth, + include_answer=not args.no_answer, + include_images=args.include_images, + max_results=args.max_results, + include_domains=args.include_domains, + exclude_domains=args.exclude_domains, + ) + + # Output results + if args.raw: + print(json.dumps(results, indent=2)) + else: + print(format_results(results)) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2203604 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,111 @@ +import pytest +from unittest.mock import AsyncMock + +from disagreement.interactions import Interaction +from disagreement.enums import InteractionType +from disagreement.models import Message + + +class DummyHTTP: + def __init__(self): + self.create_interaction_response = AsyncMock() + self.create_followup_message = AsyncMock( + return_value={ + "id": "123", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + ) + self.edit_original_interaction_response = AsyncMock(return_value={"id": "123"}) + self.edit_followup_message = AsyncMock( + return_value={ + "id": "123", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + ) + + +class DummyBot: + def __init__(self): + self._http = DummyHTTP() + self.application_id = "app123" + self._guilds = {} + self._channels = {} + + def get_guild(self, gid): + return self._guilds.get(gid) + + async def fetch_channel(self, cid): + return self._channels.get(cid) + + +class DummyCommandHTTP: + def __init__(self): + self.edit_message = AsyncMock(return_value={"id": "321", "channel_id": "c"}) + + +class DummyCommandBot: + def __init__(self): + self._http = DummyCommandHTTP() + + async def edit_message(self, channel_id, message_id, *, content=None, **kwargs): + return await self._http.edit_message( + channel_id, message_id, {"content": content, **kwargs} + ) + + +class DummyClient: + pass + + +class DummyInteraction: + data = None + + +@pytest.fixture() +def dummy_bot(): + return DummyBot() + + +@pytest.fixture() +def interaction(dummy_bot): + data = { + "id": "1", + "application_id": dummy_bot.application_id, + "type": InteractionType.APPLICATION_COMMAND.value, + "token": "tok", + "version": 1, + } + return Interaction(data, client_instance=dummy_bot) + + +@pytest.fixture() +def command_bot(): + return DummyCommandBot() + + +@pytest.fixture() +def message(command_bot): + message_data = { + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + return Message(message_data, client_instance=command_bot) + + +@pytest.fixture() +def dummy_client(): + return DummyClient() + + +@pytest.fixture() +def basic_interaction(): + return DummyInteraction() diff --git a/tests/test_additional_converters.py b/tests/test_additional_converters.py new file mode 100644 index 0000000..88f4975 --- /dev/null +++ b/tests/test_additional_converters.py @@ -0,0 +1,172 @@ +import pytest + +from disagreement.ext.commands.converters import run_converters +from disagreement.ext.commands.core import CommandContext, Command +from disagreement.ext.commands.errors import BadArgument +from disagreement.models import Message, Member, Role, Guild +from disagreement.enums import ( + VerificationLevel, + MessageNotificationLevel, + ExplicitContentFilterLevel, + MFALevel, + GuildNSFWLevel, + PremiumTier, +) + + +class DummyBot: + def __init__(self, guild: Guild): + self._guilds = {guild.id: guild} + + def get_guild(self, gid): + return self._guilds.get(gid) + + async def fetch_member(self, gid, mid): + guild = self._guilds.get(gid) + return guild.get_member(mid) if guild else None + + async def fetch_role(self, gid, rid): + guild = self._guilds.get(gid) + return guild.get_role(rid) if guild else None + + async def fetch_guild(self, gid): + return self._guilds.get(gid) + + +@pytest.fixture() +def guild_objects(): + guild_data = { + "id": "1", + "name": "g", + "owner_id": "2", + "afk_timeout": 60, + "verification_level": VerificationLevel.NONE.value, + "default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value, + "explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value, + "roles": [], + "emojis": [], + "features": [], + "mfa_level": MFALevel.NONE.value, + "system_channel_flags": 0, + "premium_tier": PremiumTier.NONE.value, + "nsfw_level": GuildNSFWLevel.DEFAULT.value, + } + guild = Guild(guild_data, client_instance=None) + + member = Member( + { + "user": {"id": "3", "username": "m", "discriminator": "0001"}, + "joined_at": "t", + "roles": [], + }, + None, + ) + member.guild_id = guild.id + + role = Role( + { + "id": "5", + "name": "r", + "color": 0, + "hoist": False, + "position": 0, + "permissions": "0", + "managed": False, + "mentionable": True, + } + ) + + guild._members[member.id] = member + guild.roles.append(role) + + return guild, member, role + + +@pytest.fixture() +def command_context(guild_objects): + guild, member, role = guild_objects + bot = DummyBot(guild) + message_data = { + "id": "10", + "channel_id": "20", + "guild_id": guild.id, + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + msg = Message(message_data, client_instance=bot) + + async def dummy(ctx): + pass + + cmd = Command(dummy) + return CommandContext( + message=msg, bot=bot, prefix="!", command=cmd, invoked_with="dummy" + ) + + +@pytest.mark.asyncio +async def test_member_converter(command_context, guild_objects): + _, member, _ = guild_objects + mention = f"<@!{member.id}>" + result = await run_converters(command_context, Member, mention) + assert result is member + result = await run_converters(command_context, Member, member.id) + assert result is member + + +@pytest.mark.asyncio +async def test_role_converter(command_context, guild_objects): + _, _, role = guild_objects + mention = f"<@&{role.id}>" + result = await run_converters(command_context, Role, mention) + assert result is role + result = await run_converters(command_context, Role, role.id) + assert result is role + + +@pytest.mark.asyncio +async def test_guild_converter(command_context, guild_objects): + guild, _, _ = guild_objects + result = await run_converters(command_context, Guild, guild.id) + assert result is guild + + +@pytest.mark.asyncio +async def test_member_converter_no_guild(): + guild_data = { + "id": "99", + "name": "g", + "owner_id": "2", + "afk_timeout": 60, + "verification_level": VerificationLevel.NONE.value, + "default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value, + "explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value, + "roles": [], + "emojis": [], + "features": [], + "mfa_level": MFALevel.NONE.value, + "system_channel_flags": 0, + "premium_tier": PremiumTier.NONE.value, + "nsfw_level": GuildNSFWLevel.DEFAULT.value, + } + guild = Guild(guild_data, client_instance=None) + bot = DummyBot(guild) + message_data = { + "id": "11", + "channel_id": "20", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + msg = Message(message_data, client_instance=bot) + + async def dummy(ctx): + pass + + ctx = CommandContext( + message=msg, bot=bot, prefix="!", command=Command(dummy), invoked_with="dummy" + ) + + with pytest.raises(BadArgument): + await run_converters(ctx, Member, "<@!1>") diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..234077e --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,17 @@ +import time + +from disagreement.cache import Cache + + +def test_cache_store_and_get(): + cache = Cache() + cache.set("a", 123) + assert cache.get("a") == 123 + + +def test_cache_ttl_expiry(): + cache = Cache(ttl=0.01) + cache.set("b", 1) + assert cache.get("b") == 1 + time.sleep(0.02) + assert cache.get("b") is None diff --git a/tests/test_client_context_manager.py b/tests/test_client_context_manager.py new file mode 100644 index 0000000..ba4ef7f --- /dev/null +++ b/tests/test_client_context_manager.py @@ -0,0 +1,33 @@ +import asyncio +import pytest +from unittest.mock import AsyncMock + +from disagreement.client import Client + + +@pytest.mark.asyncio +async def test_client_async_context_closes(monkeypatch): + client = Client(token="t") + monkeypatch.setattr(client, "connect", AsyncMock()) + monkeypatch.setattr(client._http, "close", AsyncMock()) + + async with client: + client.connect.assert_awaited_once() + + client._http.close.assert_awaited_once() + assert client.is_closed() + + +@pytest.mark.asyncio +async def test_client_async_context_closes_on_exception(monkeypatch): + client = Client(token="t") + monkeypatch.setattr(client, "connect", AsyncMock()) + monkeypatch.setattr(client._http, "close", AsyncMock()) + + with pytest.raises(ValueError): + async with client: + raise ValueError("boom") + + client.connect.assert_awaited_once() + client._http.close.assert_awaited_once() + assert client.is_closed() diff --git a/tests/test_command_checks.py b/tests/test_command_checks.py new file mode 100644 index 0000000..50754f2 --- /dev/null +++ b/tests/test_command_checks.py @@ -0,0 +1,51 @@ +import asyncio +import pytest + +from disagreement.ext.commands.core import Command, CommandContext +from disagreement.ext.commands.decorators import check, cooldown +from disagreement.ext.commands.errors import CheckFailure, CommandOnCooldown + + +@pytest.mark.asyncio +async def test_check_decorator_blocks(message): + async def cb(ctx): + pass + + cmd = Command(check(lambda c: False)(cb)) + ctx = CommandContext( + message=message, + bot=message._client, + prefix="!", + command=cmd, + invoked_with="test", + ) + + with pytest.raises(CheckFailure): + await cmd.invoke(ctx) + + +@pytest.mark.asyncio +async def test_cooldown_per_user(message): + uses = [] + + @cooldown(1, 0.05) + async def cb(ctx): + uses.append(1) + + cmd = Command(cb) + ctx = CommandContext( + message=message, + bot=message._client, + prefix="!", + command=cmd, + invoked_with="test", + ) + + await cmd.invoke(ctx) + + with pytest.raises(CommandOnCooldown): + await cmd.invoke(ctx) + + await asyncio.sleep(0.05) + await cmd.invoke(ctx) + assert len(uses) == 2 diff --git a/tests/test_components_factory.py b/tests/test_components_factory.py new file mode 100644 index 0000000..c1cdf12 --- /dev/null +++ b/tests/test_components_factory.py @@ -0,0 +1,29 @@ +from disagreement.components import component_factory +from disagreement.enums import ComponentType, ButtonStyle + + +def test_component_factory_button(): + data = { + "type": ComponentType.BUTTON.value, + "style": ButtonStyle.PRIMARY.value, + "label": "Click", + "custom_id": "x", + } + comp = component_factory(data) + assert comp.to_dict()["label"] == "Click" + + +def test_component_factory_action_row(): + data = { + "type": ComponentType.ACTION_ROW.value, + "components": [ + { + "type": ComponentType.BUTTON.value, + "style": ButtonStyle.PRIMARY.value, + "label": "A", + "custom_id": "b", + } + ], + } + row = component_factory(data) + assert len(row.components) == 1 diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..463a00a --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,199 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from disagreement.ext.app_commands.context import AppCommandContext +from disagreement.interactions import Interaction +from disagreement.enums import InteractionType, MessageFlags +from disagreement.models import ( + Embed, + ActionRow, + Button, + Container, + TextDisplay, +) +from disagreement.enums import ButtonStyle, ComponentType + + +class DummyHTTP: + def __init__(self): + self.create_interaction_response = AsyncMock() + self.create_followup_message = AsyncMock(return_value={"id": "123"}) + self.edit_original_interaction_response = AsyncMock(return_value={"id": "123"}) + self.edit_followup_message = AsyncMock(return_value={"id": "123"}) + + +class DummyBot: + def __init__(self): + self._http = DummyHTTP() + self.application_id = "app123" + self._guilds = {} + self._channels = {} + + def get_guild(self, gid): + return self._guilds.get(gid) + + async def fetch_channel(self, cid): + return self._channels.get(cid) +from disagreement.ext.commands.core import CommandContext, Command +from disagreement.enums import MessageFlags, ButtonStyle, ComponentType +from disagreement.models import ActionRow, Button, Container, TextDisplay + + +@pytest.mark.asyncio +async def test_sends_extra_payload(dummy_bot, interaction): + ctx = AppCommandContext(dummy_bot, interaction) + button = Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="a") + row = ActionRow([button]) + await ctx.send( + content="hi", + tts=True, + components=[row], + allowed_mentions={"parse": []}, + files=[{"id": 1, "filename": "f.txt"}], + ephemeral=True, + ) + dummy_bot._http.create_interaction_response.assert_called_once() + payload = dummy_bot._http.create_interaction_response.call_args.kwargs[ + "payload" + ].data.to_dict() + assert payload["tts"] is True + assert payload["components"] + assert payload["allowed_mentions"] == {"parse": []} + assert payload["attachments"] == [{"id": 1, "filename": "f.txt"}] + assert payload["flags"] == MessageFlags.EPHEMERAL.value + await ctx.send_followup(content="again") + dummy_bot._http.create_followup_message.assert_called_once() + + +@pytest.mark.asyncio +async def test_second_send_is_followup(dummy_bot, interaction): + ctx = AppCommandContext(dummy_bot, interaction) + await ctx.send(content="first") + await ctx.send_followup(content="second") + assert dummy_bot._http.create_interaction_response.call_count == 1 + assert dummy_bot._http.create_followup_message.call_count == 1 + + +@pytest.mark.asyncio +async def test_edit_with_components_and_attachments(dummy_bot, interaction): + ctx = AppCommandContext(dummy_bot, interaction) + await ctx.send(content="orig") + row = ActionRow([Button(style=ButtonStyle.PRIMARY, label="B", custom_id="b")]) + await ctx.edit(content="new", components=[row], attachments=[{"id": 1}]) + dummy_bot._http.edit_original_interaction_response.assert_called_once() + payload = dummy_bot._http.edit_original_interaction_response.call_args.kwargs[ + "payload" + ] + assert payload["components"] + assert payload["attachments"] == [{"id": 1}] + + +@pytest.mark.asyncio +async def test_send_with_flags(dummy_bot, interaction): + ctx = AppCommandContext(dummy_bot, interaction) + await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value) + payload = dummy_bot._http.create_interaction_response.call_args.kwargs[ + "payload" + ].data.to_dict() + assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value + + +@pytest.mark.asyncio +async def test_send_container_component(dummy_bot, interaction): + ctx = AppCommandContext(dummy_bot, interaction) + container = Container(components=[TextDisplay(content="hi")]) + await ctx.send(components=[container], flags=MessageFlags.IS_COMPONENTS_V2.value) + payload = dummy_bot._http.create_interaction_response.call_args.kwargs[ + "payload" + ].data.to_dict() + assert payload["components"][0]["type"] == ComponentType.CONTAINER.value + assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value + + +@pytest.mark.asyncio +async def test_command_context_edit(command_bot, message): + async def dummy(ctx): + pass + + cmd = Command(dummy) + ctx = CommandContext( + message=message, bot=command_bot, prefix="!", command=cmd, invoked_with="dummy" + ) + await ctx.edit(message.id, content="new") + command_bot._http.edit_message.assert_called_once() + args = command_bot._http.edit_message.call_args[0] + assert args[0] == message.channel_id + assert args[1] == message.id + assert args[2]["content"] == "new" + + +@pytest.mark.asyncio +async def test_send_http_error_propagates(dummy_bot, interaction): + ctx = AppCommandContext(dummy_bot, interaction) + dummy_bot._http.create_interaction_response.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError): + await ctx.send(content="hi") + + +@pytest.mark.asyncio +async def test_concurrent_send_only_initial_once(dummy_bot, interaction): + ctx = AppCommandContext(dummy_bot, interaction) + + async def send_msg(i: int): + if i == 0: + await ctx.send(content=str(i)) + else: + await ctx.send_followup(content=str(i)) + + await asyncio.gather(*(send_msg(i) for i in range(50))) + assert dummy_bot._http.create_interaction_response.call_count == 1 + assert dummy_bot._http.create_followup_message.call_count == 49 + + +@pytest.mark.asyncio +async def test_send_with_flags_2(): + bot = DummyBot() + interaction = Interaction( + { + "id": "1", + "application_id": bot.application_id, + "type": InteractionType.APPLICATION_COMMAND.value, + "token": "tok", + "version": 1, + }, + client_instance=bot, + ) + ctx = AppCommandContext(bot, interaction) + await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value) + payload = bot._http.create_interaction_response.call_args[1][ + "payload" + ].data.to_dict() + assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value + + +@pytest.mark.asyncio +async def test_send_container_component_2(): + bot = DummyBot() + interaction = Interaction( + { + "id": "1", + "application_id": bot.application_id, + "type": InteractionType.APPLICATION_COMMAND.value, + "token": "tok", + "version": 1, + }, + client_instance=bot, + ) + ctx = AppCommandContext(bot, interaction) + container = Container(components=[TextDisplay(content="hi")]) + await ctx.send( + components=[container], + flags=MessageFlags.IS_COMPONENTS_V2.value, + ) + payload = bot._http.create_interaction_response.call_args[1][ + "payload" + ].data.to_dict() + assert payload["components"][0]["type"] == ComponentType.CONTAINER.value + assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value diff --git a/tests/test_context_menus.py b/tests/test_context_menus.py new file mode 100644 index 0000000..8b88d65 --- /dev/null +++ b/tests/test_context_menus.py @@ -0,0 +1,80 @@ +import pytest + +from disagreement.ext.app_commands.handler import AppCommandHandler +from disagreement.ext.app_commands.decorators import user_command, message_command +from disagreement.enums import ApplicationCommandType, InteractionType +from disagreement.interactions import Interaction +from disagreement.models import User, Message + + +@pytest.mark.asyncio +async def test_user_context_menu_invokes(dummy_bot): + handler = AppCommandHandler(dummy_bot) + captured = {} + + @user_command(name="Info") + async def info(ctx, user: User): + captured["user"] = user + + handler.add_command(info) + + data = { + "id": "cmd", + "name": "Info", + "type": ApplicationCommandType.USER.value, + "target_id": "42", + "resolved": { + "users": {"42": {"id": "42", "username": "Bob", "discriminator": "0001"}} + }, + } + payload = { + "id": "1", + "application_id": dummy_bot.application_id, + "type": InteractionType.APPLICATION_COMMAND.value, + "token": "tok", + "version": 1, + "data": data, + } + interaction = Interaction(payload, client_instance=dummy_bot) + await handler.process_interaction(interaction) + assert isinstance(captured.get("user"), User) + assert captured["user"].id == "42" + + +@pytest.mark.asyncio +async def test_message_context_menu_invokes(dummy_bot): + handler = AppCommandHandler(dummy_bot) + captured = {} + + @message_command(name="Quote") + async def quote(ctx, message: Message): + captured["msg"] = message + + handler.add_command(quote) + + msg_data = { + "id": "99", + "channel_id": "c", + "author": {"id": "2", "username": "Ann", "discriminator": "0001"}, + "content": "Hello", + "timestamp": "t", + } + data = { + "id": "cmd", + "name": "Quote", + "type": ApplicationCommandType.MESSAGE.value, + "target_id": "99", + "resolved": {"messages": {"99": msg_data}}, + } + payload = { + "id": "1", + "application_id": dummy_bot.application_id, + "type": InteractionType.APPLICATION_COMMAND.value, + "token": "tok", + "version": 1, + "data": data, + } + interaction = Interaction(payload, client_instance=dummy_bot) + await handler.process_interaction(interaction) + assert isinstance(captured.get("msg"), Message) + assert captured["msg"].id == "99" diff --git a/tests/test_converter_registration.py b/tests/test_converter_registration.py new file mode 100644 index 0000000..c2c4e49 --- /dev/null +++ b/tests/test_converter_registration.py @@ -0,0 +1,30 @@ +import pytest + +from disagreement.ext.app_commands.handler import AppCommandHandler +from disagreement.ext.app_commands.converters import Converter + + +class MyType: + def __init__(self, value): + self.value = value + + +class MyConverter(Converter[MyType]): + async def convert(self, interaction, value): + return MyType(f"converted-{value}") + + +@pytest.mark.asyncio +async def test_custom_converter_registration(dummy_client): + handler = AppCommandHandler(client=dummy_client) + handler.register_converter(MyType, MyConverter) + assert handler.get_converter(MyType) is MyConverter + + result = await handler._resolve_value( + value="example", + expected_type=MyType, + resolved_data=None, + guild_id=None, + ) + assert isinstance(result, MyType) + assert result.value == "converted-example" diff --git a/tests/test_converters.py b/tests/test_converters.py new file mode 100644 index 0000000..90792c3 --- /dev/null +++ b/tests/test_converters.py @@ -0,0 +1,113 @@ +import asyncio +import pytest +from hypothesis import given, strategies as st + +from disagreement.ext.app_commands.converters import run_converters +from disagreement.enums import ApplicationCommandOptionType +from disagreement.errors import AppCommandOptionConversionError +from conftest import DummyInteraction, DummyClient + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "py_type, option_type, input_value, expected", + [ + (str, ApplicationCommandOptionType.STRING, "hello", "hello"), + (int, ApplicationCommandOptionType.INTEGER, "42", 42), + (bool, ApplicationCommandOptionType.BOOLEAN, "true", True), + (float, ApplicationCommandOptionType.NUMBER, "3.14", pytest.approx(3.14)), + ], +) +async def test_basic_type_converters( + basic_interaction, dummy_client, py_type, option_type, input_value, expected +): + result = await run_converters( + basic_interaction, py_type, option_type, input_value, dummy_client + ) + assert result == expected + + +@pytest.mark.asyncio +async def test_run_converters_error_cases(basic_interaction, dummy_client): + with pytest.raises(AppCommandOptionConversionError): + await run_converters( + basic_interaction, + bool, + ApplicationCommandOptionType.BOOLEAN, + "maybe", + dummy_client, + ) + + with pytest.raises(AppCommandOptionConversionError): + await run_converters( + basic_interaction, + list, + ApplicationCommandOptionType.MENTIONABLE, + "x", + dummy_client, + ) + + +@given(st.text()) +def test_string_roundtrip(value): + interaction = DummyInteraction() + client = DummyClient() + result = asyncio.run( + run_converters( + interaction, + str, + ApplicationCommandOptionType.STRING, + value, + client, + ) + ) + assert result == value + + +@given(st.integers()) +def test_integer_roundtrip(value): + interaction = DummyInteraction() + client = DummyClient() + result = asyncio.run( + run_converters( + interaction, + int, + ApplicationCommandOptionType.INTEGER, + str(value), + client, + ) + ) + assert result == value + + +@given(st.booleans()) +def test_boolean_roundtrip(value): + interaction = DummyInteraction() + client = DummyClient() + raw = "true" if value else "false" + result = asyncio.run( + run_converters( + interaction, + bool, + ApplicationCommandOptionType.BOOLEAN, + raw, + client, + ) + ) + assert result is value + + +@given(st.floats(allow_nan=False, allow_infinity=False)) +def test_number_roundtrip(value): + interaction = DummyInteraction() + client = DummyClient() + result = asyncio.run( + run_converters( + interaction, + float, + ApplicationCommandOptionType.NUMBER, + str(value), + client, + ) + ) + assert result == pytest.approx(value) diff --git a/tests/test_error_handler.py b/tests/test_error_handler.py new file mode 100644 index 0000000..49be211 --- /dev/null +++ b/tests/test_error_handler.py @@ -0,0 +1,38 @@ +import asyncio +import logging +from types import SimpleNamespace + +import pytest + +from disagreement.error_handler import setup_global_error_handler + + +@pytest.mark.asyncio +async def test_handle_exception_logs_error(monkeypatch, capsys): + loop = asyncio.new_event_loop() + records = [] + + def fake_error(msg, *args, **kwargs): + records.append(msg % args if args else msg) + + monkeypatch.setattr(logging, "error", fake_error) + setup_global_error_handler(loop) + exc = RuntimeError("boom") + loop.call_exception_handler({"exception": exc}) + assert any("Unhandled exception" in r for r in records) + loop.close() + + +@pytest.mark.asyncio +async def test_handle_message_logs_error(monkeypatch): + loop = asyncio.new_event_loop() + logged = {} + + def fake_error(msg, *args, **kwargs): + logged["msg"] = msg % args if args else msg + + monkeypatch.setattr(logging, "error", fake_error) + setup_global_error_handler(loop) + loop.call_exception_handler({"message": "oops"}) + assert "oops" in logged["msg"] + loop.close() diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..244f3d5 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,23 @@ +import pytest + +from disagreement.errors import ( + HTTPException, + RateLimitError, + AppCommandOptionConversionError, +) + + +def test_http_exception_message(): + exc = HTTPException(message="Bad", status=400) + assert str(exc) == "HTTP 400: Bad" + + +def test_rate_limit_error_inherits_httpexception(): + exc = RateLimitError(response=None, retry_after=1.0, is_global=True) + assert isinstance(exc, HTTPException) + assert "Rate limited" in str(exc) + + +def test_app_command_option_conversion_error(): + exc = AppCommandOptionConversionError("bad", option_name="opt", original_value="x") + assert "opt" in str(exc) and "x" in str(exc) diff --git a/tests/test_event_dispatcher.py b/tests/test_event_dispatcher.py new file mode 100644 index 0000000..a685b31 --- /dev/null +++ b/tests/test_event_dispatcher.py @@ -0,0 +1,68 @@ +import asyncio + +import pytest + +from disagreement.event_dispatcher import EventDispatcher + + +class DummyClient: + def __init__(self): + self.parsed = {} + + def parse_message(self, data): + self.parsed["message"] = True + return data + + def parse_guild(self, data): + self.parsed["guild"] = True + return data + + def parse_channel(self, data): + self.parsed["channel"] = True + return data + + +@pytest.mark.asyncio +async def test_dispatch_calls_listener(): + client = DummyClient() + dispatcher = EventDispatcher(client) + called = {} + + async def listener(payload): + called["data"] = payload + + dispatcher.register("MESSAGE_CREATE", listener) + await dispatcher.dispatch("MESSAGE_CREATE", {"id": 1}) + assert called["data"] == {"id": 1} + assert client.parsed.get("message") + + +@pytest.mark.asyncio +async def test_dispatch_listener_no_args(): + client = DummyClient() + dispatcher = EventDispatcher(client) + called = False + + async def listener(): + nonlocal called + called = True + + dispatcher.register("GUILD_CREATE", listener) + await dispatcher.dispatch("GUILD_CREATE", {"id": 123}) + assert called + + +@pytest.mark.asyncio +async def test_unregister_listener(): + client = DummyClient() + dispatcher = EventDispatcher(client) + called = False + + async def listener(_): + nonlocal called + called = True + + dispatcher.register("MESSAGE_CREATE", listener) + dispatcher.unregister("MESSAGE_CREATE", listener) + await dispatcher.dispatch("MESSAGE_CREATE", {"id": 1}) + assert not called diff --git a/tests/test_event_error_hook.py b/tests/test_event_error_hook.py new file mode 100644 index 0000000..c08567c --- /dev/null +++ b/tests/test_event_error_hook.py @@ -0,0 +1,42 @@ +import pytest +from unittest.mock import AsyncMock + +from disagreement.event_dispatcher import EventDispatcher + + +class DummyClient: + pass + + +@pytest.mark.asyncio +async def test_dispatch_error_hook_called(): + dispatcher = EventDispatcher(DummyClient()) + hook = AsyncMock() + dispatcher.on_dispatch_error = hook + + async def listener(_): + raise RuntimeError("boom") + + dispatcher.register("TEST_EVENT", listener) + await dispatcher.dispatch("TEST_EVENT", {}) + + hook.assert_awaited_once() + args = hook.call_args.args + assert args[0] == "TEST_EVENT" + assert isinstance(args[1], RuntimeError) + assert args[2] is listener + + +@pytest.mark.asyncio +async def test_dispatch_error_hook_not_called_when_ok(): + dispatcher = EventDispatcher(DummyClient()) + hook = AsyncMock() + dispatcher.on_dispatch_error = hook + + async def listener(_): + return + + dispatcher.register("TEST_EVENT", listener) + await dispatcher.dispatch("TEST_EVENT", {}) + + hook.assert_not_awaited() diff --git a/tests/test_extension_loader.py b/tests/test_extension_loader.py new file mode 100644 index 0000000..37f45f6 --- /dev/null +++ b/tests/test_extension_loader.py @@ -0,0 +1,44 @@ +import sys +import types + +import pytest + +from disagreement.ext import loader + + +def create_dummy_module(name): + mod = types.ModuleType(name) + called = {"setup": False, "teardown": False} + + def setup(): + called["setup"] = True + + def teardown(): + called["teardown"] = True + + mod.setup = setup + mod.teardown = teardown + sys.modules[name] = mod + return called + + +def test_load_and_unload_extension(): + called = create_dummy_module("dummy_ext") + + module = loader.load_extension("dummy_ext") + assert module is sys.modules["dummy_ext"] + assert called["setup"] is True + + loader.unload_extension("dummy_ext") + assert called["teardown"] is True + assert "dummy_ext" not in loader._loaded_extensions + assert "dummy_ext" not in sys.modules + + +def test_load_extension_twice_raises(): + called = create_dummy_module("repeat_ext") + loader.load_extension("repeat_ext") + with pytest.raises(ValueError): + loader.load_extension("repeat_ext") + loader.unload_extension("repeat_ext") + assert called["teardown"] is True diff --git a/tests/test_gateway_backoff.py b/tests/test_gateway_backoff.py new file mode 100644 index 0000000..c5dbb21 --- /dev/null +++ b/tests/test_gateway_backoff.py @@ -0,0 +1,67 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from disagreement.gateway import GatewayClient, GatewayException +from disagreement.client import Client + + +class DummyHTTP: + async def get_gateway_bot(self): + return {"url": "ws://example"} + + async def _ensure_session(self): + self._session = AsyncMock() + self._session.ws_connect = AsyncMock() + + +class DummyDispatcher: + async def dispatch(self, *_): + pass + + +class DummyClient: + def __init__(self): + self.loop = asyncio.get_event_loop() + self.application_id = None # Mock application_id for Client.connect + +@pytest.mark.asyncio +async def test_client_connect_backoff(monkeypatch): + http = DummyHTTP() + # Mock the GatewayClient's connect method to simulate failures and then success + mock_gateway_connect = AsyncMock( + side_effect=[GatewayException("boom"), GatewayException("boom"), None] + ) + # Create a dummy client instance + client = Client( + token="test_token", + intents=0, + loop=asyncio.get_event_loop(), + command_prefix="!", + verbose=False, + mention_replies=False, + shard_count=None, + ) + # Patch the internal _gateway attribute after client initialization + # This ensures _initialize_gateway is called and _gateway is set + await client._initialize_gateway() + monkeypatch.setattr(client._gateway, "connect", mock_gateway_connect) + + # Mock wait_until_ready to prevent it from blocking the test + monkeypatch.setattr(client, "wait_until_ready", AsyncMock()) + + delays = [] + + async def fake_sleep(d): + delays.append(d) + + monkeypatch.setattr(asyncio, "sleep", fake_sleep) + + # Call the client's connect method, which contains the backoff logic + await client.connect() + + # Assert that GatewayClient.connect was called the correct number of times + assert mock_gateway_connect.call_count == 3 + # Assert the delays experienced due to exponential backoff + assert delays == [5, 10] diff --git a/tests/test_help_command.py b/tests/test_help_command.py new file mode 100644 index 0000000..23a2c7a --- /dev/null +++ b/tests/test_help_command.py @@ -0,0 +1,57 @@ +import pytest + +from disagreement.ext.commands.core import CommandHandler, Command +from disagreement.models import Message + + +class DummyBot: + def __init__(self): + self.sent = [] + + async def send_message(self, channel_id, content, **kwargs): + self.sent.append(content) + return {"id": "1", "channel_id": channel_id, "content": content} + + +@pytest.mark.asyncio +async def test_help_lists_commands(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + + async def foo(ctx): + pass + + handler.add_command(Command(foo, name="foo", brief="Foo cmd")) + + msg_data = { + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "!help", + "timestamp": "t", + } + msg = Message(msg_data, client_instance=bot) + await handler.process_commands(msg) + assert any("foo" in m for m in bot.sent) + + +@pytest.mark.asyncio +async def test_help_specific_command(): + bot = DummyBot() + handler = CommandHandler(client=bot, prefix="!") + + async def bar(ctx): + pass + + handler.add_command(Command(bar, name="bar", description="Bar desc")) + + msg_data = { + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "!help bar", + "timestamp": "t", + } + msg = Message(msg_data, client_instance=bot) + await handler.process_commands(msg) + assert any("Bar desc" in m for m in bot.sent) diff --git a/tests/test_http_reactions.py b/tests/test_http_reactions.py new file mode 100644 index 0000000..101875f --- /dev/null +++ b/tests/test_http_reactions.py @@ -0,0 +1,57 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.client import Client +from disagreement.errors import DisagreementException +from disagreement.models import User + + +@pytest.mark.asyncio +async def test_create_reaction_calls_http(): + http = SimpleNamespace(create_reaction=AsyncMock()) + client = Client.__new__(Client) + client._http = http + client._closed = False + + await client.create_reaction("1", "2", "😀") + + http.create_reaction.assert_called_once_with("1", "2", "😀") + + +@pytest.mark.asyncio +async def test_create_reaction_closed(): + http = SimpleNamespace(create_reaction=AsyncMock()) + client = Client.__new__(Client) + client._http = http + client._closed = True + + with pytest.raises(DisagreementException): + await client.create_reaction("1", "2", "😀") + + +@pytest.mark.asyncio +async def test_delete_reaction_calls_http(): + http = SimpleNamespace(delete_reaction=AsyncMock()) + client = Client.__new__(Client) + client._http = http + client._closed = False + + await client.delete_reaction("1", "2", "😀") + + http.delete_reaction.assert_called_once_with("1", "2", "😀") + + +@pytest.mark.asyncio +async def test_get_reactions_parses_users(): + users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}] + http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload)) + client = Client.__new__(Client) + client._http = http + client._closed = False + client._users = {} + + users = await client.get_reactions("1", "2", "😀") + + http.get_reactions.assert_called_once_with("1", "2", "😀") + assert isinstance(users[0], User) diff --git a/tests/test_hybrid_context.py b/tests/test_hybrid_context.py new file mode 100644 index 0000000..1a9973d --- /dev/null +++ b/tests/test_hybrid_context.py @@ -0,0 +1,44 @@ +import asyncio +import pytest + +from disagreement.hybrid_context import HybridContext +from disagreement.ext.app_commands.context import AppCommandContext + + +class DummyCommandCtx: + def __init__(self): + self.sent = [] + + async def reply(self, *a, **kw): + self.sent.append(("reply", a, kw)) + + async def edit(self, *a, **kw): + self.sent.append(("edit", a, kw)) + + +class DummyAppCtx(AppCommandContext): + def __init__(self): + self.sent = [] + + async def send(self, *a, **kw): + self.sent.append(("send", a, kw)) + + +@pytest.mark.asyncio +async def test_send_routes_based_on_context(): + cctx = DummyCommandCtx() + actx = DummyAppCtx() + await HybridContext(cctx).send("hi") + await HybridContext(actx).send("hi") + assert cctx.sent[0][0] == "reply" + assert actx.sent[0][0] == "send" + + +@pytest.mark.asyncio +async def test_edit_delegation_and_error(): + cctx = DummyCommandCtx() + hctx = HybridContext(cctx) + await hctx.edit("m") + assert cctx.sent[0][0] == "edit" + with pytest.raises(AttributeError): + await HybridContext(DummyAppCtx()).edit("m") diff --git a/tests/test_i18n.py b/tests/test_i18n.py new file mode 100644 index 0000000..80a92f7 --- /dev/null +++ b/tests/test_i18n.py @@ -0,0 +1,21 @@ +import pytest # pylint: disable=import-error + +from disagreement.i18n import set_translations, translate +from disagreement.ext.app_commands.commands import SlashCommand + + +async def dummy(ctx): + pass + + +def test_translate_lookup(): + set_translations("xx", {"hello": "bonjour"}) + assert translate("hello", "xx") == "bonjour" + assert translate("missing", "xx") == "missing" + + +def test_appcommand_uses_locale(): + set_translations("xx", {"cmd": "c", "desc": "d"}) + cmd = SlashCommand(dummy, name="cmd", description="desc", locale="xx") + assert cmd.name == "c" + assert cmd.description == "d" diff --git a/tests/test_interaction.py b/tests/test_interaction.py new file mode 100644 index 0000000..812fa45 --- /dev/null +++ b/tests/test_interaction.py @@ -0,0 +1,19 @@ +import pytest + +from disagreement.models import Embed + + +@pytest.mark.asyncio +async def test_edit_calls_http_with_payload(dummy_bot, interaction): + await interaction.edit(content="updated") + dummy_bot._http.edit_original_interaction_response.assert_called_once() + kwargs = dummy_bot._http.edit_original_interaction_response.call_args.kwargs + assert kwargs["application_id"] == dummy_bot.application_id + assert kwargs["interaction_token"] == interaction.token + assert kwargs["payload"] == {"content": "updated"} + + +@pytest.mark.asyncio +async def test_edit_embed_and_embeds_raises(dummy_bot, interaction): + with pytest.raises(ValueError): + await interaction.edit(embed=Embed(), embeds=[Embed()]) diff --git a/tests/test_logging_config.py b/tests/test_logging_config.py new file mode 100644 index 0000000..61e73fa --- /dev/null +++ b/tests/test_logging_config.py @@ -0,0 +1,30 @@ +import logging +from disagreement.logging_config import setup_logging + + +def test_setup_logging_sets_level(tmp_path): + root_logger = logging.getLogger() + original_handlers = root_logger.handlers.copy() + root_logger.handlers.clear() + try: + setup_logging(logging.INFO) + assert root_logger.level == logging.INFO + assert root_logger.handlers + assert isinstance(root_logger.handlers[0], logging.StreamHandler) + finally: + root_logger.handlers.clear() + root_logger.handlers.extend(original_handlers) + + +def test_setup_logging_file(tmp_path): + log_file = tmp_path / "test.log" + root_logger = logging.getLogger() + original_handlers = root_logger.handlers.copy() + root_logger.handlers.clear() + try: + setup_logging(logging.WARNING, file=str(log_file)) + logging.warning("hello") + assert log_file.read_text() + finally: + root_logger.handlers.clear() + root_logger.handlers.extend(original_handlers) diff --git a/tests/test_modal_send.py b/tests/test_modal_send.py new file mode 100644 index 0000000..bceeb0b --- /dev/null +++ b/tests/test_modal_send.py @@ -0,0 +1,16 @@ +import pytest + +from disagreement.ui import Modal, TextInput + + +class MyModal(Modal): + def __init__(self): + super().__init__(title="T", custom_id="m") + self.input = TextInput(label="L", custom_id="i") + + +@pytest.mark.asyncio +async def test_send_modal(dummy_bot, interaction): + modal = MyModal() + await interaction.response.send_modal(modal) + dummy_bot._http.create_interaction_response.assert_called_once() diff --git a/tests/test_modals.py b/tests/test_modals.py new file mode 100644 index 0000000..3d1fdff --- /dev/null +++ b/tests/test_modals.py @@ -0,0 +1,32 @@ +import pytest + +from disagreement.interactions import Interaction +from disagreement.ui import Modal, text_input +from disagreement.enums import InteractionCallbackType, TextInputStyle + + +class MyModal(Modal): + def __init__(self): + super().__init__(title="Test", custom_id="m1") + + @text_input(label="Name", style=TextInputStyle.SHORT) + async def name(self, interaction: Interaction): + pass + + +def test_modal_to_dict(): + modal = MyModal() + data = modal.to_dict() + assert data["title"] == "Test" + assert data["custom_id"] == "m1" + assert data["components"][0]["components"][0]["label"] == "Name" + + +@pytest.mark.asyncio +async def test_respond_modal(dummy_bot, interaction): + modal = MyModal() + await interaction.respond_modal(modal) + dummy_bot._http.create_interaction_response.assert_called_once() + payload = dummy_bot._http.create_interaction_response.call_args.kwargs["payload"] + assert payload["type"] == InteractionCallbackType.MODAL.value + assert payload["data"]["custom_id"] == "m1" diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 0000000..6221116 --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,114 @@ +import pytest +import aiohttp +from unittest.mock import AsyncMock + +from disagreement.http import HTTPClient +from disagreement import oauth + + +@pytest.mark.asyncio +async def test_build_authorization_url(): + url = oauth.build_authorization_url( + client_id="123", + redirect_uri="https://example.com/cb", + scope=["identify", "guilds"], + state="xyz", + ) + assert url.startswith("https://discord.com/oauth2/authorize?") + assert "client_id=123" in url + assert "state=xyz" in url + assert "scope=identify+guilds" in url + + +@pytest.mark.asyncio +@pytest.mark.asyncio +async def test_exchange_code_for_token_makes_correct_request(monkeypatch): + mock_client_response = AsyncMock() + mock_client_response.status = 200 + mock_client_response.json = AsyncMock(return_value={"access_token": "a"}) + mock_client_response.__aenter__ = AsyncMock(return_value=mock_client_response) + mock_client_response.__aexit__ = AsyncMock(return_value=None) + + post_mock = AsyncMock(return_value=mock_client_response) + monkeypatch.setattr("aiohttp.ClientSession.post", post_mock) + + data = await oauth.exchange_code_for_token( + client_id="id", + client_secret="secret", + code="code", + redirect_uri="https://cb", + ) + + assert data == {"access_token": "a"} + post_mock.assert_called_once() + args, kwargs = post_mock.call_args + assert args[0] == "https://discord.com/api/v10/oauth2/token" + assert kwargs["headers"]["Content-Type"] == "application/x-www-form-urlencoded" + assert kwargs["data"]["grant_type"] == "authorization_code" + assert kwargs["data"]["client_id"] == "id" + + +@pytest.mark.asyncio +async def test_exchange_code_for_token_custom_session(): + mock_client_response = AsyncMock() + mock_client_response.status = 200 + mock_client_response.json = AsyncMock(return_value={"access_token": "x"}) + mock_client_response.__aenter__ = AsyncMock(return_value=mock_client_response) + mock_client_response.__aexit__ = AsyncMock(return_value=None) + + mock_session = AsyncMock() + mock_session.post = AsyncMock(return_value=mock_client_response) + + data = await oauth.exchange_code_for_token( + client_id="c1", + client_secret="c2", + code="code", + redirect_uri="https://cb", + session=mock_session, + ) + assert data == {"access_token": "x"} + mock_session.post.assert_called_once() + + +@pytest.mark.asyncio +async def test_refresh_access_token_success(monkeypatch): + mock_client_response = AsyncMock() + mock_client_response.status = 200 + mock_client_response.json = AsyncMock(return_value={"access_token": "b"}) + mock_client_response.__aenter__ = AsyncMock(return_value=mock_client_response) + mock_client_response.__aexit__ = AsyncMock(return_value=None) + + post_mock = AsyncMock(return_value=mock_client_response) + monkeypatch.setattr("aiohttp.ClientSession.post", post_mock) + + data = await oauth.refresh_access_token( + refresh_token="rt", + client_id="cid", + client_secret="sec", + ) + + assert data == {"access_token": "b"} + post_mock.assert_called_once() + args, kwargs = post_mock.call_args + assert args[0] == "https://discord.com/api/v10/oauth2/token" + assert kwargs["data"]["grant_type"] == "refresh_token" + assert kwargs["data"]["refresh_token"] == "rt" + + +@pytest.mark.asyncio +async def test_refresh_access_token_error(monkeypatch): + mock_client_response = AsyncMock() + mock_client_response.status = 400 + mock_client_response.json = AsyncMock(return_value={"error": "invalid"}) + mock_client_response.__aenter__ = AsyncMock(return_value=mock_client_response) + mock_client_response.__aexit__ = AsyncMock(return_value=None) + + post_mock = AsyncMock(return_value=mock_client_response) + monkeypatch.setattr("aiohttp.ClientSession.post", post_mock) + + with pytest.raises(oauth.HTTPException): + await oauth.refresh_access_token( + refresh_token="bad", + client_id="cid", + client_secret="sec", + ) diff --git a/tests/test_permissions.py b/tests/test_permissions.py new file mode 100644 index 0000000..4281311 --- /dev/null +++ b/tests/test_permissions.py @@ -0,0 +1,34 @@ +import pytest + +from disagreement.permissions import ( + Permissions, + has_permissions, + missing_permissions, + permissions_value, +) + + +def test_permissions_value_combination(): + perm = permissions_value(Permissions.SEND_MESSAGES, Permissions.MANAGE_MESSAGES) + assert perm == (Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES) + + +def test_has_permissions_true(): + current = Permissions.SEND_MESSAGES | Permissions.MANAGE_MESSAGES + assert has_permissions(current, Permissions.SEND_MESSAGES) + assert has_permissions( + current, Permissions.MANAGE_MESSAGES, Permissions.SEND_MESSAGES + ) + + +def test_has_permissions_false(): + current = Permissions.SEND_MESSAGES + assert not has_permissions(current, Permissions.MANAGE_MESSAGES) + + +def test_missing_permissions(): + current = Permissions.SEND_MESSAGES + missing = missing_permissions( + current, Permissions.SEND_MESSAGES, Permissions.MANAGE_MESSAGES + ) + assert missing == [Permissions.MANAGE_MESSAGES] diff --git a/tests/test_presence_and_typing.py b/tests/test_presence_and_typing.py new file mode 100644 index 0000000..8e75663 --- /dev/null +++ b/tests/test_presence_and_typing.py @@ -0,0 +1,39 @@ +import pytest + +from disagreement.event_dispatcher import EventDispatcher +from disagreement.models import PresenceUpdate, TypingStart + + +@pytest.mark.asyncio +async def test_presence_and_typing_parsing(dummy_client): + dispatcher = EventDispatcher(dummy_client) + events = {} + + async def on_presence(presence): + events["presence"] = presence + + async def on_typing(typing): + events["typing"] = typing + + dispatcher.register("PRESENCE_UPDATE", on_presence) + dispatcher.register("TYPING_START", on_typing) + + presence_data = { + "user": {"id": "1", "username": "u", "discriminator": "0001"}, + "guild_id": "g", + "status": "online", + "activities": [], + "client_status": {}, + } + typing_data = { + "channel_id": "c", + "user_id": "1", + "timestamp": 123, + } + await dispatcher.dispatch("PRESENCE_UPDATE", presence_data) + await dispatcher.dispatch("TYPING_START", typing_data) + + assert isinstance(events.get("presence"), PresenceUpdate) + assert events["presence"].status == "online" + assert isinstance(events.get("typing"), TypingStart) + assert events["typing"].channel_id == "c" diff --git a/tests/test_presence_update.py b/tests/test_presence_update.py new file mode 100644 index 0000000..cf43818 --- /dev/null +++ b/tests/test_presence_update.py @@ -0,0 +1,33 @@ +import pytest +from unittest.mock import AsyncMock + +from disagreement.client import Client +from disagreement.errors import DisagreementException + + +from unittest.mock import MagicMock + +class DummyGateway(MagicMock): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.update_presence = AsyncMock() + + +@pytest.mark.asyncio +async def test_change_presence_passes_arguments(): + client = Client(token="t") + client._gateway = DummyGateway() + + await client.change_presence(status="idle", activity_name="hi", activity_type=0) + + client._gateway.update_presence.assert_awaited_once_with( + status="idle", activity_name="hi", activity_type=0, since=0, afk=False + ) + + +@pytest.mark.asyncio +async def test_change_presence_when_closed(): + client = Client(token="t") + client._closed = True + with pytest.raises(DisagreementException): + await client.change_presence(status="online") diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py new file mode 100644 index 0000000..dfc2afa --- /dev/null +++ b/tests/test_rate_limiter.py @@ -0,0 +1,30 @@ +import asyncio +import time + +import pytest + +from disagreement.rate_limiter import RateLimiter + + +@pytest.mark.asyncio +async def test_route_rate_limit_sleep(): + rl = RateLimiter() + task = asyncio.create_task(rl.handle_rate_limit("GET:/a", 0.05, False)) + await asyncio.sleep(0) # ensure task starts + start = time.monotonic() + await rl.acquire("GET:/a") + duration = time.monotonic() - start + await task + assert duration >= 0.05 + + +@pytest.mark.asyncio +async def test_global_rate_limit_blocks_all_routes(): + rl = RateLimiter() + task = asyncio.create_task(rl.handle_rate_limit("GET:/a", 0.05, True)) + await asyncio.sleep(0) + start = time.monotonic() + await rl.acquire("POST:/b") + duration = time.monotonic() - start + await task + assert duration >= 0.05 diff --git a/tests/test_reactions.py b/tests/test_reactions.py new file mode 100644 index 0000000..b73ea37 --- /dev/null +++ b/tests/test_reactions.py @@ -0,0 +1,38 @@ +import pytest +from disagreement.event_dispatcher import EventDispatcher + + +@pytest.mark.asyncio +async def test_reaction_payload(): + # This test now checks the raw payload dictionary, as the Reaction model is removed. + data = { + "user_id": "1", + "channel_id": "2", + "message_id": "3", + "emoji": {"name": "😀", "id": None}, + } + # The "reaction" is just the data dictionary itself. + assert data["user_id"] == "1" + assert data["emoji"]["name"] == "😀" + + +@pytest.mark.asyncio +async def test_dispatch_reaction_event(dummy_client): + dispatcher = EventDispatcher(dummy_client) + captured = [] + + async def listener(payload: dict): + captured.append(payload) + + # The event name is now MESSAGE_REACTION_ADD as per the original test setup. + # If this were to fail, the next step would be to confirm the correct event name. + dispatcher.register("MESSAGE_REACTION_ADD", listener) + payload = { + "user_id": "1", + "channel_id": "2", + "message_id": "3", + "emoji": {"name": "👍", "id": None}, + } + await dispatcher.dispatch("MESSAGE_REACTION_ADD", payload) + assert len(captured) == 1 + assert isinstance(captured[0], dict) diff --git a/tests/test_sharding.py b/tests/test_sharding.py new file mode 100644 index 0000000..d69d209 --- /dev/null +++ b/tests/test_sharding.py @@ -0,0 +1,52 @@ +import pytest +from unittest.mock import AsyncMock + +from disagreement.shard_manager import ShardManager +from disagreement.client import Client + + +class DummyGateway: + def __init__(self, *args, **kwargs): + self.connect = AsyncMock() + self.close = AsyncMock() + + +class DummyClient: + def __init__(self): + self._http = object() + self._event_dispatcher = object() + self.token = "t" + self.intents = 0 + self.verbose = False + + +def test_shard_manager_creates_shards(monkeypatch): + monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway) + client = DummyClient() + manager = ShardManager(client, shard_count=3) + assert len(manager.shards) == 0 + manager._create_shards() + assert len(manager.shards) == 3 + + +@pytest.mark.asyncio +async def test_shard_manager_start_and_close(monkeypatch): + monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway) + client = DummyClient() + manager = ShardManager(client, shard_count=2) + await manager.start() + for shard in manager.shards: + shard.gateway.connect.assert_awaited_once() + await manager.close() + for shard in manager.shards: + shard.gateway.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_client_uses_shard_manager(monkeypatch): + dummy_manager = AsyncMock() + monkeypatch.setattr("disagreement.client.ShardManager", lambda c, n: dummy_manager) + c = Client(token="x", shard_count=2) + monkeypatch.setattr(c, "wait_until_ready", AsyncMock()) + await c.connect() + dummy_manager.start.assert_awaited_once() diff --git a/tests/test_slash_contexts.py b/tests/test_slash_contexts.py new file mode 100644 index 0000000..527dc12 --- /dev/null +++ b/tests/test_slash_contexts.py @@ -0,0 +1,15 @@ +import pytest + +from disagreement.ext.app_commands.decorators import slash_command +from disagreement.ext.app_commands.commands import SlashCommand +from disagreement.enums import InteractionContextType + + +async def dummy(ctx): + pass + + +def test_boolean_context_parameters(): + cmd = slash_command(description="test", dms=False, private_channels=False)(dummy) + assert isinstance(cmd, SlashCommand) + assert cmd.contexts == [InteractionContextType.GUILD] diff --git a/tests/test_tasks_extension.py b/tests/test_tasks_extension.py new file mode 100644 index 0000000..d8d453e --- /dev/null +++ b/tests/test_tasks_extension.py @@ -0,0 +1,24 @@ +import asyncio + +import pytest + +from disagreement.ext import tasks + + +class Dummy: + def __init__(self) -> None: + self.count = 0 + + @tasks.loop(seconds=0.01) + async def work(self) -> None: + self.count += 1 + + +@pytest.mark.asyncio +async def test_loop_runs_and_stops() -> None: + dummy = Dummy() + dummy.work.start() # pylint: disable=no-member + await asyncio.sleep(0.05) + dummy.work.stop() # pylint: disable=no-member + assert dummy.count >= 2 + assert not dummy.work.running # pylint: disable=no-member diff --git a/tests/test_typing_indicator.py b/tests/test_typing_indicator.py new file mode 100644 index 0000000..7c86e28 --- /dev/null +++ b/tests/test_typing_indicator.py @@ -0,0 +1,64 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.client import Client +from disagreement.errors import DisagreementException + + +@pytest.mark.asyncio +async def test_typing_context_manager_calls_http(): + http = SimpleNamespace(trigger_typing=AsyncMock()) + client = Client.__new__(Client) + client._http = http + client._closed = False + + async with client.typing("123"): + pass + + http.trigger_typing.assert_called_once_with("123") + + +@pytest.mark.asyncio +async def test_typing_closed(): + http = SimpleNamespace(trigger_typing=AsyncMock()) + client = Client.__new__(Client) + client._http = http + client._closed = True + + with pytest.raises(DisagreementException): + async with client.typing("123"): + pass + + +@pytest.mark.asyncio +async def test_context_typing(): + http = SimpleNamespace(trigger_typing=AsyncMock()) + + class DummyBot: + def __init__(self): + self._http = http + self._closed = False + + def typing(self, channel_id): + from disagreement.typing import Typing + + return Typing(self, channel_id) + + bot = DummyBot() + msg = SimpleNamespace(channel_id="c", id="1", guild_id=None, author=None) + + async def dummy(ctx): + pass + + from disagreement.ext.commands.core import Command, CommandContext + + cmd = Command(dummy) + ctx = CommandContext( + message=msg, bot=bot, prefix="!", command=cmd, invoked_with="dummy" + ) + + async with ctx.typing(): + pass + + http.trigger_typing.assert_called_once_with("c") diff --git a/tests/test_ui.py b/tests/test_ui.py new file mode 100644 index 0000000..b95282a --- /dev/null +++ b/tests/test_ui.py @@ -0,0 +1,65 @@ +import asyncio + +import pytest +from types import SimpleNamespace + +from disagreement.enums import ButtonStyle +from disagreement.models import SelectOption +from disagreement.ui.button import button, Button +from disagreement.ui.select import select, Select +from disagreement.ui.view import View + + +@pytest.mark.asyncio +async def test_button_decorator_creates_button(): + @button(label="Hi", custom_id="x") + async def cb(view, inter): + pass + + assert isinstance(cb, Button) + assert cb.label == "Hi" + view = View() + view.add_item(cb) + comps = view.to_components_payload() + assert comps[0]["components"][0]["custom_id"] == "x" + + +@pytest.mark.asyncio +async def test_button_decorator_requires_coroutine(): + with pytest.raises(TypeError): + button()(lambda x, y: None) + + +@pytest.mark.asyncio +async def test_select_decorator_creates_select(): + options = [SelectOption(label="A", value="a")] + + @select(custom_id="sel", options=options) + async def cb(view, inter): + pass + + assert isinstance(cb, Select) + view = View() + view.add_item(cb) + payload = view.to_components_payload()[0]["components"][0] + assert payload["custom_id"] == "sel" + + +@pytest.mark.asyncio +async def test_view_dispatch_calls_callback(monkeypatch): + called = {} + + @button(label="B", custom_id="b") + async def cb(view, inter): + called["ok"] = True + + view = View() + view.add_item(cb) + + class DummyInteraction: + def __init__(self): + self.data = SimpleNamespace(custom_id="b") + + interaction = DummyInteraction() + await view._dispatch(interaction) + assert called.get("ok") diff --git a/tests/test_view_layout.py b/tests/test_view_layout.py new file mode 100644 index 0000000..898b4d5 --- /dev/null +++ b/tests/test_view_layout.py @@ -0,0 +1,53 @@ +import pytest + +from disagreement.ui.view import View +from disagreement.ui.button import Button +from disagreement.ui.select import Select +from disagreement.enums import ButtonStyle, ComponentType +from disagreement.models import SelectOption + + +def test_basic_layout_keeps_one_item_per_row(): + view = View() + view.add_item(Button(style=ButtonStyle.PRIMARY, label="a")) + view.add_item(Button(style=ButtonStyle.PRIMARY, label="b")) + + rows = view.to_components() + assert len(rows) == 2 + assert all(len(r.components) == 1 for r in rows) + + +def test_advanced_layout_groups_buttons(): + view = View() + for i in range(6): + view.add_item(Button(style=ButtonStyle.PRIMARY, label=str(i))) + + rows = view.layout_components_advanced() + assert len(rows) == 2 + assert len(rows[0].components) == 5 + assert len(rows[1].components) == 1 + + +def test_advanced_layout_select_separate(): + view = View() + view.add_item(Select(custom_id="s", options=[SelectOption(label="A", value="a")])) + view.add_item(Button(style=ButtonStyle.PRIMARY, label="b1")) + view.add_item(Button(style=ButtonStyle.PRIMARY, label="b2")) + + rows = view.layout_components_advanced() + assert len(rows) == 2 + assert rows[0].components[0].type == ComponentType.STRING_SELECT + assert all(c.type == ComponentType.BUTTON for c in rows[1].components) + assert len(rows[1].components) == 2 + + +def test_advanced_layout_respects_row_attribute(): + view = View() + view.add_item(Button(style=ButtonStyle.PRIMARY, label="x", row=1)) + view.add_item(Button(style=ButtonStyle.PRIMARY, label="y", row=1)) + view.add_item(Button(style=ButtonStyle.PRIMARY, label="z", row=0)) + + rows = view.layout_components_advanced() + assert len(rows) == 2 + assert len(rows[0].components) == 1 + assert len(rows[1].components) == 2 diff --git a/tests/test_voice_client.py b/tests/test_voice_client.py new file mode 100644 index 0000000..4f10e50 --- /dev/null +++ b/tests/test_voice_client.py @@ -0,0 +1,75 @@ +import asyncio +import pytest + +from disagreement.voice_client import VoiceClient + + +class DummyWebSocket: + def __init__(self, messages): + self.sent = [] + self._queue = asyncio.Queue() + for m in messages: + self._queue.put_nowait(m) + + async def send_json(self, data): + self.sent.append(data) + + async def receive_json(self): + return await self._queue.get() + + async def close(self): + pass + + +class DummyUDP: + def __init__(self): + self.connected = None + self.sent = [] + + def connect(self, address): + self.connected = address + + def send(self, data): + self.sent.append(data) + + def getsockname(self): + return ("127.0.0.1", 12345) + + def close(self): + pass + + +@pytest.mark.asyncio +async def test_voice_client_handshake(): + hello = {"d": {"heartbeat_interval": 50}} + ready = {"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}} + session_desc = {"d": {"secret_key": [1, 2, 3]}} + ws = DummyWebSocket([hello, ready, session_desc]) + udp = DummyUDP() + + vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) + await vc.connect() + vc._heartbeat_task.cancel() + + assert ws.sent[0]["op"] == 0 + assert ws.sent[1]["op"] == 1 + assert udp.connected == ("127.0.0.1", 4000) + assert vc.secret_key == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_send_audio_frame(): + ws = DummyWebSocket( + [ + {"d": {"heartbeat_interval": 50}}, + {"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}}, + {"d": {"secret_key": []}}, + ] + ) + udp = DummyUDP() + vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) + await vc.connect() + vc._heartbeat_task.cancel() + + await vc.send_audio_frame(b"abc") + assert udp.sent[-1] == b"abc" diff --git a/tests/test_wait_for.py b/tests/test_wait_for.py new file mode 100644 index 0000000..8418ae7 --- /dev/null +++ b/tests/test_wait_for.py @@ -0,0 +1,35 @@ +import asyncio + +import pytest # pylint: disable=import-error + +from disagreement.client import Client + + +@pytest.mark.asyncio +async def test_wait_for_resolves_on_event(): + client = Client(token="t") + + async def dispatch_event(): + await asyncio.sleep(0.05) + data = { + "id": "42", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "hello", + "timestamp": "t", + } + await client._event_dispatcher.dispatch("MESSAGE_CREATE", data) + + asyncio.create_task(dispatch_event()) + message = await client.wait_for( + "MESSAGE_CREATE", check=lambda m: m.id == "42", timeout=1 + ) + + assert message.content == "hello" + + +@pytest.mark.asyncio +async def test_wait_for_timeout(): + client = Client(token="t") + with pytest.raises(asyncio.TimeoutError): + await client.wait_for("MESSAGE_CREATE", timeout=0.1) diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py new file mode 100644 index 0000000..af63cb1 --- /dev/null +++ b/tests/test_webhooks.py @@ -0,0 +1,44 @@ +import pytest +from unittest.mock import AsyncMock + +from disagreement.http import HTTPClient + + +@pytest.mark.asyncio +async def test_create_followup_message_calls_request(): + http = HTTPClient(token="t") + http.request = AsyncMock() + payload = {"content": "hello"} + await http.create_followup_message("app_id", "token", payload) + http.request.assert_called_once_with( + "POST", + f"/webhooks/app_id/token", + payload=payload, + use_auth_header=False, + ) + + +@pytest.mark.asyncio +async def test_edit_followup_message_calls_request(): + http = HTTPClient(token="t") + http.request = AsyncMock() + payload = {"content": "new content"} + await http.edit_followup_message("app_id", "token", "123", payload) + http.request.assert_called_once_with( + "PATCH", + f"/webhooks/app_id/token/messages/123", + payload=payload, + use_auth_header=False, + ) + + +@pytest.mark.asyncio +async def test_delete_followup_message_calls_request(): + http = HTTPClient(token="t") + http.request = AsyncMock() + await http.delete_followup_message("app_id", "token", "456") + http.request.assert_called_once_with( + "DELETE", + f"/webhooks/app_id/token/messages/456", + use_auth_header=False, + )