Merge pull request #26 from Slipstreamm/codex/implement-ratelimiter-in-httpclient
Merge PR #26
This commit is contained in:
commit
b9bfa24511
@ -17,11 +17,13 @@ from .errors import (
|
||||
DisagreementException,
|
||||
)
|
||||
from . import __version__ # For User-Agent
|
||||
from .rate_limiter import RateLimiter
|
||||
from .interactions import InteractionResponsePayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
from .models import Message, Webhook, File
|
||||
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
|
||||
from .interactions import ApplicationCommand, Snowflake
|
||||
|
||||
# Discord API constants
|
||||
API_BASE_URL = "https://discord.com/api/v10" # Using API v10
|
||||
@ -44,8 +46,7 @@ class HTTPClient:
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
self._global_rate_limit_lock = asyncio.Event()
|
||||
self._global_rate_limit_lock.set() # Initially unlocked
|
||||
self._rate_limiter = RateLimiter()
|
||||
|
||||
async def _ensure_session(self):
|
||||
if self._session is None or self._session.closed:
|
||||
@ -87,10 +88,10 @@ class HTTPClient:
|
||||
if self.verbose:
|
||||
print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
|
||||
|
||||
# Global rate limit handling
|
||||
await self._global_rate_limit_lock.wait()
|
||||
route = f"{method.upper()}:{endpoint}"
|
||||
|
||||
for attempt in range(5): # Max 5 retries for rate limits
|
||||
await self._rate_limiter.acquire(route)
|
||||
assert self._session is not None, "ClientSession not initialized"
|
||||
async with self._session.request(
|
||||
method,
|
||||
@ -120,6 +121,8 @@ class HTTPClient:
|
||||
if self.verbose:
|
||||
print(f"HTTP RESPONSE: {response.status} {url} | {data}")
|
||||
|
||||
self._rate_limiter.release(route, response.headers)
|
||||
|
||||
if 200 <= response.status < 300:
|
||||
if response.status == 204:
|
||||
return None
|
||||
@ -142,12 +145,9 @@ class HTTPClient:
|
||||
if data and isinstance(data, dict) and "message" in data:
|
||||
error_message += f" Discord says: {data['message']}"
|
||||
|
||||
if is_global:
|
||||
self._global_rate_limit_lock.clear()
|
||||
await asyncio.sleep(retry_after)
|
||||
self._global_rate_limit_lock.set()
|
||||
else:
|
||||
await asyncio.sleep(retry_after)
|
||||
await self._rate_limiter.handle_rate_limit(
|
||||
route, retry_after, is_global
|
||||
)
|
||||
|
||||
if attempt < 4: # Don't log on the last attempt before raising
|
||||
print(
|
||||
@ -657,7 +657,7 @@ class HTTPClient:
|
||||
self,
|
||||
interaction_id: "Snowflake",
|
||||
interaction_token: str,
|
||||
payload: "InteractionResponsePayload",
|
||||
payload: Union["InteractionResponsePayload", Dict[str, Any]],
|
||||
*,
|
||||
ephemeral: bool = False,
|
||||
) -> None:
|
||||
@ -670,10 +670,16 @@ class HTTPClient:
|
||||
"""
|
||||
# Interaction responses do not use the bot token in the Authorization header.
|
||||
# They are authenticated by the interaction_token in the URL.
|
||||
payload_data: Dict[str, Any]
|
||||
if isinstance(payload, InteractionResponsePayload):
|
||||
payload_data = payload.to_dict()
|
||||
else:
|
||||
payload_data = payload
|
||||
|
||||
await self.request(
|
||||
"POST",
|
||||
f"/interactions/{interaction_id}/{interaction_token}/callback",
|
||||
payload=payload.to_dict(),
|
||||
payload=payload_data,
|
||||
use_auth_header=False,
|
||||
)
|
||||
|
||||
|
70
tests/test_http_rate_limit.py
Normal file
70
tests/test_http_rate_limit.py
Normal file
@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from disagreement.http import HTTPClient
|
||||
|
||||
|
||||
class DummyResp:
|
||||
def __init__(self, status, headers=None, data=None):
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self._data = data or {}
|
||||
self.headers.setdefault("Content-Type", "application/json")
|
||||
|
||||
async def json(self):
|
||||
return self._data
|
||||
|
||||
async def text(self):
|
||||
return str(self._data)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_acquires_and_releases(monkeypatch):
|
||||
http = HTTPClient(token="t")
|
||||
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
|
||||
resp = DummyResp(200)
|
||||
http._session = SimpleNamespace(request=MagicMock(return_value=resp))
|
||||
http._rate_limiter.acquire = AsyncMock()
|
||||
http._rate_limiter.release = MagicMock()
|
||||
http._rate_limiter.handle_rate_limit = AsyncMock()
|
||||
|
||||
await http.request("GET", "/a")
|
||||
|
||||
http._rate_limiter.acquire.assert_awaited_once_with("GET:/a")
|
||||
http._rate_limiter.release.assert_called_once_with("GET:/a", resp.headers)
|
||||
http._rate_limiter.handle_rate_limit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_handles_rate_limit(monkeypatch):
|
||||
http = HTTPClient(token="t")
|
||||
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
|
||||
resp1 = DummyResp(
|
||||
429,
|
||||
{
|
||||
"Retry-After": "0.1",
|
||||
"X-RateLimit-Global": "false",
|
||||
"X-RateLimit-Remaining": "0",
|
||||
"X-RateLimit-Reset-After": "0.1",
|
||||
},
|
||||
{"message": "slow"},
|
||||
)
|
||||
resp2 = DummyResp(200, {}, {})
|
||||
http._session = SimpleNamespace(request=MagicMock(side_effect=[resp1, resp2]))
|
||||
http._rate_limiter.acquire = AsyncMock()
|
||||
http._rate_limiter.release = MagicMock()
|
||||
http._rate_limiter.handle_rate_limit = AsyncMock()
|
||||
|
||||
result = await http.request("GET", "/a")
|
||||
|
||||
assert http._rate_limiter.acquire.await_count == 2
|
||||
assert http._rate_limiter.release.call_count == 2
|
||||
http._rate_limiter.handle_rate_limit.assert_awaited_once_with("GET:/a", 0.1, False)
|
||||
assert result == {}
|
Loading…
x
Reference in New Issue
Block a user