Implement gateway reconnection backoff (#24)
This commit is contained in:
parent
09fae8a489
commit
1ff56106c9
@ -84,6 +84,8 @@ class Client:
|
||||
verbose: bool = False,
|
||||
mention_replies: bool = False,
|
||||
shard_count: Optional[int] = None,
|
||||
gateway_max_retries: int = 5,
|
||||
gateway_max_backoff: float = 60.0,
|
||||
):
|
||||
if not token:
|
||||
raise ValueError("A bot token must be provided.")
|
||||
@ -103,6 +105,8 @@ class Client:
|
||||
None # Initialized in run() or connect()
|
||||
)
|
||||
self.shard_count: Optional[int] = shard_count
|
||||
self.gateway_max_retries: int = gateway_max_retries
|
||||
self.gateway_max_backoff: float = gateway_max_backoff
|
||||
self._shard_manager: Optional[ShardManager] = None
|
||||
|
||||
# Initialize CommandHandler
|
||||
@ -169,6 +173,8 @@ class Client:
|
||||
intents=self.intents,
|
||||
client_instance=self,
|
||||
verbose=self.verbose,
|
||||
max_retries=self.gateway_max_retries,
|
||||
max_backoff=self.gateway_max_backoff,
|
||||
)
|
||||
|
||||
async def _initialize_shard_manager(self) -> None:
|
||||
|
@ -10,6 +10,7 @@ import aiohttp
|
||||
import json
|
||||
import zlib
|
||||
import time
|
||||
import random
|
||||
from typing import Optional, TYPE_CHECKING, Any, Dict
|
||||
|
||||
from .enums import GatewayOpcode, GatewayIntent
|
||||
@ -43,6 +44,8 @@ class GatewayClient:
|
||||
*,
|
||||
shard_id: Optional[int] = None,
|
||||
shard_count: Optional[int] = None,
|
||||
max_retries: int = 5,
|
||||
max_backoff: float = 60.0,
|
||||
):
|
||||
self._http: "HTTPClient" = http_client
|
||||
self._dispatcher: "EventDispatcher" = event_dispatcher
|
||||
@ -52,6 +55,8 @@ class GatewayClient:
|
||||
self.verbose: bool = verbose
|
||||
self._shard_id: Optional[int] = shard_id
|
||||
self._shard_count: Optional[int] = shard_count
|
||||
self._max_retries: int = max_retries
|
||||
self._max_backoff: float = max_backoff
|
||||
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
@ -70,6 +75,26 @@ class GatewayClient:
|
||||
self._buffer = bytearray()
|
||||
self._inflator = zlib.decompressobj()
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
"""Attempts to reconnect using exponential backoff with jitter."""
|
||||
delay = 1.0
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
await self.connect()
|
||||
return
|
||||
except Exception as e: # noqa: BLE001
|
||||
if attempt >= self._max_retries - 1:
|
||||
print(f"Reconnect failed after {attempt + 1} attempts: {e}")
|
||||
raise
|
||||
jitter = random.uniform(0, delay)
|
||||
wait_time = min(delay + jitter, self._max_backoff)
|
||||
print(
|
||||
f"Reconnect attempt {attempt + 1} failed: {e}. "
|
||||
f"Retrying in {wait_time:.2f} seconds..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
delay = min(delay * 2, self._max_backoff)
|
||||
|
||||
async def _decompress_message(
|
||||
self, message_bytes: bytes
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
@ -354,7 +379,7 @@ class GatewayClient:
|
||||
await self._heartbeat()
|
||||
elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect
|
||||
print("Gateway requested RECONNECT. Closing and will attempt to reconnect.")
|
||||
await self.close(code=4000) # Use a non-1000 code to indicate reconnect
|
||||
await self.close(code=4000, reconnect=True)
|
||||
elif op == GatewayOpcode.INVALID_SESSION:
|
||||
# The 'd' payload for INVALID_SESSION is a boolean indicating resumability
|
||||
can_resume = data.get("d") is True
|
||||
@ -363,9 +388,7 @@ class GatewayClient:
|
||||
self._session_id = None # Clear session_id to force re-identify
|
||||
self._last_sequence = None
|
||||
# Close and reconnect. The connect logic will decide to resume or identify.
|
||||
await self.close(
|
||||
code=4000 if can_resume else 4009
|
||||
) # 4009 for non-resumable
|
||||
await self.close(code=4000 if can_resume else 4009, reconnect=True)
|
||||
elif op == GatewayOpcode.HELLO:
|
||||
hello_d_payload = data.get("d")
|
||||
if (
|
||||
@ -411,13 +434,11 @@ class GatewayClient:
|
||||
print("Receive_loop task cancelled.")
|
||||
except aiohttp.ClientConnectionError as e:
|
||||
print(f"ClientConnectionError in receive_loop: {e}. Attempting reconnect.")
|
||||
# This might be handled by an outer reconnect loop in the Client class
|
||||
await self.close(code=1006) # Abnormal closure
|
||||
await self.close(code=1006, reconnect=True) # Abnormal closure
|
||||
except Exception as e:
|
||||
print(f"Unexpected error in receive_loop: {e}")
|
||||
traceback.print_exc()
|
||||
# Consider specific error types for more granular handling
|
||||
await self.close(code=1011) # Internal error
|
||||
await self.close(code=1011, reconnect=True)
|
||||
finally:
|
||||
print("Receive_loop ended.")
|
||||
# If the loop ends unexpectedly (not due to explicit close),
|
||||
@ -465,7 +486,7 @@ class GatewayClient:
|
||||
f"An unexpected error occurred during Gateway connection: {e}"
|
||||
) from e
|
||||
|
||||
async def close(self, code: int = 1000):
|
||||
async def close(self, code: int = 1000, *, reconnect: bool = False):
|
||||
"""Closes the Gateway connection."""
|
||||
print(f"Closing Gateway connection with code {code}...")
|
||||
if self._keep_alive_task and not self._keep_alive_task.done():
|
||||
@ -476,11 +497,13 @@ class GatewayClient:
|
||||
pass # Expected
|
||||
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
current = asyncio.current_task(loop=self._loop)
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected
|
||||
if self._receive_task is not current:
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected
|
||||
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.close(code=code)
|
||||
|
@ -51,6 +51,8 @@ class ShardManager:
|
||||
verbose=self.client.verbose,
|
||||
shard_id=shard_id,
|
||||
shard_count=self.shard_count,
|
||||
max_retries=self.client.gateway_max_retries,
|
||||
max_backoff=self.client.gateway_max_backoff,
|
||||
)
|
||||
self.shards.append(Shard(shard_id, self.shard_count, gateway))
|
||||
|
||||
|
@ -4,7 +4,8 @@
|
||||
|
||||
The default behaviour tries up to five reconnect attempts, doubling the delay each time up to a configurable maximum. A small random jitter is added to spread out reconnect attempts when multiple clients restart at once.
|
||||
|
||||
You can control the maximum number of retries and the backoff cap when constructing `Client`:
|
||||
You can control the maximum number of retries and the backoff cap when constructing `Client`.
|
||||
These options are forwarded to `GatewayClient` as `max_retries` and `max_backoff`:
|
||||
|
||||
```python
|
||||
bot = Client(
|
||||
|
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
@ -66,3 +67,37 @@ async def test_client_connect_backoff(monkeypatch):
|
||||
assert mock_gateway_connect.call_count == 3
|
||||
# Assert the delays experienced due to exponential backoff
|
||||
assert delays == [5, 10]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_reconnect_backoff(monkeypatch):
|
||||
http = DummyHTTP()
|
||||
dispatcher = DummyDispatcher()
|
||||
client = DummyClient()
|
||||
gw = GatewayClient(
|
||||
http_client=http,
|
||||
event_dispatcher=dispatcher,
|
||||
token="t",
|
||||
intents=0,
|
||||
client_instance=client,
|
||||
max_retries=3,
|
||||
max_backoff=10.0,
|
||||
)
|
||||
|
||||
connect_mock = AsyncMock(
|
||||
side_effect=[GatewayException("boom"), GatewayException("boom"), None]
|
||||
)
|
||||
monkeypatch.setattr(gw, "connect", connect_mock)
|
||||
|
||||
delays = []
|
||||
|
||||
async def fake_sleep(d):
|
||||
delays.append(d)
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
||||
monkeypatch.setattr(random, "uniform", lambda a, b: 0)
|
||||
|
||||
await gw._reconnect()
|
||||
|
||||
assert connect_mock.call_count == 3
|
||||
assert delays == [1.0, 2.0]
|
||||
|
@ -18,6 +18,8 @@ class DummyClient:
|
||||
self.token = "t"
|
||||
self.intents = 0
|
||||
self.verbose = False
|
||||
self.gateway_max_retries = 5
|
||||
self.gateway_max_backoff = 60.0
|
||||
|
||||
|
||||
def test_shard_manager_creates_shards(monkeypatch):
|
||||
|
Loading…
x
Reference in New Issue
Block a user