From 39b05bc958c170b3bbb33ce312c71871812d9581 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Tue, 10 Jun 2025 17:53:14 -0600 Subject: [PATCH] Add message history iterator --- disagreement/interactions.py | 6 +-- disagreement/models.py | 60 ++++++++++++++++++++++++++++ tests/test_textchannel_history.py | 65 +++++++++++++++++++++++++++++++ 3 files changed, 127 insertions(+), 4 deletions(-) create mode 100644 tests/test_textchannel_history.py diff --git a/disagreement/interactions.py b/disagreement/interactions.py index 4dc7337..eeaac57 100644 --- a/disagreement/interactions.py +++ b/disagreement/interactions.py @@ -402,7 +402,7 @@ class Interaction: await self._client._http.create_interaction_response( interaction_id=self.id, interaction_token=self.token, - payload=payload, + payload=payload.to_dict(), # type: ignore[arg-type] ) async def edit( @@ -503,9 +503,7 @@ class InteractionCallbackData: self.tts: Optional[bool] = data.get("tts") self.content: Optional[str] = data.get("content") self.embeds: Optional[List[Embed]] = ( - [Embed(e) for e in data.get("embeds", [])] - if data.get("embeds") - else None + [Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None ) self.allowed_mentions: Optional[AllowedMentions] = ( AllowedMentions(data["allowed_mentions"]) diff --git a/disagreement/models.py b/disagreement/models.py index 6bc5d39..36852a3 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1107,6 +1107,36 @@ class TextChannel(Channel): self._client._messages.pop(mid, None) return ids + async def history( + self, + *, + limit: Optional[int] = 100, + before: "Snowflake | None" = None, + ): + """An async iterator over messages in the channel.""" + + params: Dict[str, Union[int, str]] = {} + if before is not None: + params["before"] = before + + fetched = 0 + while True: + to_fetch = 100 if limit is None else min(100, limit - fetched) + if to_fetch <= 0: + break + params["limit"] = to_fetch + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params.copy() + ) + if not messages: + break + params["before"] = messages[-1]["id"] + for msg in messages: + yield Message(msg, self._client) + fetched += 1 + if limit is not None and fetched >= limit: + return + def __repr__(self) -> str: return f"" @@ -1224,6 +1254,36 @@ class DMChannel(Channel): components=components, ) + async def history( + self, + *, + limit: Optional[int] = 100, + before: "Snowflake | None" = None, + ): + """An async iterator over messages in this DM.""" + + params: Dict[str, Union[int, str]] = {} + if before is not None: + params["before"] = before + + fetched = 0 + while True: + to_fetch = 100 if limit is None else min(100, limit - fetched) + if to_fetch <= 0: + break + params["limit"] = to_fetch + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params.copy() + ) + if not messages: + break + params["before"] = messages[-1]["id"] + for msg in messages: + yield Message(msg, self._client) + fetched += 1 + if limit is not None and fetched >= limit: + return + def __repr__(self) -> str: recipient_repr = self.recipient.username if self.recipient else "Unknown" return f"" diff --git a/tests/test_textchannel_history.py b/tests/test_textchannel_history.py new file mode 100644 index 0000000..4dfc2c3 --- /dev/null +++ b/tests/test_textchannel_history.py @@ -0,0 +1,65 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.client import Client +from disagreement.models import TextChannel, Message + + +@pytest.mark.asyncio +async def test_textchannel_history_paginates(): + first_page = [ + { + "id": "3", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m3", + "timestamp": "t", + }, + { + "id": "2", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m2", + "timestamp": "t", + }, + ] + second_page = [ + { + "id": "1", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m1", + "timestamp": "t", + } + ] + http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page])) + client = Client.__new__(Client) + client._http = http + channel = TextChannel({"id": "c", "type": 0}, client) + + messages = [] + async for msg in channel.history(limit=3): + messages.append(msg) + + assert len(messages) == 3 + assert all(isinstance(m, Message) for m in messages) + http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3}) + http.request.assert_any_call( + "GET", "/channels/c/messages", params={"limit": 1, "before": "2"} + ) + + +@pytest.mark.asyncio +async def test_textchannel_history_before_param(): + http = SimpleNamespace(request=AsyncMock(return_value=[])) + client = Client.__new__(Client) + client._http = http + channel = TextChannel({"id": "c", "type": 0}, client) + + messages = [m async for m in channel.history(limit=1, before="b")] + + assert messages == [] + http.request.assert_called_once_with( + "GET", "/channels/c/messages", params={"limit": 1, "before": "b"} + )