Add rate limiter integration with HTTPClient

This commit is contained in:
Slipstream 2025-06-10 17:51:58 -06:00
parent 534b5b3980
commit e55e963a59
3 changed files with 91 additions and 17 deletions

View File

@ -17,11 +17,13 @@ from .errors import (
DisagreementException,
)
from . import __version__ # For User-Agent
from .rate_limiter import RateLimiter
from .interactions import InteractionResponsePayload
if TYPE_CHECKING:
from .client import Client
from .models import Message, Webhook, File
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
from .interactions import ApplicationCommand, Snowflake
# Discord API constants
API_BASE_URL = "https://discord.com/api/v10" # Using API v10
@ -44,8 +46,7 @@ class HTTPClient:
self.verbose = verbose
self._global_rate_limit_lock = asyncio.Event()
self._global_rate_limit_lock.set() # Initially unlocked
self._rate_limiter = RateLimiter()
async def _ensure_session(self):
if self._session is None or self._session.closed:
@ -87,10 +88,10 @@ class HTTPClient:
if self.verbose:
print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
# Global rate limit handling
await self._global_rate_limit_lock.wait()
route = f"{method.upper()}:{endpoint}"
for attempt in range(5): # Max 5 retries for rate limits
await self._rate_limiter.acquire(route)
assert self._session is not None, "ClientSession not initialized"
async with self._session.request(
method,
@ -120,6 +121,8 @@ class HTTPClient:
if self.verbose:
print(f"HTTP RESPONSE: {response.status} {url} | {data}")
self._rate_limiter.release(route, response.headers)
if 200 <= response.status < 300:
if response.status == 204:
return None
@ -142,12 +145,9 @@ class HTTPClient:
if data and isinstance(data, dict) and "message" in data:
error_message += f" Discord says: {data['message']}"
if is_global:
self._global_rate_limit_lock.clear()
await asyncio.sleep(retry_after)
self._global_rate_limit_lock.set()
else:
await asyncio.sleep(retry_after)
await self._rate_limiter.handle_rate_limit(
route, retry_after, is_global
)
if attempt < 4: # Don't log on the last attempt before raising
print(
@ -657,7 +657,7 @@ class HTTPClient:
self,
interaction_id: "Snowflake",
interaction_token: str,
payload: "InteractionResponsePayload",
payload: Union["InteractionResponsePayload", Dict[str, Any]],
*,
ephemeral: bool = False,
) -> None:
@ -670,10 +670,16 @@ class HTTPClient:
"""
# Interaction responses do not use the bot token in the Authorization header.
# They are authenticated by the interaction_token in the URL.
payload_data: Dict[str, Any]
if isinstance(payload, InteractionResponsePayload):
payload_data = payload.to_dict()
else:
payload_data = payload
await self.request(
"POST",
f"/interactions/{interaction_id}/{interaction_token}/callback",
payload=payload.to_dict(),
payload=payload_data,
use_auth_header=False,
)

View File

@ -402,7 +402,7 @@ class Interaction:
await self._client._http.create_interaction_response(
interaction_id=self.id,
interaction_token=self.token,
payload=payload,
payload=payload.to_dict(),
)
async def edit(
@ -503,9 +503,7 @@ class InteractionCallbackData:
self.tts: Optional[bool] = data.get("tts")
self.content: Optional[str] = data.get("content")
self.embeds: Optional[List[Embed]] = (
[Embed(e) for e in data.get("embeds", [])]
if data.get("embeds")
else None
[Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None
)
self.allowed_mentions: Optional[AllowedMentions] = (
AllowedMentions(data["allowed_mentions"])

View File

@ -0,0 +1,70 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from disagreement.http import HTTPClient
class DummyResp:
def __init__(self, status, headers=None, data=None):
self.status = status
self.headers = headers or {}
self._data = data or {}
self.headers.setdefault("Content-Type", "application/json")
async def json(self):
return self._data
async def text(self):
return str(self._data)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
@pytest.mark.asyncio
async def test_request_acquires_and_releases(monkeypatch):
http = HTTPClient(token="t")
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
resp = DummyResp(200)
http._session = SimpleNamespace(request=MagicMock(return_value=resp))
http._rate_limiter.acquire = AsyncMock()
http._rate_limiter.release = MagicMock()
http._rate_limiter.handle_rate_limit = AsyncMock()
await http.request("GET", "/a")
http._rate_limiter.acquire.assert_awaited_once_with("GET:/a")
http._rate_limiter.release.assert_called_once_with("GET:/a", resp.headers)
http._rate_limiter.handle_rate_limit.assert_not_called()
@pytest.mark.asyncio
async def test_request_handles_rate_limit(monkeypatch):
http = HTTPClient(token="t")
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
resp1 = DummyResp(
429,
{
"Retry-After": "0.1",
"X-RateLimit-Global": "false",
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset-After": "0.1",
},
{"message": "slow"},
)
resp2 = DummyResp(200, {}, {})
http._session = SimpleNamespace(request=MagicMock(side_effect=[resp1, resp2]))
http._rate_limiter.acquire = AsyncMock()
http._rate_limiter.release = MagicMock()
http._rate_limiter.handle_rate_limit = AsyncMock()
result = await http.request("GET", "/a")
assert http._rate_limiter.acquire.await_count == 2
assert http._rate_limiter.release.call_count == 2
http._rate_limiter.handle_rate_limit.assert_awaited_once_with("GET:/a", 0.1, False)
assert result == {}