diff --git a/disagreement/http.py b/disagreement/http.py index 410e37b..f350f15 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -17,11 +17,13 @@ from .errors import ( DisagreementException, ) from . import __version__ # For User-Agent +from .rate_limiter import RateLimiter +from .interactions import InteractionResponsePayload if TYPE_CHECKING: from .client import Client from .models import Message, Webhook, File - from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake + from .interactions import ApplicationCommand, Snowflake # Discord API constants API_BASE_URL = "https://discord.com/api/v10" # Using API v10 @@ -44,8 +46,7 @@ class HTTPClient: self.verbose = verbose - self._global_rate_limit_lock = asyncio.Event() - self._global_rate_limit_lock.set() # Initially unlocked + self._rate_limiter = RateLimiter() async def _ensure_session(self): if self._session is None or self._session.closed: @@ -87,10 +88,10 @@ class HTTPClient: if self.verbose: print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}") - # Global rate limit handling - await self._global_rate_limit_lock.wait() + route = f"{method.upper()}:{endpoint}" for attempt in range(5): # Max 5 retries for rate limits + await self._rate_limiter.acquire(route) assert self._session is not None, "ClientSession not initialized" async with self._session.request( method, @@ -120,6 +121,8 @@ class HTTPClient: if self.verbose: print(f"HTTP RESPONSE: {response.status} {url} | {data}") + self._rate_limiter.release(route, response.headers) + if 200 <= response.status < 300: if response.status == 204: return None @@ -142,12 +145,9 @@ class HTTPClient: if data and isinstance(data, dict) and "message" in data: error_message += f" Discord says: {data['message']}" - if is_global: - self._global_rate_limit_lock.clear() - await asyncio.sleep(retry_after) - self._global_rate_limit_lock.set() - else: - await asyncio.sleep(retry_after) + await self._rate_limiter.handle_rate_limit( + route, retry_after, is_global + ) if attempt < 4: # Don't log on the last attempt before raising print( @@ -657,7 +657,7 @@ class HTTPClient: self, interaction_id: "Snowflake", interaction_token: str, - payload: "InteractionResponsePayload", + payload: Union["InteractionResponsePayload", Dict[str, Any]], *, ephemeral: bool = False, ) -> None: @@ -670,10 +670,16 @@ class HTTPClient: """ # Interaction responses do not use the bot token in the Authorization header. # They are authenticated by the interaction_token in the URL. + payload_data: Dict[str, Any] + if isinstance(payload, InteractionResponsePayload): + payload_data = payload.to_dict() + else: + payload_data = payload + await self.request( "POST", f"/interactions/{interaction_id}/{interaction_token}/callback", - payload=payload.to_dict(), + payload=payload_data, use_auth_header=False, ) diff --git a/disagreement/interactions.py b/disagreement/interactions.py index 4dc7337..262164f 100644 --- a/disagreement/interactions.py +++ b/disagreement/interactions.py @@ -402,7 +402,7 @@ class Interaction: await self._client._http.create_interaction_response( interaction_id=self.id, interaction_token=self.token, - payload=payload, + payload=payload.to_dict(), ) async def edit( @@ -503,9 +503,7 @@ class InteractionCallbackData: self.tts: Optional[bool] = data.get("tts") self.content: Optional[str] = data.get("content") self.embeds: Optional[List[Embed]] = ( - [Embed(e) for e in data.get("embeds", [])] - if data.get("embeds") - else None + [Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None ) self.allowed_mentions: Optional[AllowedMentions] = ( AllowedMentions(data["allowed_mentions"]) diff --git a/tests/test_http_rate_limit.py b/tests/test_http_rate_limit.py new file mode 100644 index 0000000..33ba4c5 --- /dev/null +++ b/tests/test_http_rate_limit.py @@ -0,0 +1,70 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +from disagreement.http import HTTPClient + + +class DummyResp: + def __init__(self, status, headers=None, data=None): + self.status = status + self.headers = headers or {} + self._data = data or {} + self.headers.setdefault("Content-Type", "application/json") + + async def json(self): + return self._data + + async def text(self): + return str(self._data) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +@pytest.mark.asyncio +async def test_request_acquires_and_releases(monkeypatch): + http = HTTPClient(token="t") + monkeypatch.setattr(http, "_ensure_session", AsyncMock()) + resp = DummyResp(200) + http._session = SimpleNamespace(request=MagicMock(return_value=resp)) + http._rate_limiter.acquire = AsyncMock() + http._rate_limiter.release = MagicMock() + http._rate_limiter.handle_rate_limit = AsyncMock() + + await http.request("GET", "/a") + + http._rate_limiter.acquire.assert_awaited_once_with("GET:/a") + http._rate_limiter.release.assert_called_once_with("GET:/a", resp.headers) + http._rate_limiter.handle_rate_limit.assert_not_called() + + +@pytest.mark.asyncio +async def test_request_handles_rate_limit(monkeypatch): + http = HTTPClient(token="t") + monkeypatch.setattr(http, "_ensure_session", AsyncMock()) + resp1 = DummyResp( + 429, + { + "Retry-After": "0.1", + "X-RateLimit-Global": "false", + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset-After": "0.1", + }, + {"message": "slow"}, + ) + resp2 = DummyResp(200, {}, {}) + http._session = SimpleNamespace(request=MagicMock(side_effect=[resp1, resp2])) + http._rate_limiter.acquire = AsyncMock() + http._rate_limiter.release = MagicMock() + http._rate_limiter.handle_rate_limit = AsyncMock() + + result = await http.request("GET", "/a") + + assert http._rate_limiter.acquire.await_count == 2 + assert http._rate_limiter.release.call_count == 2 + http._rate_limiter.handle_rate_limit.assert_awaited_once_with("GET:/a", 0.1, False) + assert result == {}