Implement gateway reconnection backoff (#24)

This commit is contained in:
Slipstream 2025-06-10 15:49:11 -06:00 committed by GitHub
parent 09fae8a489
commit 1ff56106c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 83 additions and 14 deletions

View File

@ -84,6 +84,8 @@ class Client:
verbose: bool = False,
mention_replies: bool = False,
shard_count: Optional[int] = None,
gateway_max_retries: int = 5,
gateway_max_backoff: float = 60.0,
):
if not token:
raise ValueError("A bot token must be provided.")
@ -103,6 +105,8 @@ class Client:
None # Initialized in run() or connect()
)
self.shard_count: Optional[int] = shard_count
self.gateway_max_retries: int = gateway_max_retries
self.gateway_max_backoff: float = gateway_max_backoff
self._shard_manager: Optional[ShardManager] = None
# Initialize CommandHandler
@ -169,6 +173,8 @@ class Client:
intents=self.intents,
client_instance=self,
verbose=self.verbose,
max_retries=self.gateway_max_retries,
max_backoff=self.gateway_max_backoff,
)
async def _initialize_shard_manager(self) -> None:

View File

@ -10,6 +10,7 @@ import aiohttp
import json
import zlib
import time
import random
from typing import Optional, TYPE_CHECKING, Any, Dict
from .enums import GatewayOpcode, GatewayIntent
@ -43,6 +44,8 @@ class GatewayClient:
*,
shard_id: Optional[int] = None,
shard_count: Optional[int] = None,
max_retries: int = 5,
max_backoff: float = 60.0,
):
self._http: "HTTPClient" = http_client
self._dispatcher: "EventDispatcher" = event_dispatcher
@ -52,6 +55,8 @@ class GatewayClient:
self.verbose: bool = verbose
self._shard_id: Optional[int] = shard_id
self._shard_count: Optional[int] = shard_count
self._max_retries: int = max_retries
self._max_backoff: float = max_backoff
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
@ -70,6 +75,26 @@ class GatewayClient:
self._buffer = bytearray()
self._inflator = zlib.decompressobj()
async def _reconnect(self) -> None:
"""Attempts to reconnect using exponential backoff with jitter."""
delay = 1.0
for attempt in range(self._max_retries):
try:
await self.connect()
return
except Exception as e: # noqa: BLE001
if attempt >= self._max_retries - 1:
print(f"Reconnect failed after {attempt + 1} attempts: {e}")
raise
jitter = random.uniform(0, delay)
wait_time = min(delay + jitter, self._max_backoff)
print(
f"Reconnect attempt {attempt + 1} failed: {e}. "
f"Retrying in {wait_time:.2f} seconds..."
)
await asyncio.sleep(wait_time)
delay = min(delay * 2, self._max_backoff)
async def _decompress_message(
self, message_bytes: bytes
) -> Optional[Dict[str, Any]]:
@ -354,7 +379,7 @@ class GatewayClient:
await self._heartbeat()
elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect
print("Gateway requested RECONNECT. Closing and will attempt to reconnect.")
await self.close(code=4000) # Use a non-1000 code to indicate reconnect
await self.close(code=4000, reconnect=True)
elif op == GatewayOpcode.INVALID_SESSION:
# The 'd' payload for INVALID_SESSION is a boolean indicating resumability
can_resume = data.get("d") is True
@ -363,9 +388,7 @@ class GatewayClient:
self._session_id = None # Clear session_id to force re-identify
self._last_sequence = None
# Close and reconnect. The connect logic will decide to resume or identify.
await self.close(
code=4000 if can_resume else 4009
) # 4009 for non-resumable
await self.close(code=4000 if can_resume else 4009, reconnect=True)
elif op == GatewayOpcode.HELLO:
hello_d_payload = data.get("d")
if (
@ -411,13 +434,11 @@ class GatewayClient:
print("Receive_loop task cancelled.")
except aiohttp.ClientConnectionError as e:
print(f"ClientConnectionError in receive_loop: {e}. Attempting reconnect.")
# This might be handled by an outer reconnect loop in the Client class
await self.close(code=1006) # Abnormal closure
await self.close(code=1006, reconnect=True) # Abnormal closure
except Exception as e:
print(f"Unexpected error in receive_loop: {e}")
traceback.print_exc()
# Consider specific error types for more granular handling
await self.close(code=1011) # Internal error
await self.close(code=1011, reconnect=True)
finally:
print("Receive_loop ended.")
# If the loop ends unexpectedly (not due to explicit close),
@ -465,7 +486,7 @@ class GatewayClient:
f"An unexpected error occurred during Gateway connection: {e}"
) from e
async def close(self, code: int = 1000):
async def close(self, code: int = 1000, *, reconnect: bool = False):
"""Closes the Gateway connection."""
print(f"Closing Gateway connection with code {code}...")
if self._keep_alive_task and not self._keep_alive_task.done():
@ -476,11 +497,13 @@ class GatewayClient:
pass # Expected
if self._receive_task and not self._receive_task.done():
current = asyncio.current_task(loop=self._loop)
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass # Expected
if self._receive_task is not current:
try:
await self._receive_task
except asyncio.CancelledError:
pass # Expected
if self._ws and not self._ws.closed:
await self._ws.close(code=code)

View File

@ -51,6 +51,8 @@ class ShardManager:
verbose=self.client.verbose,
shard_id=shard_id,
shard_count=self.shard_count,
max_retries=self.client.gateway_max_retries,
max_backoff=self.client.gateway_max_backoff,
)
self.shards.append(Shard(shard_id, self.shard_count, gateway))

View File

@ -4,7 +4,8 @@
The default behaviour tries up to five reconnect attempts, doubling the delay each time up to a configurable maximum. A small random jitter is added to spread out reconnect attempts when multiple clients restart at once.
You can control the maximum number of retries and the backoff cap when constructing `Client`:
You can control the maximum number of retries and the backoff cap when constructing `Client`.
These options are forwarded to `GatewayClient` as `max_retries` and `max_backoff`:
```python
bot = Client(

View File

@ -1,5 +1,6 @@
import asyncio
from unittest.mock import AsyncMock
import random
import pytest
@ -66,3 +67,37 @@ async def test_client_connect_backoff(monkeypatch):
assert mock_gateway_connect.call_count == 3
# Assert the delays experienced due to exponential backoff
assert delays == [5, 10]
@pytest.mark.asyncio
async def test_gateway_reconnect_backoff(monkeypatch):
http = DummyHTTP()
dispatcher = DummyDispatcher()
client = DummyClient()
gw = GatewayClient(
http_client=http,
event_dispatcher=dispatcher,
token="t",
intents=0,
client_instance=client,
max_retries=3,
max_backoff=10.0,
)
connect_mock = AsyncMock(
side_effect=[GatewayException("boom"), GatewayException("boom"), None]
)
monkeypatch.setattr(gw, "connect", connect_mock)
delays = []
async def fake_sleep(d):
delays.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
monkeypatch.setattr(random, "uniform", lambda a, b: 0)
await gw._reconnect()
assert connect_mock.call_count == 3
assert delays == [1.0, 2.0]

View File

@ -18,6 +18,8 @@ class DummyClient:
self.token = "t"
self.intents = 0
self.verbose = False
self.gateway_max_retries = 5
self.gateway_max_backoff = 60.0
def test_shard_manager_creates_shards(monkeypatch):