parent
36b06c6c7a
commit
7595e33fd1
@ -344,6 +344,9 @@ class GatewayClient:
|
||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||
)
|
||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||
await self._dispatcher.dispatch(
|
||||
"SHARD_RESUME", {"shard_id": self._shard_id}
|
||||
)
|
||||
elif event_name:
|
||||
# For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing.
|
||||
# Models/parsers in EventDispatcher will need to handle potentially empty dicts.
|
||||
@ -508,6 +511,10 @@ class GatewayClient:
|
||||
self._receive_task.cancel()
|
||||
self._receive_task = self._loop.create_task(self._receive_loop())
|
||||
|
||||
await self._dispatcher.dispatch(
|
||||
"SHARD_CONNECT", {"shard_id": self._shard_id}
|
||||
)
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
raise GatewayException(
|
||||
f"Failed to connect to Gateway (Connector Error): {e}"
|
||||
@ -559,6 +566,10 @@ class GatewayClient:
|
||||
self._last_sequence = None
|
||||
self._resume_gateway_url = None # This might be re-fetched anyway
|
||||
|
||||
await self._dispatcher.dispatch(
|
||||
"SHARD_DISCONNECT", {"shard_id": self._shard_id}
|
||||
)
|
||||
|
||||
@property
|
||||
def latency(self) -> Optional[float]:
|
||||
"""Returns the latency between heartbeat and ACK in seconds."""
|
||||
|
@ -98,3 +98,36 @@ Emitted when a guild role is updated. The callback receives a
|
||||
async def on_guild_role_update(event: disagreement.GuildRoleUpdate):
|
||||
...
|
||||
```
|
||||
|
||||
## SHARD_CONNECT
|
||||
|
||||
Fired when a shard establishes its gateway connection. The callback receives a
|
||||
dictionary with the shard ID.
|
||||
|
||||
```python
|
||||
@client.event
|
||||
async def on_shard_connect(info: dict):
|
||||
print("shard connected", info["shard_id"])
|
||||
```
|
||||
|
||||
## SHARD_DISCONNECT
|
||||
|
||||
Emitted when a shard's gateway connection is closed. The callback receives a
|
||||
dictionary with the shard ID.
|
||||
|
||||
```python
|
||||
@client.event
|
||||
async def on_shard_disconnect(info: dict):
|
||||
...
|
||||
```
|
||||
|
||||
## SHARD_RESUME
|
||||
|
||||
Sent when a shard successfully resumes after a reconnect. The callback receives
|
||||
a dictionary with the shard ID.
|
||||
|
||||
```python
|
||||
@client.event
|
||||
async def on_shard_resume(info: dict):
|
||||
...
|
||||
```
|
||||
|
@ -3,6 +3,7 @@ from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.shard_manager import ShardManager
|
||||
from disagreement.client import Client, AutoShardedClient
|
||||
from disagreement.event_dispatcher import EventDispatcher
|
||||
|
||||
|
||||
class DummyGateway:
|
||||
@ -10,11 +11,27 @@ class DummyGateway:
|
||||
self.connect = AsyncMock()
|
||||
self.close = AsyncMock()
|
||||
|
||||
dispatcher = kwargs.get("event_dispatcher")
|
||||
shard_id = kwargs.get("shard_id")
|
||||
|
||||
async def emit_connect():
|
||||
await dispatcher.dispatch("SHARD_CONNECT", {"shard_id": shard_id})
|
||||
|
||||
async def emit_close():
|
||||
await dispatcher.dispatch("SHARD_DISCONNECT", {"shard_id": shard_id})
|
||||
|
||||
async def emit_resume():
|
||||
await dispatcher.dispatch("SHARD_RESUME", {"shard_id": shard_id})
|
||||
|
||||
self.connect.side_effect = emit_connect
|
||||
self.close.side_effect = emit_close
|
||||
self.resume = AsyncMock(side_effect=emit_resume)
|
||||
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self):
|
||||
self._http = object()
|
||||
self._event_dispatcher = object()
|
||||
self._event_dispatcher = EventDispatcher(self)
|
||||
self.token = "t"
|
||||
self.intents = 0
|
||||
self.verbose = False
|
||||
@ -68,3 +85,34 @@ async def test_auto_sharded_client_fetches_count(monkeypatch):
|
||||
await c.connect()
|
||||
dummy_manager.start.assert_awaited_once()
|
||||
assert c.shard_count == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shard_events_emitted(monkeypatch):
|
||||
monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway)
|
||||
|
||||
client = DummyClient()
|
||||
manager = ShardManager(client, shard_count=1)
|
||||
|
||||
events: list[tuple[str, int | None]] = []
|
||||
|
||||
async def on_connect(info):
|
||||
events.append(("connect", info.get("shard_id")))
|
||||
|
||||
async def on_disconnect(info):
|
||||
events.append(("disconnect", info.get("shard_id")))
|
||||
|
||||
async def on_resume(info):
|
||||
events.append(("resume", info.get("shard_id")))
|
||||
|
||||
client._event_dispatcher.register("SHARD_CONNECT", on_connect)
|
||||
client._event_dispatcher.register("SHARD_DISCONNECT", on_disconnect)
|
||||
client._event_dispatcher.register("SHARD_RESUME", on_resume)
|
||||
|
||||
await manager.start()
|
||||
await manager.shards[0].gateway.resume()
|
||||
await manager.close()
|
||||
|
||||
assert ("connect", 0) in events
|
||||
assert ("disconnect", 0) in events
|
||||
assert ("resume", 0) in events
|
||||
|
Loading…
x
Reference in New Issue
Block a user