Compare commits

...

2 Commits

3 changed files with 126 additions and 1 deletions

View File

@ -402,7 +402,7 @@ class Interaction:
await self._client._http.create_interaction_response(
interaction_id=self.id,
interaction_token=self.token,
payload=payload,
payload=payload.to_dict(), # type: ignore[arg-type]
)
async def edit(

View File

@ -1108,6 +1108,36 @@ class TextChannel(Channel):
self._client._messages.pop(mid, None)
return ids
async def history(
self,
*,
limit: Optional[int] = 100,
before: "Snowflake | None" = None,
):
"""An async iterator over messages in the channel."""
params: Dict[str, Union[int, str]] = {}
if before is not None:
params["before"] = before
fetched = 0
while True:
to_fetch = 100 if limit is None else min(100, limit - fetched)
if to_fetch <= 0:
break
params["limit"] = to_fetch
messages = await self._client._http.request(
"GET", f"/channels/{self.id}/messages", params=params.copy()
)
if not messages:
break
params["before"] = messages[-1]["id"]
for msg in messages:
yield Message(msg, self._client)
fetched += 1
if limit is not None and fetched >= limit:
return
def __repr__(self) -> str:
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
@ -1225,6 +1255,36 @@ class DMChannel(Channel):
components=components,
)
async def history(
self,
*,
limit: Optional[int] = 100,
before: "Snowflake | None" = None,
):
"""An async iterator over messages in this DM."""
params: Dict[str, Union[int, str]] = {}
if before is not None:
params["before"] = before
fetched = 0
while True:
to_fetch = 100 if limit is None else min(100, limit - fetched)
if to_fetch <= 0:
break
params["limit"] = to_fetch
messages = await self._client._http.request(
"GET", f"/channels/{self.id}/messages", params=params.copy()
)
if not messages:
break
params["before"] = messages[-1]["id"]
for msg in messages:
yield Message(msg, self._client)
fetched += 1
if limit is not None and fetched >= limit:
return
def __repr__(self) -> str:
recipient_repr = self.recipient.username if self.recipient else "Unknown"
return f"<DMChannel id='{self.id}' recipient='{recipient_repr}'>"

View File

@ -0,0 +1,65 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.client import Client
from disagreement.models import TextChannel, Message
@pytest.mark.asyncio
async def test_textchannel_history_paginates():
first_page = [
{
"id": "3",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "m3",
"timestamp": "t",
},
{
"id": "2",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "m2",
"timestamp": "t",
},
]
second_page = [
{
"id": "1",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "m1",
"timestamp": "t",
}
]
http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page]))
client = Client.__new__(Client)
client._http = http
channel = TextChannel({"id": "c", "type": 0}, client)
messages = []
async for msg in channel.history(limit=3):
messages.append(msg)
assert len(messages) == 3
assert all(isinstance(m, Message) for m in messages)
http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3})
http.request.assert_any_call(
"GET", "/channels/c/messages", params={"limit": 1, "before": "2"}
)
@pytest.mark.asyncio
async def test_textchannel_history_before_param():
http = SimpleNamespace(request=AsyncMock(return_value=[]))
client = Client.__new__(Client)
client._http = http
channel = TextChannel({"id": "c", "type": 0}, client)
messages = [m async for m in channel.history(limit=1, before="b")]
assert messages == []
http.request.assert_called_once_with(
"GET", "/channels/c/messages", params={"limit": 1, "before": "b"}
)