Add async message pager and channel history (#32)

This commit is contained in:
Slipstream 2025-06-10 18:01:21 -06:00 committed by GitHub
parent d423f5c03a
commit 71097c6fbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 176 additions and 57 deletions

View File

@ -30,7 +30,7 @@ from .errors import (
NotFound,
)
from .color import Color
from .utils import utcnow
from .utils import utcnow, message_pager
from .enums import GatewayIntent, GatewayOpcode # Export enums
from .error_handler import setup_global_error_handler
from .hybrid_context import HybridContext

View File

@ -8,7 +8,7 @@ import json
import asyncio
import aiohttp # pylint: disable=import-error
import asyncio
from typing import Optional, TYPE_CHECKING, List, Dict, Any, Union
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union
from .errors import DisagreementException, HTTPException
from .enums import ( # These enums will need to be defined in disagreement/enums.py
@ -1087,6 +1087,19 @@ class TextChannel(Channel):
)
self.last_pin_timestamp: Optional[str] = data.get("last_pin_timestamp")
def history(
self,
*,
limit: Optional[int] = None,
before: Optional[str] = None,
after: Optional[str] = None,
) -> AsyncIterator["Message"]:
"""Return an async iterator over this channel's messages."""
from .utils import message_pager
return message_pager(self, limit=limit, before=before, after=after)
async def send(
self,
content: Optional[str] = None,

View File

@ -3,8 +3,71 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover - for type hinting only
from .models import Message, TextChannel
def utcnow() -> datetime:
"""Return the current timezone-aware UTC time."""
return datetime.now(timezone.utc)
async def message_pager(
channel: "TextChannel",
*,
limit: Optional[int] = None,
before: Optional[str] = None,
after: Optional[str] = None,
) -> AsyncIterator["Message"]:
"""Asynchronously paginate a channel's messages.
Parameters
----------
channel:
The :class:`TextChannel` to fetch messages from.
limit:
The maximum number of messages to yield. ``None`` fetches until no
more messages are returned.
before:
Fetch messages with IDs less than this snowflake.
after:
Fetch messages with IDs greater than this snowflake.
Yields
------
Message
Messages in the channel, oldest first.
"""
remaining = limit
last_id = before
while remaining is None or remaining > 0:
fetch_limit = 100
if remaining is not None:
fetch_limit = min(fetch_limit, remaining)
params: Dict[str, Any] = {"limit": fetch_limit}
if last_id is not None:
params["before"] = last_id
if after is not None:
params["after"] = after
data = await channel._client._http.request( # type: ignore[attr-defined]
"GET",
f"/channels/{channel.id}/messages",
params=params,
)
if not data:
break
for raw in data:
msg = channel._client.parse_message(raw) # type: ignore[attr-defined]
yield msg
last_id = msg.id
if remaining is not None:
remaining -= 1
if remaining == 0:
return

16
docs/message_history.md Normal file
View File

@ -0,0 +1,16 @@
# Message History
`TextChannel.history` provides an async iterator over a channel's past messages. The iterator is powered by `utils.message_pager` which handles pagination for you.
```python
channel = await client.fetch_channel(123456789012345678)
async for message in channel.history(limit=200):
print(message.content)
```
Pass `before` or `after` to control the range of messages returned. The paginator fetches messages in batches of up to 100 until the limit is reached or Discord returns no more messages.
## Next Steps
- [Caching](caching.md)
- [Typing Indicator](typing_indicator.md)

View File

@ -0,0 +1,31 @@
"""Example showing how to read a channel's message history."""
import asyncio
import os
import sys
# Allow running example from repository root
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from disagreement.client import Client
from disagreement.models import TextChannel
from dotenv import load_dotenv
load_dotenv()
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
CHANNEL_ID = os.environ.get("DISCORD_CHANNEL_ID", "")
client = Client(token=BOT_TOKEN)
async def main() -> None:
channel = await client.fetch_channel(CHANNEL_ID)
if isinstance(channel, TextChannel):
async for message in channel.history(limit=10):
print(message.content)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,37 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.client import Client
from disagreement.models import TextChannel
from disagreement.utils import message_pager
@pytest.mark.asyncio
async def test_message_pager_fetches_until_empty():
calls = [
[
{
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
],
[],
]
http = SimpleNamespace(request=AsyncMock(side_effect=calls))
client = Client.__new__(Client)
client._http = http
from disagreement.models import Message
client.parse_message = lambda d: Message(d, client_instance=client)
channel = TextChannel({"id": "c", "type": 0}, client)
messages = []
async for m in message_pager(channel):
messages.append(m)
assert len(messages) == 1
http.request.assert_awaited()

View File

@ -1,65 +1,24 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.client import Client
from disagreement.models import TextChannel, Message
from disagreement.models import TextChannel
@pytest.mark.asyncio
async def test_textchannel_history_paginates():
first_page = [
{
"id": "3",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "m3",
"timestamp": "t",
},
{
"id": "2",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "m2",
"timestamp": "t",
},
]
second_page = [
{
"id": "1",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "m1",
"timestamp": "t",
}
]
http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page]))
async def test_textchannel_history_delegates(monkeypatch):
called = {}
async def fake_pager(channel, *, limit=None, before=None, after=None):
called["args"] = (channel, limit, before, after)
if False:
yield None
monkeypatch.setattr("disagreement.utils.message_pager", fake_pager)
client = Client.__new__(Client)
client._http = http
channel = TextChannel({"id": "c", "type": 0}, client)
messages = []
async for msg in channel.history(limit=3):
messages.append(msg)
hist = channel.history(limit=2, before="b")
with pytest.raises(StopAsyncIteration):
await hist.__anext__()
assert len(messages) == 3
assert all(isinstance(m, Message) for m in messages)
http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3})
http.request.assert_any_call(
"GET", "/channels/c/messages", params={"limit": 1, "before": "2"}
)
@pytest.mark.asyncio
async def test_textchannel_history_before_param():
http = SimpleNamespace(request=AsyncMock(return_value=[]))
client = Client.__new__(Client)
client._http = http
channel = TextChannel({"id": "c", "type": 0}, client)
messages = [m async for m in channel.history(limit=1, before="b")]
assert messages == []
http.request.assert_called_once_with(
"GET", "/channels/c/messages", params={"limit": 1, "before": "b"}
)
assert called["args"] == (channel, 2, "b", None)