Add rate limiter integration with HTTPClient
This commit is contained in:
parent
534b5b3980
commit
e55e963a59
@ -17,11 +17,13 @@ from .errors import (
|
|||||||
DisagreementException,
|
DisagreementException,
|
||||||
)
|
)
|
||||||
from . import __version__ # For User-Agent
|
from . import __version__ # For User-Agent
|
||||||
|
from .rate_limiter import RateLimiter
|
||||||
|
from .interactions import InteractionResponsePayload
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import Client
|
from .client import Client
|
||||||
from .models import Message, Webhook, File
|
from .models import Message, Webhook, File
|
||||||
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
|
from .interactions import ApplicationCommand, Snowflake
|
||||||
|
|
||||||
# Discord API constants
|
# Discord API constants
|
||||||
API_BASE_URL = "https://discord.com/api/v10" # Using API v10
|
API_BASE_URL = "https://discord.com/api/v10" # Using API v10
|
||||||
@ -44,8 +46,7 @@ class HTTPClient:
|
|||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
self._global_rate_limit_lock = asyncio.Event()
|
self._rate_limiter = RateLimiter()
|
||||||
self._global_rate_limit_lock.set() # Initially unlocked
|
|
||||||
|
|
||||||
async def _ensure_session(self):
|
async def _ensure_session(self):
|
||||||
if self._session is None or self._session.closed:
|
if self._session is None or self._session.closed:
|
||||||
@ -87,10 +88,10 @@ class HTTPClient:
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
|
print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
|
||||||
|
|
||||||
# Global rate limit handling
|
route = f"{method.upper()}:{endpoint}"
|
||||||
await self._global_rate_limit_lock.wait()
|
|
||||||
|
|
||||||
for attempt in range(5): # Max 5 retries for rate limits
|
for attempt in range(5): # Max 5 retries for rate limits
|
||||||
|
await self._rate_limiter.acquire(route)
|
||||||
assert self._session is not None, "ClientSession not initialized"
|
assert self._session is not None, "ClientSession not initialized"
|
||||||
async with self._session.request(
|
async with self._session.request(
|
||||||
method,
|
method,
|
||||||
@ -120,6 +121,8 @@ class HTTPClient:
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"HTTP RESPONSE: {response.status} {url} | {data}")
|
print(f"HTTP RESPONSE: {response.status} {url} | {data}")
|
||||||
|
|
||||||
|
self._rate_limiter.release(route, response.headers)
|
||||||
|
|
||||||
if 200 <= response.status < 300:
|
if 200 <= response.status < 300:
|
||||||
if response.status == 204:
|
if response.status == 204:
|
||||||
return None
|
return None
|
||||||
@ -142,12 +145,9 @@ class HTTPClient:
|
|||||||
if data and isinstance(data, dict) and "message" in data:
|
if data and isinstance(data, dict) and "message" in data:
|
||||||
error_message += f" Discord says: {data['message']}"
|
error_message += f" Discord says: {data['message']}"
|
||||||
|
|
||||||
if is_global:
|
await self._rate_limiter.handle_rate_limit(
|
||||||
self._global_rate_limit_lock.clear()
|
route, retry_after, is_global
|
||||||
await asyncio.sleep(retry_after)
|
)
|
||||||
self._global_rate_limit_lock.set()
|
|
||||||
else:
|
|
||||||
await asyncio.sleep(retry_after)
|
|
||||||
|
|
||||||
if attempt < 4: # Don't log on the last attempt before raising
|
if attempt < 4: # Don't log on the last attempt before raising
|
||||||
print(
|
print(
|
||||||
@ -657,7 +657,7 @@ class HTTPClient:
|
|||||||
self,
|
self,
|
||||||
interaction_id: "Snowflake",
|
interaction_id: "Snowflake",
|
||||||
interaction_token: str,
|
interaction_token: str,
|
||||||
payload: "InteractionResponsePayload",
|
payload: Union["InteractionResponsePayload", Dict[str, Any]],
|
||||||
*,
|
*,
|
||||||
ephemeral: bool = False,
|
ephemeral: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -670,10 +670,16 @@ class HTTPClient:
|
|||||||
"""
|
"""
|
||||||
# Interaction responses do not use the bot token in the Authorization header.
|
# Interaction responses do not use the bot token in the Authorization header.
|
||||||
# They are authenticated by the interaction_token in the URL.
|
# They are authenticated by the interaction_token in the URL.
|
||||||
|
payload_data: Dict[str, Any]
|
||||||
|
if isinstance(payload, InteractionResponsePayload):
|
||||||
|
payload_data = payload.to_dict()
|
||||||
|
else:
|
||||||
|
payload_data = payload
|
||||||
|
|
||||||
await self.request(
|
await self.request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/interactions/{interaction_id}/{interaction_token}/callback",
|
f"/interactions/{interaction_id}/{interaction_token}/callback",
|
||||||
payload=payload.to_dict(),
|
payload=payload_data,
|
||||||
use_auth_header=False,
|
use_auth_header=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -402,7 +402,7 @@ class Interaction:
|
|||||||
await self._client._http.create_interaction_response(
|
await self._client._http.create_interaction_response(
|
||||||
interaction_id=self.id,
|
interaction_id=self.id,
|
||||||
interaction_token=self.token,
|
interaction_token=self.token,
|
||||||
payload=payload,
|
payload=payload.to_dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def edit(
|
async def edit(
|
||||||
@ -503,9 +503,7 @@ class InteractionCallbackData:
|
|||||||
self.tts: Optional[bool] = data.get("tts")
|
self.tts: Optional[bool] = data.get("tts")
|
||||||
self.content: Optional[str] = data.get("content")
|
self.content: Optional[str] = data.get("content")
|
||||||
self.embeds: Optional[List[Embed]] = (
|
self.embeds: Optional[List[Embed]] = (
|
||||||
[Embed(e) for e in data.get("embeds", [])]
|
[Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None
|
||||||
if data.get("embeds")
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
self.allowed_mentions: Optional[AllowedMentions] = (
|
self.allowed_mentions: Optional[AllowedMentions] = (
|
||||||
AllowedMentions(data["allowed_mentions"])
|
AllowedMentions(data["allowed_mentions"])
|
||||||
|
70
tests/test_http_rate_limit.py
Normal file
70
tests/test_http_rate_limit.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import pytest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from disagreement.http import HTTPClient
|
||||||
|
|
||||||
|
|
||||||
|
class DummyResp:
|
||||||
|
def __init__(self, status, headers=None, data=None):
|
||||||
|
self.status = status
|
||||||
|
self.headers = headers or {}
|
||||||
|
self._data = data or {}
|
||||||
|
self.headers.setdefault("Content-Type", "application/json")
|
||||||
|
|
||||||
|
async def json(self):
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
async def text(self):
|
||||||
|
return str(self._data)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_acquires_and_releases(monkeypatch):
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
|
||||||
|
resp = DummyResp(200)
|
||||||
|
http._session = SimpleNamespace(request=MagicMock(return_value=resp))
|
||||||
|
http._rate_limiter.acquire = AsyncMock()
|
||||||
|
http._rate_limiter.release = MagicMock()
|
||||||
|
http._rate_limiter.handle_rate_limit = AsyncMock()
|
||||||
|
|
||||||
|
await http.request("GET", "/a")
|
||||||
|
|
||||||
|
http._rate_limiter.acquire.assert_awaited_once_with("GET:/a")
|
||||||
|
http._rate_limiter.release.assert_called_once_with("GET:/a", resp.headers)
|
||||||
|
http._rate_limiter.handle_rate_limit.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_handles_rate_limit(monkeypatch):
|
||||||
|
http = HTTPClient(token="t")
|
||||||
|
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
|
||||||
|
resp1 = DummyResp(
|
||||||
|
429,
|
||||||
|
{
|
||||||
|
"Retry-After": "0.1",
|
||||||
|
"X-RateLimit-Global": "false",
|
||||||
|
"X-RateLimit-Remaining": "0",
|
||||||
|
"X-RateLimit-Reset-After": "0.1",
|
||||||
|
},
|
||||||
|
{"message": "slow"},
|
||||||
|
)
|
||||||
|
resp2 = DummyResp(200, {}, {})
|
||||||
|
http._session = SimpleNamespace(request=MagicMock(side_effect=[resp1, resp2]))
|
||||||
|
http._rate_limiter.acquire = AsyncMock()
|
||||||
|
http._rate_limiter.release = MagicMock()
|
||||||
|
http._rate_limiter.handle_rate_limit = AsyncMock()
|
||||||
|
|
||||||
|
result = await http.request("GET", "/a")
|
||||||
|
|
||||||
|
assert http._rate_limiter.acquire.await_count == 2
|
||||||
|
assert http._rate_limiter.release.call_count == 2
|
||||||
|
http._rate_limiter.handle_rate_limit.assert_awaited_once_with("GET:/a", 0.1, False)
|
||||||
|
assert result == {}
|
Loading…
x
Reference in New Issue
Block a user