Add async message pager and channel history (#32)
This commit is contained in:
parent
d423f5c03a
commit
71097c6fbe
@ -30,7 +30,7 @@ from .errors import (
|
||||
NotFound,
|
||||
)
|
||||
from .color import Color
|
||||
from .utils import utcnow
|
||||
from .utils import utcnow, message_pager
|
||||
from .enums import GatewayIntent, GatewayOpcode # Export enums
|
||||
from .error_handler import setup_global_error_handler
|
||||
from .hybrid_context import HybridContext
|
||||
|
@ -8,7 +8,7 @@ import json
|
||||
import asyncio
|
||||
import aiohttp # pylint: disable=import-error
|
||||
import asyncio
|
||||
from typing import Optional, TYPE_CHECKING, List, Dict, Any, Union
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from .errors import DisagreementException, HTTPException
|
||||
from .enums import ( # These enums will need to be defined in disagreement/enums.py
|
||||
@ -1087,6 +1087,19 @@ class TextChannel(Channel):
|
||||
)
|
||||
self.last_pin_timestamp: Optional[str] = data.get("last_pin_timestamp")
|
||||
|
||||
def history(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> AsyncIterator["Message"]:
|
||||
"""Return an async iterator over this channel's messages."""
|
||||
|
||||
from .utils import message_pager
|
||||
|
||||
return message_pager(self, limit=limit, before=before, after=after)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
|
@ -3,8 +3,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - for type hinting only
|
||||
from .models import Message, TextChannel
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Return the current timezone-aware UTC time."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
async def message_pager(
|
||||
channel: "TextChannel",
|
||||
*,
|
||||
limit: Optional[int] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> AsyncIterator["Message"]:
|
||||
"""Asynchronously paginate a channel's messages.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
channel:
|
||||
The :class:`TextChannel` to fetch messages from.
|
||||
limit:
|
||||
The maximum number of messages to yield. ``None`` fetches until no
|
||||
more messages are returned.
|
||||
before:
|
||||
Fetch messages with IDs less than this snowflake.
|
||||
after:
|
||||
Fetch messages with IDs greater than this snowflake.
|
||||
|
||||
Yields
|
||||
------
|
||||
Message
|
||||
Messages in the channel, oldest first.
|
||||
"""
|
||||
|
||||
remaining = limit
|
||||
last_id = before
|
||||
while remaining is None or remaining > 0:
|
||||
fetch_limit = 100
|
||||
if remaining is not None:
|
||||
fetch_limit = min(fetch_limit, remaining)
|
||||
|
||||
params: Dict[str, Any] = {"limit": fetch_limit}
|
||||
if last_id is not None:
|
||||
params["before"] = last_id
|
||||
if after is not None:
|
||||
params["after"] = after
|
||||
|
||||
data = await channel._client._http.request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"/channels/{channel.id}/messages",
|
||||
params=params,
|
||||
)
|
||||
|
||||
if not data:
|
||||
break
|
||||
|
||||
for raw in data:
|
||||
msg = channel._client.parse_message(raw) # type: ignore[attr-defined]
|
||||
yield msg
|
||||
last_id = msg.id
|
||||
if remaining is not None:
|
||||
remaining -= 1
|
||||
if remaining == 0:
|
||||
return
|
||||
|
16
docs/message_history.md
Normal file
16
docs/message_history.md
Normal file
@ -0,0 +1,16 @@
|
||||
# Message History
|
||||
|
||||
`TextChannel.history` provides an async iterator over a channel's past messages. The iterator is powered by `utils.message_pager` which handles pagination for you.
|
||||
|
||||
```python
|
||||
channel = await client.fetch_channel(123456789012345678)
|
||||
async for message in channel.history(limit=200):
|
||||
print(message.content)
|
||||
```
|
||||
|
||||
Pass `before` or `after` to control the range of messages returned. The paginator fetches messages in batches of up to 100 until the limit is reached or Discord returns no more messages.
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Caching](caching.md)
|
||||
- [Typing Indicator](typing_indicator.md)
|
31
examples/message_history.py
Normal file
31
examples/message_history.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Example showing how to read a channel's message history."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Allow running example from repository root
|
||||
if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)):
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import TextChannel
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN", "")
|
||||
CHANNEL_ID = os.environ.get("DISCORD_CHANNEL_ID", "")
|
||||
|
||||
client = Client(token=BOT_TOKEN)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
channel = await client.fetch_channel(CHANNEL_ID)
|
||||
if isinstance(channel, TextChannel):
|
||||
async for message in channel.history(limit=10):
|
||||
print(message.content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
37
tests/test_message_pager.py
Normal file
37
tests/test_message_pager.py
Normal file
@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import TextChannel
|
||||
from disagreement.utils import message_pager
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_pager_fetches_until_empty():
|
||||
calls = [
|
||||
[
|
||||
{
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
],
|
||||
[],
|
||||
]
|
||||
http = SimpleNamespace(request=AsyncMock(side_effect=calls))
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
from disagreement.models import Message
|
||||
|
||||
client.parse_message = lambda d: Message(d, client_instance=client)
|
||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||
|
||||
messages = []
|
||||
async for m in message_pager(channel):
|
||||
messages.append(m)
|
||||
|
||||
assert len(messages) == 1
|
||||
http.request.assert_awaited()
|
@ -1,65 +1,24 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import TextChannel, Message
|
||||
from disagreement.models import TextChannel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_textchannel_history_paginates():
|
||||
first_page = [
|
||||
{
|
||||
"id": "3",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "m3",
|
||||
"timestamp": "t",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "m2",
|
||||
"timestamp": "t",
|
||||
},
|
||||
]
|
||||
second_page = [
|
||||
{
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "m1",
|
||||
"timestamp": "t",
|
||||
}
|
||||
]
|
||||
http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page]))
|
||||
async def test_textchannel_history_delegates(monkeypatch):
|
||||
called = {}
|
||||
|
||||
async def fake_pager(channel, *, limit=None, before=None, after=None):
|
||||
called["args"] = (channel, limit, before, after)
|
||||
if False:
|
||||
yield None
|
||||
|
||||
monkeypatch.setattr("disagreement.utils.message_pager", fake_pager)
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||
|
||||
messages = []
|
||||
async for msg in channel.history(limit=3):
|
||||
messages.append(msg)
|
||||
hist = channel.history(limit=2, before="b")
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await hist.__anext__()
|
||||
|
||||
assert len(messages) == 3
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3})
|
||||
http.request.assert_any_call(
|
||||
"GET", "/channels/c/messages", params={"limit": 1, "before": "2"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_textchannel_history_before_param():
|
||||
http = SimpleNamespace(request=AsyncMock(return_value=[]))
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||
|
||||
messages = [m async for m in channel.history(limit=1, before="b")]
|
||||
|
||||
assert messages == []
|
||||
http.request.assert_called_once_with(
|
||||
"GET", "/channels/c/messages", params={"limit": 1, "before": "b"}
|
||||
)
|
||||
assert called["args"] == (channel, 2, "b", None)
|
||||
|
Loading…
x
Reference in New Issue
Block a user