diff --git a/disagreement/color.py b/disagreement/color.py index 59d5e8b..d1884f6 100644 --- a/disagreement/color.py +++ b/disagreement/color.py @@ -48,3 +48,31 @@ class Color: def to_rgb(self) -> tuple[int, int, int]: return ((self.value >> 16) & 0xFF, (self.value >> 8) & 0xFF, self.value & 0xFF) + + @classmethod + def parse(cls, value: "Color | int | str | tuple[int, int, int] | None") -> "Color | None": + """Convert ``value`` to a :class:`Color` instance. + + Parameters + ---------- + value: + The value to convert. May be ``None``, an existing ``Color``, an + integer in the ``0xRRGGBB`` format, or a hex string like ``"#RRGGBB"``. + + Returns + ------- + Optional[Color] + A ``Color`` object if ``value`` is not ``None``. + """ + + if value is None: + return None + if isinstance(value, cls): + return value + if isinstance(value, int): + return cls(value) + if isinstance(value, str): + return cls.from_hex(value) + if isinstance(value, tuple) and len(value) == 3: + return cls.from_rgb(*value) + raise TypeError("Color value must be Color, int, str, tuple, or None") diff --git a/disagreement/enums.py b/disagreement/enums.py index d224576..68e59c3 100644 --- a/disagreement/enums.py +++ b/disagreement/enums.py @@ -49,6 +49,11 @@ class GatewayIntent(IntEnum): AUTO_MODERATION_CONFIGURATION = 1 << 20 AUTO_MODERATION_EXECUTION = 1 << 21 + @classmethod + def none(cls) -> int: + """Return a bitmask representing no intents.""" + return 0 + @classmethod def default(cls) -> int: """Returns default intents (excluding privileged ones like members, presences, message content).""" diff --git a/disagreement/ext/tasks.py b/disagreement/ext/tasks.py index 7290c90..986e2da 100644 --- a/disagreement/ext/tasks.py +++ b/disagreement/ext/tasks.py @@ -18,6 +18,8 @@ class Task: delta: Optional[datetime.timedelta] = None, time_of_day: Optional[datetime.time] = None, on_error: Optional[Callable[[Exception], Awaitable[None]]] = None, + before_loop: Optional[Callable[[], Awaitable[None] | None]] = None, + after_loop: Optional[Callable[[], Awaitable[None] | None]] = None, ) -> None: self._coro = coro self._task: Optional[asyncio.Task[None]] = None @@ -36,6 +38,8 @@ class Task: self._seconds = float(interval_seconds) self._time_of_day = time_of_day self._on_error = on_error + self._before_loop = before_loop + self._after_loop = after_loop def _seconds_until_time(self) -> float: assert self._time_of_day is not None @@ -47,6 +51,9 @@ class Task: async def _run(self, *args: Any, **kwargs: Any) -> None: try: + if self._before_loop is not None: + await _maybe_call_no_args(self._before_loop) + first = True while True: if self._time_of_day is not None: @@ -65,6 +72,9 @@ class Task: first = False except asyncio.CancelledError: pass + finally: + if self._after_loop is not None: + await _maybe_call_no_args(self._after_loop) def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: if self._task is None or self._task.done(): @@ -89,6 +99,12 @@ async def _maybe_call( await result +async def _maybe_call_no_args(func: Callable[[], Awaitable[None] | None]) -> None: + result = func() + if asyncio.iscoroutine(result): + await result + + class _Loop: def __init__( self, @@ -110,6 +126,8 @@ class _Loop: self.on_error = on_error self._task: Optional[Task] = None self._owner: Any = None + self._before_loop: Optional[Callable[..., Awaitable[Any]]] = None + self._after_loop: Optional[Callable[..., Awaitable[Any]]] = None def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop": return _BoundLoop(self, obj) @@ -119,7 +137,33 @@ class _Loop: return self.func(*args, **kwargs) return self.func(self._owner, *args, **kwargs) + def before_loop( + self, func: Callable[..., Awaitable[Any]] + ) -> Callable[..., Awaitable[Any]]: + self._before_loop = func + return func + + def after_loop( + self, func: Callable[..., Awaitable[Any]] + ) -> Callable[..., Awaitable[Any]]: + self._after_loop = func + return func + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: + def call_before() -> Awaitable[None] | None: + if self._before_loop is None: + return None + if self._owner is not None: + return self._before_loop(self._owner) + return self._before_loop() + + def call_after() -> Awaitable[None] | None: + if self._after_loop is None: + return None + if self._owner is not None: + return self._after_loop(self._owner) + return self._after_loop() + self._task = Task( self._coro, seconds=self.seconds, @@ -128,6 +172,8 @@ class _Loop: delta=self.delta, time_of_day=self.time_of_day, on_error=self.on_error, + before_loop=call_before, + after_loop=call_after, ) return self._task.start(*args, **kwargs) diff --git a/disagreement/http.py b/disagreement/http.py index 624b4da..8b3f35a 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -17,11 +17,13 @@ from .errors import ( DisagreementException, ) from . import __version__ # For User-Agent +from .rate_limiter import RateLimiter +from .interactions import InteractionResponsePayload if TYPE_CHECKING: from .client import Client from .models import Message, Webhook, File - from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake + from .interactions import ApplicationCommand, Snowflake # Discord API constants API_BASE_URL = "https://discord.com/api/v10" # Using API v10 @@ -44,8 +46,7 @@ class HTTPClient: self.verbose = verbose - self._global_rate_limit_lock = asyncio.Event() - self._global_rate_limit_lock.set() # Initially unlocked + self._rate_limiter = RateLimiter() async def _ensure_session(self): if self._session is None or self._session.closed: @@ -87,10 +88,10 @@ class HTTPClient: if self.verbose: print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}") - # Global rate limit handling - await self._global_rate_limit_lock.wait() + route = f"{method.upper()}:{endpoint}" for attempt in range(5): # Max 5 retries for rate limits + await self._rate_limiter.acquire(route) assert self._session is not None, "ClientSession not initialized" async with self._session.request( method, @@ -120,6 +121,8 @@ class HTTPClient: if self.verbose: print(f"HTTP RESPONSE: {response.status} {url} | {data}") + self._rate_limiter.release(route, response.headers) + if 200 <= response.status < 300: if response.status == 204: return None @@ -142,12 +145,9 @@ class HTTPClient: if data and isinstance(data, dict) and "message" in data: error_message += f" Discord says: {data['message']}" - if is_global: - self._global_rate_limit_lock.clear() - await asyncio.sleep(retry_after) - self._global_rate_limit_lock.set() - else: - await asyncio.sleep(retry_after) + await self._rate_limiter.handle_rate_limit( + route, retry_after, is_global + ) if attempt < 4: # Don't log on the last attempt before raising print( @@ -667,7 +667,7 @@ class HTTPClient: self, interaction_id: "Snowflake", interaction_token: str, - payload: "InteractionResponsePayload", + payload: Union["InteractionResponsePayload", Dict[str, Any]], *, ephemeral: bool = False, ) -> None: @@ -680,10 +680,16 @@ class HTTPClient: """ # Interaction responses do not use the bot token in the Authorization header. # They are authenticated by the interaction_token in the URL. + payload_data: Dict[str, Any] + if isinstance(payload, InteractionResponsePayload): + payload_data = payload.to_dict() + else: + payload_data = payload + await self.request( "POST", f"/interactions/{interaction_id}/{interaction_token}/callback", - payload=payload.to_dict(), + payload=payload_data, use_auth_header=False, ) diff --git a/disagreement/interactions.py b/disagreement/interactions.py index 4dc7337..4b89b6b 100644 --- a/disagreement/interactions.py +++ b/disagreement/interactions.py @@ -402,7 +402,7 @@ class Interaction: await self._client._http.create_interaction_response( interaction_id=self.id, interaction_token=self.token, - payload=payload, + payload=payload.to_dict(), # type: ignore[arg-type] ) async def edit( @@ -503,9 +503,7 @@ class InteractionCallbackData: self.tts: Optional[bool] = data.get("tts") self.content: Optional[str] = data.get("content") self.embeds: Optional[List[Embed]] = ( - [Embed(e) for e in data.get("embeds", [])] - if data.get("embeds") - else None + [Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None ) self.allowed_mentions: Optional[AllowedMentions] = ( AllowedMentions(data["allowed_mentions"]) @@ -572,3 +570,6 @@ class InteractionResponsePayload: def __repr__(self) -> str: return f"" + + def __getitem__(self, item: str) -> Any: + return self.to_dict()[item] diff --git a/disagreement/models.py b/disagreement/models.py index c3e4520..d462ac8 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -25,6 +25,7 @@ from .enums import ( # These enums will need to be defined in disagreement/enum # SelectMenuType will be part of ComponentType or a new enum if needed ) from .permissions import Permissions +from .color import Color if TYPE_CHECKING: @@ -327,7 +328,7 @@ class Embed: self.description: Optional[str] = data.get("description") self.url: Optional[str] = data.get("url") self.timestamp: Optional[str] = data.get("timestamp") # ISO8601 timestamp - self.color: Optional[int] = data.get("color") + self.color = Color.parse(data.get("color")) self.footer: Optional[EmbedFooter] = ( EmbedFooter(data["footer"]) if data.get("footer") else None @@ -357,7 +358,7 @@ class Embed: if self.timestamp: payload["timestamp"] = self.timestamp if self.color is not None: - payload["color"] = self.color + payload["color"] = self.color.value if self.footer: payload["footer"] = self.footer.to_dict() if self.image: @@ -1122,6 +1123,36 @@ class TextChannel(Channel): self._client._messages.pop(mid, None) return ids + async def history( + self, + *, + limit: Optional[int] = 100, + before: "Snowflake | None" = None, + ): + """An async iterator over messages in the channel.""" + + params: Dict[str, Union[int, str]] = {} + if before is not None: + params["before"] = before + + fetched = 0 + while True: + to_fetch = 100 if limit is None else min(100, limit - fetched) + if to_fetch <= 0: + break + params["limit"] = to_fetch + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params.copy() + ) + if not messages: + break + params["before"] = messages[-1]["id"] + for msg in messages: + yield Message(msg, self._client) + fetched += 1 + if limit is not None and fetched >= limit: + return + def __repr__(self) -> str: return f"" @@ -1239,6 +1270,36 @@ class DMChannel(Channel): components=components, ) + async def history( + self, + *, + limit: Optional[int] = 100, + before: "Snowflake | None" = None, + ): + """An async iterator over messages in this DM.""" + + params: Dict[str, Union[int, str]] = {} + if before is not None: + params["before"] = before + + fetched = 0 + while True: + to_fetch = 100 if limit is None else min(100, limit - fetched) + if to_fetch <= 0: + break + params["limit"] = to_fetch + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params.copy() + ) + if not messages: + break + params["before"] = messages[-1]["id"] + for msg in messages: + yield Message(msg, self._client) + fetched += 1 + if limit is not None and fetched >= limit: + return + def __repr__(self) -> str: recipient_repr = self.recipient.username if self.recipient else "Unknown" return f"" @@ -1723,13 +1784,13 @@ class Container(Component): def __init__( self, components: List[Component], - accent_color: Optional[int] = None, + accent_color: Color | int | str | None = None, spoiler: bool = False, id: Optional[int] = None, ): super().__init__(ComponentType.CONTAINER) self.components = components - self.accent_color = accent_color + self.accent_color = Color.parse(accent_color) self.spoiler = spoiler self.id = id @@ -1737,7 +1798,7 @@ class Container(Component): payload = super().to_dict() payload["components"] = [c.to_dict() for c in self.components] if self.accent_color: - payload["accent_color"] = self.accent_color + payload["accent_color"] = self.accent_color.value if self.spoiler: payload["spoiler"] = self.spoiler if self.id is not None: diff --git a/docs/gateway.md b/docs/gateway.md index 2e2444f..910a0ad 100644 --- a/docs/gateway.md +++ b/docs/gateway.md @@ -16,3 +16,9 @@ bot = Client( ``` These values are passed to `GatewayClient` and applied whenever the connection needs to be re-established. + +## Gateway Intents + +`GatewayIntent` values control which events your bot receives from the Gateway. Use +`GatewayIntent.none()` to opt out of all events entirely. It returns `0`, which +represents a bitmask with no intents enabled. diff --git a/docs/task_loop.md b/docs/task_loop.md index b054909..c60a689 100644 --- a/docs/task_loop.md +++ b/docs/task_loop.md @@ -35,6 +35,22 @@ async def worker(): ... ``` +Run setup and teardown code using `before_loop` and `after_loop`: + +```python +@tasks.loop(seconds=5.0) +async def worker(): + ... + +@worker.before_loop +async def before_worker(): + print("starting") + +@worker.after_loop +async def after_worker(): + print("stopped") +``` + You can also schedule a task at a specific time of day: ```python diff --git a/docs/using_components.md b/docs/using_components.md index 9ebd65a..ac508b0 100644 --- a/docs/using_components.md +++ b/docs/using_components.md @@ -152,7 +152,7 @@ from disagreement.models import Container, TextDisplay container = Container( components=[TextDisplay(content="Inside a container")], - accent_color=0xFF0000, + accent_color="#FF0000", # int or Color() also work spoiler=False, ) ``` diff --git a/tests/test_color_acceptance.py b/tests/test_color_acceptance.py new file mode 100644 index 0000000..f495a1d --- /dev/null +++ b/tests/test_color_acceptance.py @@ -0,0 +1,29 @@ +from disagreement.color import Color +from disagreement.models import Embed, Container, Component + + +def test_color_parse(): + assert Color.parse(0x123456).value == 0x123456 + assert Color.parse("#123456").value == 0x123456 + c = Color(0xABCDEF) + assert Color.parse(c) is c + assert Color.parse(None) is None + assert Color.parse((255, 0, 0)).value == 0xFF0000 + + +def test_embed_color_parsing(): + e = Embed({"color": "#FF0000"}) + assert e.color.value == 0xFF0000 + e = Embed({"color": Color(0x00FF00)}) + assert e.color.value == 0x00FF00 + e = Embed({"color": 0x0000FF}) + assert e.color.value == 0x0000FF + + +def test_container_accent_color_parsing(): + container = Container(components=[], accent_color="#010203") + assert container.accent_color.value == 0x010203 + container = Container(components=[], accent_color=Color(0x111111)) + assert container.accent_color.value == 0x111111 + container = Container(components=[], accent_color=0x222222) + assert container.accent_color.value == 0x222222 diff --git a/tests/test_gateway_intent.py b/tests/test_gateway_intent.py new file mode 100644 index 0000000..f1714ca --- /dev/null +++ b/tests/test_gateway_intent.py @@ -0,0 +1,7 @@ +import pytest + +from disagreement.enums import GatewayIntent + + +def test_gateway_intent_none_equals_zero(): + assert GatewayIntent.none() == 0 diff --git a/tests/test_http_rate_limit.py b/tests/test_http_rate_limit.py new file mode 100644 index 0000000..33ba4c5 --- /dev/null +++ b/tests/test_http_rate_limit.py @@ -0,0 +1,70 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +from disagreement.http import HTTPClient + + +class DummyResp: + def __init__(self, status, headers=None, data=None): + self.status = status + self.headers = headers or {} + self._data = data or {} + self.headers.setdefault("Content-Type", "application/json") + + async def json(self): + return self._data + + async def text(self): + return str(self._data) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +@pytest.mark.asyncio +async def test_request_acquires_and_releases(monkeypatch): + http = HTTPClient(token="t") + monkeypatch.setattr(http, "_ensure_session", AsyncMock()) + resp = DummyResp(200) + http._session = SimpleNamespace(request=MagicMock(return_value=resp)) + http._rate_limiter.acquire = AsyncMock() + http._rate_limiter.release = MagicMock() + http._rate_limiter.handle_rate_limit = AsyncMock() + + await http.request("GET", "/a") + + http._rate_limiter.acquire.assert_awaited_once_with("GET:/a") + http._rate_limiter.release.assert_called_once_with("GET:/a", resp.headers) + http._rate_limiter.handle_rate_limit.assert_not_called() + + +@pytest.mark.asyncio +async def test_request_handles_rate_limit(monkeypatch): + http = HTTPClient(token="t") + monkeypatch.setattr(http, "_ensure_session", AsyncMock()) + resp1 = DummyResp( + 429, + { + "Retry-After": "0.1", + "X-RateLimit-Global": "false", + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset-After": "0.1", + }, + {"message": "slow"}, + ) + resp2 = DummyResp(200, {}, {}) + http._session = SimpleNamespace(request=MagicMock(side_effect=[resp1, resp2])) + http._rate_limiter.acquire = AsyncMock() + http._rate_limiter.release = MagicMock() + http._rate_limiter.handle_rate_limit = AsyncMock() + + result = await http.request("GET", "/a") + + assert http._rate_limiter.acquire.await_count == 2 + assert http._rate_limiter.release.call_count == 2 + http._rate_limiter.handle_rate_limit.assert_awaited_once_with("GET:/a", 0.1, False) + assert result == {} diff --git a/tests/test_modals.py b/tests/test_modals.py index d21a503..697cb8b 100644 --- a/tests/test_modals.py +++ b/tests/test_modals.py @@ -28,6 +28,5 @@ async def test_respond_modal(dummy_bot, interaction): await interaction.respond_modal(modal) dummy_bot._http.create_interaction_response.assert_called_once() payload = dummy_bot._http.create_interaction_response.call_args.kwargs["payload"] - payload_dict = payload.to_dict() - assert payload_dict["type"] == InteractionCallbackType.MODAL.value - assert payload_dict["data"]["custom_id"] == "m1" + assert payload.type == InteractionCallbackType.MODAL + assert payload.data["custom_id"] == "m1" diff --git a/tests/test_tasks_extension.py b/tests/test_tasks_extension.py index 19f6f64..4b2abfd 100644 --- a/tests/test_tasks_extension.py +++ b/tests/test_tasks_extension.py @@ -58,3 +58,27 @@ async def test_loop_runs_and_stops() -> None: dummy.work.stop() # pylint: disable=no-member assert dummy.count >= 2 assert not dummy.work.running # pylint: disable=no-member + + +@pytest.mark.asyncio +async def test_before_after_loop_callbacks() -> None: + events: list[str] = [] + + @tasks.loop(seconds=0.01) + async def ticker() -> None: + events.append("tick") + + @ticker.before_loop + async def before() -> None: # pragma: no cover - trivial callback + events.append("before") + + @ticker.after_loop + async def after() -> None: # pragma: no cover - trivial callback + events.append("after") + + ticker.start() + await asyncio.sleep(0.03) + ticker.stop() + await asyncio.sleep(0.01) + assert events and events[0] == "before" + assert "after" in events diff --git a/tests/test_textchannel_history.py b/tests/test_textchannel_history.py new file mode 100644 index 0000000..4dfc2c3 --- /dev/null +++ b/tests/test_textchannel_history.py @@ -0,0 +1,65 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.client import Client +from disagreement.models import TextChannel, Message + + +@pytest.mark.asyncio +async def test_textchannel_history_paginates(): + first_page = [ + { + "id": "3", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m3", + "timestamp": "t", + }, + { + "id": "2", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m2", + "timestamp": "t", + }, + ] + second_page = [ + { + "id": "1", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m1", + "timestamp": "t", + } + ] + http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page])) + client = Client.__new__(Client) + client._http = http + channel = TextChannel({"id": "c", "type": 0}, client) + + messages = [] + async for msg in channel.history(limit=3): + messages.append(msg) + + assert len(messages) == 3 + assert all(isinstance(m, Message) for m in messages) + http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3}) + http.request.assert_any_call( + "GET", "/channels/c/messages", params={"limit": 1, "before": "2"} + ) + + +@pytest.mark.asyncio +async def test_textchannel_history_before_param(): + http = SimpleNamespace(request=AsyncMock(return_value=[])) + client = Client.__new__(Client) + client._http = http + channel = TextChannel({"id": "c", "type": 0}, client) + + messages = [m async for m in channel.history(limit=1, before="b")] + + assert messages == [] + http.request.assert_called_once_with( + "GET", "/channels/c/messages", params={"limit": 1, "before": "b"} + )