diff --git a/disagreement/interactions.py b/disagreement/interactions.py index c314d96..4b89b6b 100644 --- a/disagreement/interactions.py +++ b/disagreement/interactions.py @@ -402,7 +402,7 @@ class Interaction: await self._client._http.create_interaction_response( interaction_id=self.id, interaction_token=self.token, - payload=payload, + payload=payload.to_dict(), # type: ignore[arg-type] ) async def edit( diff --git a/disagreement/models.py b/disagreement/models.py index c492569..a3e7370 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1108,6 +1108,36 @@ class TextChannel(Channel): self._client._messages.pop(mid, None) return ids + async def history( + self, + *, + limit: Optional[int] = 100, + before: "Snowflake | None" = None, + ): + """An async iterator over messages in the channel.""" + + params: Dict[str, Union[int, str]] = {} + if before is not None: + params["before"] = before + + fetched = 0 + while True: + to_fetch = 100 if limit is None else min(100, limit - fetched) + if to_fetch <= 0: + break + params["limit"] = to_fetch + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params.copy() + ) + if not messages: + break + params["before"] = messages[-1]["id"] + for msg in messages: + yield Message(msg, self._client) + fetched += 1 + if limit is not None and fetched >= limit: + return + def __repr__(self) -> str: return f"" @@ -1225,6 +1255,36 @@ class DMChannel(Channel): components=components, ) + async def history( + self, + *, + limit: Optional[int] = 100, + before: "Snowflake | None" = None, + ): + """An async iterator over messages in this DM.""" + + params: Dict[str, Union[int, str]] = {} + if before is not None: + params["before"] = before + + fetched = 0 + while True: + to_fetch = 100 if limit is None else min(100, limit - fetched) + if to_fetch <= 0: + break + params["limit"] = to_fetch + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params.copy() + ) + if not messages: + break + params["before"] = messages[-1]["id"] + for msg in messages: + yield Message(msg, self._client) + fetched += 1 + if limit is not None and fetched >= limit: + return + def __repr__(self) -> str: recipient_repr = self.recipient.username if self.recipient else "Unknown" return f"" diff --git a/tests/test_textchannel_history.py b/tests/test_textchannel_history.py new file mode 100644 index 0000000..4dfc2c3 --- /dev/null +++ b/tests/test_textchannel_history.py @@ -0,0 +1,65 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.client import Client +from disagreement.models import TextChannel, Message + + +@pytest.mark.asyncio +async def test_textchannel_history_paginates(): + first_page = [ + { + "id": "3", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m3", + "timestamp": "t", + }, + { + "id": "2", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m2", + "timestamp": "t", + }, + ] + second_page = [ + { + "id": "1", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m1", + "timestamp": "t", + } + ] + http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page])) + client = Client.__new__(Client) + client._http = http + channel = TextChannel({"id": "c", "type": 0}, client) + + messages = [] + async for msg in channel.history(limit=3): + messages.append(msg) + + assert len(messages) == 3 + assert all(isinstance(m, Message) for m in messages) + http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3}) + http.request.assert_any_call( + "GET", "/channels/c/messages", params={"limit": 1, "before": "2"} + ) + + +@pytest.mark.asyncio +async def test_textchannel_history_before_param(): + http = SimpleNamespace(request=AsyncMock(return_value=[])) + client = Client.__new__(Client) + client._http = http + channel = TextChannel({"id": "c", "type": 0}, client) + + messages = [m async for m in channel.history(limit=1, before="b")] + + assert messages == [] + http.request.assert_called_once_with( + "GET", "/channels/c/messages", params={"limit": 1, "before": "b"} + )