diff --git a/disagreement/gateway.py b/disagreement/gateway.py index ddcea94..7cd53cc 100644 --- a/disagreement/gateway.py +++ b/disagreement/gateway.py @@ -344,6 +344,9 @@ class GatewayClient: raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} ) await self._dispatcher.dispatch(event_name, event_data_to_dispatch) + await self._dispatcher.dispatch( + "SHARD_RESUME", {"shard_id": self._shard_id} + ) elif event_name: # For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing. # Models/parsers in EventDispatcher will need to handle potentially empty dicts. @@ -508,6 +511,10 @@ class GatewayClient: self._receive_task.cancel() self._receive_task = self._loop.create_task(self._receive_loop()) + await self._dispatcher.dispatch( + "SHARD_CONNECT", {"shard_id": self._shard_id} + ) + except aiohttp.ClientConnectorError as e: raise GatewayException( f"Failed to connect to Gateway (Connector Error): {e}" @@ -559,6 +566,10 @@ class GatewayClient: self._last_sequence = None self._resume_gateway_url = None # This might be re-fetched anyway + await self._dispatcher.dispatch( + "SHARD_DISCONNECT", {"shard_id": self._shard_id} + ) + @property def latency(self) -> Optional[float]: """Returns the latency between heartbeat and ACK in seconds.""" diff --git a/docs/events.md b/docs/events.md index d22294e..5cd57ef 100644 --- a/docs/events.md +++ b/docs/events.md @@ -98,3 +98,36 @@ Emitted when a guild role is updated. The callback receives a async def on_guild_role_update(event: disagreement.GuildRoleUpdate): ... ``` + +## SHARD_CONNECT + +Fired when a shard establishes its gateway connection. The callback receives a +dictionary with the shard ID. + +```python +@client.event +async def on_shard_connect(info: dict): + print("shard connected", info["shard_id"]) +``` + +## SHARD_DISCONNECT + +Emitted when a shard's gateway connection is closed. The callback receives a +dictionary with the shard ID. + +```python +@client.event +async def on_shard_disconnect(info: dict): + ... +``` + +## SHARD_RESUME + +Sent when a shard successfully resumes after a reconnect. The callback receives +a dictionary with the shard ID. + +```python +@client.event +async def on_shard_resume(info: dict): + ... +``` diff --git a/tests/test_sharding.py b/tests/test_sharding.py index 3c2249d..15c1d10 100644 --- a/tests/test_sharding.py +++ b/tests/test_sharding.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock from disagreement.shard_manager import ShardManager from disagreement.client import Client, AutoShardedClient +from disagreement.event_dispatcher import EventDispatcher class DummyGateway: @@ -10,11 +11,27 @@ class DummyGateway: self.connect = AsyncMock() self.close = AsyncMock() + dispatcher = kwargs.get("event_dispatcher") + shard_id = kwargs.get("shard_id") + + async def emit_connect(): + await dispatcher.dispatch("SHARD_CONNECT", {"shard_id": shard_id}) + + async def emit_close(): + await dispatcher.dispatch("SHARD_DISCONNECT", {"shard_id": shard_id}) + + async def emit_resume(): + await dispatcher.dispatch("SHARD_RESUME", {"shard_id": shard_id}) + + self.connect.side_effect = emit_connect + self.close.side_effect = emit_close + self.resume = AsyncMock(side_effect=emit_resume) + class DummyClient: def __init__(self): self._http = object() - self._event_dispatcher = object() + self._event_dispatcher = EventDispatcher(self) self.token = "t" self.intents = 0 self.verbose = False @@ -68,3 +85,34 @@ async def test_auto_sharded_client_fetches_count(monkeypatch): await c.connect() dummy_manager.start.assert_awaited_once() assert c.shard_count == 4 + + +@pytest.mark.asyncio +async def test_shard_events_emitted(monkeypatch): + monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway) + + client = DummyClient() + manager = ShardManager(client, shard_count=1) + + events: list[tuple[str, int | None]] = [] + + async def on_connect(info): + events.append(("connect", info.get("shard_id"))) + + async def on_disconnect(info): + events.append(("disconnect", info.get("shard_id"))) + + async def on_resume(info): + events.append(("resume", info.get("shard_id"))) + + client._event_dispatcher.register("SHARD_CONNECT", on_connect) + client._event_dispatcher.register("SHARD_DISCONNECT", on_disconnect) + client._event_dispatcher.register("SHARD_RESUME", on_resume) + + await manager.start() + await manager.shards[0].gateway.resume() + await manager.close() + + assert ("connect", 0) in events + assert ("disconnect", 0) in events + assert ("resume", 0) in events