Compare commits
2 Commits
c0066525db
...
d67097a619
Author | SHA1 | Date | |
---|---|---|---|
d67097a619 | |||
39b05bc958 |
@ -402,7 +402,7 @@ class Interaction:
|
||||
await self._client._http.create_interaction_response(
|
||||
interaction_id=self.id,
|
||||
interaction_token=self.token,
|
||||
payload=payload,
|
||||
payload=payload.to_dict(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
async def edit(
|
||||
|
@ -1108,6 +1108,36 @@ class TextChannel(Channel):
|
||||
self._client._messages.pop(mid, None)
|
||||
return ids
|
||||
|
||||
async def history(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = 100,
|
||||
before: "Snowflake | None" = None,
|
||||
):
|
||||
"""An async iterator over messages in the channel."""
|
||||
|
||||
params: Dict[str, Union[int, str]] = {}
|
||||
if before is not None:
|
||||
params["before"] = before
|
||||
|
||||
fetched = 0
|
||||
while True:
|
||||
to_fetch = 100 if limit is None else min(100, limit - fetched)
|
||||
if to_fetch <= 0:
|
||||
break
|
||||
params["limit"] = to_fetch
|
||||
messages = await self._client._http.request(
|
||||
"GET", f"/channels/{self.id}/messages", params=params.copy()
|
||||
)
|
||||
if not messages:
|
||||
break
|
||||
params["before"] = messages[-1]["id"]
|
||||
for msg in messages:
|
||||
yield Message(msg, self._client)
|
||||
fetched += 1
|
||||
if limit is not None and fetched >= limit:
|
||||
return
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
|
||||
|
||||
@ -1225,6 +1255,36 @@ class DMChannel(Channel):
|
||||
components=components,
|
||||
)
|
||||
|
||||
async def history(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = 100,
|
||||
before: "Snowflake | None" = None,
|
||||
):
|
||||
"""An async iterator over messages in this DM."""
|
||||
|
||||
params: Dict[str, Union[int, str]] = {}
|
||||
if before is not None:
|
||||
params["before"] = before
|
||||
|
||||
fetched = 0
|
||||
while True:
|
||||
to_fetch = 100 if limit is None else min(100, limit - fetched)
|
||||
if to_fetch <= 0:
|
||||
break
|
||||
params["limit"] = to_fetch
|
||||
messages = await self._client._http.request(
|
||||
"GET", f"/channels/{self.id}/messages", params=params.copy()
|
||||
)
|
||||
if not messages:
|
||||
break
|
||||
params["before"] = messages[-1]["id"]
|
||||
for msg in messages:
|
||||
yield Message(msg, self._client)
|
||||
fetched += 1
|
||||
if limit is not None and fetched >= limit:
|
||||
return
|
||||
|
||||
def __repr__(self) -> str:
|
||||
recipient_repr = self.recipient.username if self.recipient else "Unknown"
|
||||
return f"<DMChannel id='{self.id}' recipient='{recipient_repr}'>"
|
||||
|
65
tests/test_textchannel_history.py
Normal file
65
tests/test_textchannel_history.py
Normal file
@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import TextChannel, Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_textchannel_history_paginates():
|
||||
first_page = [
|
||||
{
|
||||
"id": "3",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "m3",
|
||||
"timestamp": "t",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "m2",
|
||||
"timestamp": "t",
|
||||
},
|
||||
]
|
||||
second_page = [
|
||||
{
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "m1",
|
||||
"timestamp": "t",
|
||||
}
|
||||
]
|
||||
http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page]))
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||
|
||||
messages = []
|
||||
async for msg in channel.history(limit=3):
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 3
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3})
|
||||
http.request.assert_any_call(
|
||||
"GET", "/channels/c/messages", params={"limit": 1, "before": "2"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_textchannel_history_before_param():
|
||||
http = SimpleNamespace(request=AsyncMock(return_value=[]))
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||
|
||||
messages = [m async for m in channel.history(limit=1, before="b")]
|
||||
|
||||
assert messages == []
|
||||
http.request.assert_called_once_with(
|
||||
"GET", "/channels/c/messages", params={"limit": 1, "before": "b"}
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user