Merge pull request #26 from Slipstreamm/codex/implement-ratelimiter-in-httpclient

Merge PR #26
This commit is contained in:
Slipstream 2025-06-10 17:56:28 -06:00 committed by GitHub
commit b9bfa24511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 89 additions and 13 deletions

View File

@ -17,11 +17,13 @@ from .errors import (
DisagreementException, DisagreementException,
) )
from . import __version__ # For User-Agent from . import __version__ # For User-Agent
from .rate_limiter import RateLimiter
from .interactions import InteractionResponsePayload
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import Client from .client import Client
from .models import Message, Webhook, File from .models import Message, Webhook, File
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake from .interactions import ApplicationCommand, Snowflake
# Discord API constants # Discord API constants
API_BASE_URL = "https://discord.com/api/v10" # Using API v10 API_BASE_URL = "https://discord.com/api/v10" # Using API v10
@ -44,8 +46,7 @@ class HTTPClient:
self.verbose = verbose self.verbose = verbose
self._global_rate_limit_lock = asyncio.Event() self._rate_limiter = RateLimiter()
self._global_rate_limit_lock.set() # Initially unlocked
async def _ensure_session(self): async def _ensure_session(self):
if self._session is None or self._session.closed: if self._session is None or self._session.closed:
@ -87,10 +88,10 @@ class HTTPClient:
if self.verbose: if self.verbose:
print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}") print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
# Global rate limit handling route = f"{method.upper()}:{endpoint}"
await self._global_rate_limit_lock.wait()
for attempt in range(5): # Max 5 retries for rate limits for attempt in range(5): # Max 5 retries for rate limits
await self._rate_limiter.acquire(route)
assert self._session is not None, "ClientSession not initialized" assert self._session is not None, "ClientSession not initialized"
async with self._session.request( async with self._session.request(
method, method,
@ -120,6 +121,8 @@ class HTTPClient:
if self.verbose: if self.verbose:
print(f"HTTP RESPONSE: {response.status} {url} | {data}") print(f"HTTP RESPONSE: {response.status} {url} | {data}")
self._rate_limiter.release(route, response.headers)
if 200 <= response.status < 300: if 200 <= response.status < 300:
if response.status == 204: if response.status == 204:
return None return None
@ -142,12 +145,9 @@ class HTTPClient:
if data and isinstance(data, dict) and "message" in data: if data and isinstance(data, dict) and "message" in data:
error_message += f" Discord says: {data['message']}" error_message += f" Discord says: {data['message']}"
if is_global: await self._rate_limiter.handle_rate_limit(
self._global_rate_limit_lock.clear() route, retry_after, is_global
await asyncio.sleep(retry_after) )
self._global_rate_limit_lock.set()
else:
await asyncio.sleep(retry_after)
if attempt < 4: # Don't log on the last attempt before raising if attempt < 4: # Don't log on the last attempt before raising
print( print(
@ -657,7 +657,7 @@ class HTTPClient:
self, self,
interaction_id: "Snowflake", interaction_id: "Snowflake",
interaction_token: str, interaction_token: str,
payload: "InteractionResponsePayload", payload: Union["InteractionResponsePayload", Dict[str, Any]],
*, *,
ephemeral: bool = False, ephemeral: bool = False,
) -> None: ) -> None:
@ -670,10 +670,16 @@ class HTTPClient:
""" """
# Interaction responses do not use the bot token in the Authorization header. # Interaction responses do not use the bot token in the Authorization header.
# They are authenticated by the interaction_token in the URL. # They are authenticated by the interaction_token in the URL.
payload_data: Dict[str, Any]
if isinstance(payload, InteractionResponsePayload):
payload_data = payload.to_dict()
else:
payload_data = payload
await self.request( await self.request(
"POST", "POST",
f"/interactions/{interaction_id}/{interaction_token}/callback", f"/interactions/{interaction_id}/{interaction_token}/callback",
payload=payload.to_dict(), payload=payload_data,
use_auth_header=False, use_auth_header=False,
) )

View File

@ -0,0 +1,70 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from disagreement.http import HTTPClient
class DummyResp:
def __init__(self, status, headers=None, data=None):
self.status = status
self.headers = headers or {}
self._data = data or {}
self.headers.setdefault("Content-Type", "application/json")
async def json(self):
return self._data
async def text(self):
return str(self._data)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
@pytest.mark.asyncio
async def test_request_acquires_and_releases(monkeypatch):
http = HTTPClient(token="t")
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
resp = DummyResp(200)
http._session = SimpleNamespace(request=MagicMock(return_value=resp))
http._rate_limiter.acquire = AsyncMock()
http._rate_limiter.release = MagicMock()
http._rate_limiter.handle_rate_limit = AsyncMock()
await http.request("GET", "/a")
http._rate_limiter.acquire.assert_awaited_once_with("GET:/a")
http._rate_limiter.release.assert_called_once_with("GET:/a", resp.headers)
http._rate_limiter.handle_rate_limit.assert_not_called()
@pytest.mark.asyncio
async def test_request_handles_rate_limit(monkeypatch):
http = HTTPClient(token="t")
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
resp1 = DummyResp(
429,
{
"Retry-After": "0.1",
"X-RateLimit-Global": "false",
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset-After": "0.1",
},
{"message": "slow"},
)
resp2 = DummyResp(200, {}, {})
http._session = SimpleNamespace(request=MagicMock(side_effect=[resp1, resp2]))
http._rate_limiter.acquire = AsyncMock()
http._rate_limiter.release = MagicMock()
http._rate_limiter.handle_rate_limit = AsyncMock()
result = await http.request("GET", "/a")
assert http._rate_limiter.acquire.await_count == 2
assert http._rate_limiter.release.call_count == 2
http._rate_limiter.handle_rate_limit.assert_awaited_once_with("GET:/a", 0.1, False)
assert result == {}