disagreement/tests/test_gateway_backoff.py

104 lines
2.8 KiB
Python

import asyncio
from unittest.mock import AsyncMock
import random
import pytest
from disagreement.gateway import GatewayClient, GatewayException
from disagreement.client import Client
class DummyHTTP:
async def get_gateway_bot(self):
return {"url": "ws://example"}
async def _ensure_session(self):
self._session = AsyncMock()
self._session.ws_connect = AsyncMock()
class DummyDispatcher:
async def dispatch(self, *_):
pass
class DummyClient:
def __init__(self):
self.loop = asyncio.get_event_loop()
self.application_id = None # Mock application_id for Client.connect
@pytest.mark.asyncio
async def test_client_connect_backoff(monkeypatch):
http = DummyHTTP()
# Mock the GatewayClient's connect method to simulate failures and then success
mock_gateway_connect = AsyncMock(
side_effect=[GatewayException("boom"), GatewayException("boom"), None]
)
# Create a dummy client instance
client = Client(
token="test_token",
intents=0,
loop=asyncio.get_event_loop(),
command_prefix="!",
verbose=False,
mention_replies=False,
shard_count=None,
)
# Patch the internal _gateway attribute after client initialization
# This ensures _initialize_gateway is called and _gateway is set
await client._initialize_gateway()
monkeypatch.setattr(client._gateway, "connect", mock_gateway_connect)
# Mock wait_until_ready to prevent it from blocking the test
monkeypatch.setattr(client, "wait_until_ready", AsyncMock())
delays = []
async def fake_sleep(d):
delays.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
# Call the client's connect method, which contains the backoff logic
await client.connect()
# Assert that GatewayClient.connect was called the correct number of times
assert mock_gateway_connect.call_count == 3
# Assert the delays experienced due to exponential backoff
assert delays == [5, 10]
@pytest.mark.asyncio
async def test_gateway_reconnect_backoff(monkeypatch):
http = DummyHTTP()
dispatcher = DummyDispatcher()
client = DummyClient()
gw = GatewayClient(
http_client=http,
event_dispatcher=dispatcher,
token="t",
intents=0,
client_instance=client,
max_retries=3,
max_backoff=10.0,
)
connect_mock = AsyncMock(
side_effect=[GatewayException("boom"), GatewayException("boom"), None]
)
monkeypatch.setattr(gw, "connect", connect_mock)
delays = []
async def fake_sleep(d):
delays.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
monkeypatch.setattr(random, "uniform", lambda a, b: 0)
await gw._reconnect()
assert connect_mock.call_count == 3
assert delays == [1.0, 2.0]