diff --git a/disagreement/client.py b/disagreement/client.py index fc58f54..f4557b5 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -85,12 +85,15 @@ class Client: verbose (bool): If True, print raw HTTP and Gateway traffic for debugging. mention_replies (bool): Whether replies mention the author by default. allowed_mentions (Optional[Dict[str, Any]]): Default allowed mentions for messages. - http_options (Optional[Dict[str, Any]]): Extra options passed to - :class:`HTTPClient` for creating the internal - :class:`aiohttp.ClientSession`. - message_cache_maxlen (Optional[int]): Maximum number of messages to keep - in the cache. When ``None``, the cache size is unlimited. - """ + http_options (Optional[Dict[str, Any]]): Extra options passed to + :class:`HTTPClient` for creating the internal + :class:`aiohttp.ClientSession`. + message_cache_maxlen (Optional[int]): Maximum number of messages to keep + in the cache. When ``None``, the cache size is unlimited. + sync_commands_on_ready (bool): If ``True``, automatically call + :meth:`Client.sync_application_commands` after the ``READY`` event + when :attr:`Client.application_id` is available. + """ def __init__( self, @@ -109,8 +112,9 @@ class Client: gateway_max_backoff: float = 60.0, member_cache_flags: Optional[MemberCacheFlags] = None, message_cache_maxlen: Optional[int] = None, - http_options: Optional[Dict[str, Any]] = None, - ): + http_options: Optional[Dict[str, Any]] = None, + sync_commands_on_ready: bool = True, + ): if not token: raise ValueError("A bot token must be provided.") @@ -175,8 +179,9 @@ class Client: self._webhooks: Dict[Snowflake, "Webhook"] = {} # Default whether replies mention the user - self.mention_replies: bool = mention_replies - self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions + self.mention_replies: bool = mention_replies + self.allowed_mentions: Optional[Dict[str, Any]] = allowed_mentions + self.sync_commands_on_ready: bool = sync_commands_on_ready # Basic signal handling for graceful shutdown # This might be better handled by the user's application code, but can be a nice default. diff --git a/disagreement/gateway.py b/disagreement/gateway.py index 10bb87e..20591a5 100644 --- a/disagreement/gateway.py +++ b/disagreement/gateway.py @@ -341,6 +341,12 @@ class GatewayClient: if isinstance(raw_event_d_payload, dict) and self._shard_id is not None: raw_event_d_payload["shard_id"] = self._shard_id await self._dispatcher.dispatch(event_name, raw_event_d_payload) + + if ( + getattr(self._client_instance, "sync_commands_on_ready", True) + and self._client_instance.application_id + ): + asyncio.create_task(self._client_instance.sync_application_commands()) elif event_name == "GUILD_MEMBERS_CHUNK": if isinstance(raw_event_d_payload, dict): nonce = raw_event_d_payload.get("nonce") diff --git a/tests/test_auto_sync_commands.py b/tests/test_auto_sync_commands.py new file mode 100644 index 0000000..84cd457 --- /dev/null +++ b/tests/test_auto_sync_commands.py @@ -0,0 +1,83 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from disagreement.client import Client +from disagreement.gateway import GatewayClient +from disagreement.event_dispatcher import EventDispatcher + + +class DummyHTTP: + pass + + +class DummyUser: + username = "u" + discriminator = "0001" + + +@pytest.mark.asyncio +async def test_auto_sync_on_ready(monkeypatch): + client = Client(token="t", application_id="123") + http = DummyHTTP() + dispatcher = EventDispatcher(client) + gw = GatewayClient( + http_client=http, + event_dispatcher=dispatcher, + token="t", + intents=0, + client_instance=client, + ) + monkeypatch.setattr(client, "parse_user", lambda d: DummyUser()) + monkeypatch.setattr(gw._dispatcher, "dispatch", AsyncMock()) + sync_mock = AsyncMock() + monkeypatch.setattr(client, "sync_application_commands", sync_mock) + + data = { + "t": "READY", + "s": 1, + "d": { + "session_id": "s1", + "resume_gateway_url": "url", + "application": {"id": "123"}, + "user": {"id": "1"}, + }, + } + + await gw._handle_dispatch(data) + await asyncio.sleep(0) + sync_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_auto_sync_disabled(monkeypatch): + client = Client(token="t", application_id="123", sync_commands_on_ready=False) + http = DummyHTTP() + dispatcher = EventDispatcher(client) + gw = GatewayClient( + http_client=http, + event_dispatcher=dispatcher, + token="t", + intents=0, + client_instance=client, + ) + monkeypatch.setattr(client, "parse_user", lambda d: DummyUser()) + monkeypatch.setattr(gw._dispatcher, "dispatch", AsyncMock()) + sync_mock = AsyncMock() + monkeypatch.setattr(client, "sync_application_commands", sync_mock) + + data = { + "t": "READY", + "s": 1, + "d": { + "session_id": "s1", + "resume_gateway_url": "url", + "application": {"id": "123"}, + "user": {"id": "1"}, + }, + } + + await gw._handle_dispatch(data) + await asyncio.sleep(0) + sync_mock.assert_not_called()