diff --git a/disagreement/client.py b/disagreement/client.py index d81c533..6e7c7a7 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -1700,10 +1700,20 @@ class Client: pass - async def on_typing_start(self, typing) -> None: - """|coro| Called when a user starts typing in a channel.""" - - pass + async def on_typing_start(self, typing) -> None: + """|coro| Called when a user starts typing in a channel.""" + + pass + + async def on_connect(self) -> None: + """|coro| Called when the WebSocket connection opens.""" + + pass + + async def on_disconnect(self) -> None: + """|coro| Called when the WebSocket connection closes.""" + + pass async def on_app_command_error( self, context: AppCommandContext, error: Exception diff --git a/disagreement/gateway.py b/disagreement/gateway.py index 20591a5..bab89ca 100644 --- a/disagreement/gateway.py +++ b/disagreement/gateway.py @@ -569,6 +569,7 @@ class GatewayClient: await self._dispatcher.dispatch( "SHARD_CONNECT", {"shard_id": self._shard_id} ) + await self._dispatcher.dispatch("CONNECT", {"shard_id": self._shard_id}) except aiohttp.ClientConnectorError as e: raise GatewayException( @@ -624,6 +625,7 @@ class GatewayClient: await self._dispatcher.dispatch( "SHARD_DISCONNECT", {"shard_id": self._shard_id} ) + await self._dispatcher.dispatch("DISCONNECT", {"shard_id": self._shard_id}) @property def latency(self) -> Optional[float]: diff --git a/docs/events.md b/docs/events.md index e669b5b..57a470e 100644 --- a/docs/events.md +++ b/docs/events.md @@ -132,6 +132,28 @@ async def on_shard_resume(info: dict): ... ``` +## CONNECT + +Dispatched when the WebSocket connection opens. The callback receives a +dictionary with the shard ID. + +```python +@client.event +async def on_connect(info: dict): + print("connected", info.get("shard_id")) +``` + +## DISCONNECT + +Fired when the WebSocket connection closes. The callback receives a dictionary +with the shard ID. + +```python +@client.event +async def on_disconnect(info: dict): + ... +``` + ## VOICE_STATE_UPDATE Triggered when a user's voice connection state changes, such as joining or leaving a voice channel. The callback receives a `VoiceStateUpdate` model. diff --git a/tests/test_connect_events.py b/tests/test_connect_events.py new file mode 100644 index 0000000..4c3e651 --- /dev/null +++ b/tests/test_connect_events.py @@ -0,0 +1,59 @@ +import asyncio +import pytest +from unittest.mock import AsyncMock + +from disagreement.shard_manager import ShardManager +from disagreement.event_dispatcher import EventDispatcher + + +class DummyGateway: + def __init__(self, *args, **kwargs): + self.connect = AsyncMock() + self.close = AsyncMock() + + dispatcher = kwargs.get("event_dispatcher") + shard_id = kwargs.get("shard_id") + + async def emit_connect(): + await dispatcher.dispatch("CONNECT", {"shard_id": shard_id}) + + async def emit_close(): + await dispatcher.dispatch("DISCONNECT", {"shard_id": shard_id}) + + self.connect.side_effect = emit_connect + self.close.side_effect = emit_close + + +class DummyClient: + def __init__(self): + self._http = object() + self._event_dispatcher = EventDispatcher(self) + self.token = "t" + self.intents = 0 + self.verbose = False + self.gateway_max_retries = 5 + self.gateway_max_backoff = 60.0 + + +@pytest.mark.asyncio +async def test_connect_disconnect_events(monkeypatch): + monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway) + client = DummyClient() + manager = ShardManager(client, shard_count=1) + + events: list[tuple[str, int | None]] = [] + + async def on_connect(info): + events.append(("connect", info.get("shard_id"))) + + async def on_disconnect(info): + events.append(("disconnect", info.get("shard_id"))) + + client._event_dispatcher.register("CONNECT", on_connect) + client._event_dispatcher.register("DISCONNECT", on_disconnect) + + await manager.start() + await manager.close() + + assert ("connect", 0) in events + assert ("disconnect", 0) in events