Merge remote-tracking branch 'origin/master' into codex/add-display_name-property-on-member
This commit is contained in:
commit
b375dc7d05
@ -48,3 +48,31 @@ class Color:
|
||||
|
||||
def to_rgb(self) -> tuple[int, int, int]:
|
||||
return ((self.value >> 16) & 0xFF, (self.value >> 8) & 0xFF, self.value & 0xFF)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, value: "Color | int | str | tuple[int, int, int] | None") -> "Color | None":
|
||||
"""Convert ``value`` to a :class:`Color` instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value:
|
||||
The value to convert. May be ``None``, an existing ``Color``, an
|
||||
integer in the ``0xRRGGBB`` format, or a hex string like ``"#RRGGBB"``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[Color]
|
||||
A ``Color`` object if ``value`` is not ``None``.
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if isinstance(value, int):
|
||||
return cls(value)
|
||||
if isinstance(value, str):
|
||||
return cls.from_hex(value)
|
||||
if isinstance(value, tuple) and len(value) == 3:
|
||||
return cls.from_rgb(*value)
|
||||
raise TypeError("Color value must be Color, int, str, tuple, or None")
|
||||
|
@ -49,6 +49,11 @@ class GatewayIntent(IntEnum):
|
||||
AUTO_MODERATION_CONFIGURATION = 1 << 20
|
||||
AUTO_MODERATION_EXECUTION = 1 << 21
|
||||
|
||||
@classmethod
|
||||
def none(cls) -> int:
|
||||
"""Return a bitmask representing no intents."""
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> int:
|
||||
"""Returns default intents (excluding privileged ones like members, presences, message content)."""
|
||||
|
@ -18,6 +18,8 @@ class Task:
|
||||
delta: Optional[datetime.timedelta] = None,
|
||||
time_of_day: Optional[datetime.time] = None,
|
||||
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
||||
before_loop: Optional[Callable[[], Awaitable[None] | None]] = None,
|
||||
after_loop: Optional[Callable[[], Awaitable[None] | None]] = None,
|
||||
) -> None:
|
||||
self._coro = coro
|
||||
self._task: Optional[asyncio.Task[None]] = None
|
||||
@ -36,6 +38,8 @@ class Task:
|
||||
self._seconds = float(interval_seconds)
|
||||
self._time_of_day = time_of_day
|
||||
self._on_error = on_error
|
||||
self._before_loop = before_loop
|
||||
self._after_loop = after_loop
|
||||
|
||||
def _seconds_until_time(self) -> float:
|
||||
assert self._time_of_day is not None
|
||||
@ -47,6 +51,9 @@ class Task:
|
||||
|
||||
async def _run(self, *args: Any, **kwargs: Any) -> None:
|
||||
try:
|
||||
if self._before_loop is not None:
|
||||
await _maybe_call_no_args(self._before_loop)
|
||||
|
||||
first = True
|
||||
while True:
|
||||
if self._time_of_day is not None:
|
||||
@ -65,6 +72,9 @@ class Task:
|
||||
first = False
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
if self._after_loop is not None:
|
||||
await _maybe_call_no_args(self._after_loop)
|
||||
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||
if self._task is None or self._task.done():
|
||||
@ -89,6 +99,12 @@ async def _maybe_call(
|
||||
await result
|
||||
|
||||
|
||||
async def _maybe_call_no_args(func: Callable[[], Awaitable[None] | None]) -> None:
|
||||
result = func()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
|
||||
class _Loop:
|
||||
def __init__(
|
||||
self,
|
||||
@ -110,6 +126,8 @@ class _Loop:
|
||||
self.on_error = on_error
|
||||
self._task: Optional[Task] = None
|
||||
self._owner: Any = None
|
||||
self._before_loop: Optional[Callable[..., Awaitable[Any]]] = None
|
||||
self._after_loop: Optional[Callable[..., Awaitable[Any]]] = None
|
||||
|
||||
def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop":
|
||||
return _BoundLoop(self, obj)
|
||||
@ -119,7 +137,33 @@ class _Loop:
|
||||
return self.func(*args, **kwargs)
|
||||
return self.func(self._owner, *args, **kwargs)
|
||||
|
||||
def before_loop(
|
||||
self, func: Callable[..., Awaitable[Any]]
|
||||
) -> Callable[..., Awaitable[Any]]:
|
||||
self._before_loop = func
|
||||
return func
|
||||
|
||||
def after_loop(
|
||||
self, func: Callable[..., Awaitable[Any]]
|
||||
) -> Callable[..., Awaitable[Any]]:
|
||||
self._after_loop = func
|
||||
return func
|
||||
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||
def call_before() -> Awaitable[None] | None:
|
||||
if self._before_loop is None:
|
||||
return None
|
||||
if self._owner is not None:
|
||||
return self._before_loop(self._owner)
|
||||
return self._before_loop()
|
||||
|
||||
def call_after() -> Awaitable[None] | None:
|
||||
if self._after_loop is None:
|
||||
return None
|
||||
if self._owner is not None:
|
||||
return self._after_loop(self._owner)
|
||||
return self._after_loop()
|
||||
|
||||
self._task = Task(
|
||||
self._coro,
|
||||
seconds=self.seconds,
|
||||
@ -128,6 +172,8 @@ class _Loop:
|
||||
delta=self.delta,
|
||||
time_of_day=self.time_of_day,
|
||||
on_error=self.on_error,
|
||||
before_loop=call_before,
|
||||
after_loop=call_after,
|
||||
)
|
||||
return self._task.start(*args, **kwargs)
|
||||
|
||||
|
@ -17,11 +17,13 @@ from .errors import (
|
||||
DisagreementException,
|
||||
)
|
||||
from . import __version__ # For User-Agent
|
||||
from .rate_limiter import RateLimiter
|
||||
from .interactions import InteractionResponsePayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
from .models import Message, Webhook, File
|
||||
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
|
||||
from .interactions import ApplicationCommand, Snowflake
|
||||
|
||||
# Discord API constants
|
||||
API_BASE_URL = "https://discord.com/api/v10" # Using API v10
|
||||
@ -44,8 +46,7 @@ class HTTPClient:
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
self._global_rate_limit_lock = asyncio.Event()
|
||||
self._global_rate_limit_lock.set() # Initially unlocked
|
||||
self._rate_limiter = RateLimiter()
|
||||
|
||||
async def _ensure_session(self):
|
||||
if self._session is None or self._session.closed:
|
||||
@ -87,10 +88,10 @@ class HTTPClient:
|
||||
if self.verbose:
|
||||
print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
|
||||
|
||||
# Global rate limit handling
|
||||
await self._global_rate_limit_lock.wait()
|
||||
route = f"{method.upper()}:{endpoint}"
|
||||
|
||||
for attempt in range(5): # Max 5 retries for rate limits
|
||||
await self._rate_limiter.acquire(route)
|
||||
assert self._session is not None, "ClientSession not initialized"
|
||||
async with self._session.request(
|
||||
method,
|
||||
@ -120,6 +121,8 @@ class HTTPClient:
|
||||
if self.verbose:
|
||||
print(f"HTTP RESPONSE: {response.status} {url} | {data}")
|
||||
|
||||
self._rate_limiter.release(route, response.headers)
|
||||
|
||||
if 200 <= response.status < 300:
|
||||
if response.status == 204:
|
||||
return None
|
||||
@ -142,12 +145,9 @@ class HTTPClient:
|
||||
if data and isinstance(data, dict) and "message" in data:
|
||||
error_message += f" Discord says: {data['message']}"
|
||||
|
||||
if is_global:
|
||||
self._global_rate_limit_lock.clear()
|
||||
await asyncio.sleep(retry_after)
|
||||
self._global_rate_limit_lock.set()
|
||||
else:
|
||||
await asyncio.sleep(retry_after)
|
||||
await self._rate_limiter.handle_rate_limit(
|
||||
route, retry_after, is_global
|
||||
)
|
||||
|
||||
if attempt < 4: # Don't log on the last attempt before raising
|
||||
print(
|
||||
@ -657,7 +657,7 @@ class HTTPClient:
|
||||
self,
|
||||
interaction_id: "Snowflake",
|
||||
interaction_token: str,
|
||||
payload: "InteractionResponsePayload",
|
||||
payload: Union["InteractionResponsePayload", Dict[str, Any]],
|
||||
*,
|
||||
ephemeral: bool = False,
|
||||
) -> None:
|
||||
@ -670,10 +670,16 @@ class HTTPClient:
|
||||
"""
|
||||
# Interaction responses do not use the bot token in the Authorization header.
|
||||
# They are authenticated by the interaction_token in the URL.
|
||||
payload_data: Dict[str, Any]
|
||||
if isinstance(payload, InteractionResponsePayload):
|
||||
payload_data = payload.to_dict()
|
||||
else:
|
||||
payload_data = payload
|
||||
|
||||
await self.request(
|
||||
"POST",
|
||||
f"/interactions/{interaction_id}/{interaction_token}/callback",
|
||||
payload=payload.to_dict(),
|
||||
payload=payload_data,
|
||||
use_auth_header=False,
|
||||
)
|
||||
|
||||
|
@ -402,7 +402,7 @@ class Interaction:
|
||||
await self._client._http.create_interaction_response(
|
||||
interaction_id=self.id,
|
||||
interaction_token=self.token,
|
||||
payload=payload,
|
||||
payload=payload.to_dict(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
async def edit(
|
||||
@ -503,9 +503,7 @@ class InteractionCallbackData:
|
||||
self.tts: Optional[bool] = data.get("tts")
|
||||
self.content: Optional[str] = data.get("content")
|
||||
self.embeds: Optional[List[Embed]] = (
|
||||
[Embed(e) for e in data.get("embeds", [])]
|
||||
if data.get("embeds")
|
||||
else None
|
||||
[Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None
|
||||
)
|
||||
self.allowed_mentions: Optional[AllowedMentions] = (
|
||||
AllowedMentions(data["allowed_mentions"])
|
||||
@ -572,3 +570,6 @@ class InteractionResponsePayload:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<InteractionResponsePayload type={self.type!r}>"
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
return self.to_dict()[item]
|
||||
|
@ -25,6 +25,7 @@ from .enums import ( # These enums will need to be defined in disagreement/enum
|
||||
# SelectMenuType will be part of ComponentType or a new enum if needed
|
||||
)
|
||||
from .permissions import Permissions
|
||||
from .color import Color
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -312,7 +313,7 @@ class Embed:
|
||||
self.description: Optional[str] = data.get("description")
|
||||
self.url: Optional[str] = data.get("url")
|
||||
self.timestamp: Optional[str] = data.get("timestamp") # ISO8601 timestamp
|
||||
self.color: Optional[int] = data.get("color")
|
||||
self.color = Color.parse(data.get("color"))
|
||||
|
||||
self.footer: Optional[EmbedFooter] = (
|
||||
EmbedFooter(data["footer"]) if data.get("footer") else None
|
||||
@ -342,7 +343,7 @@ class Embed:
|
||||
if self.timestamp:
|
||||
payload["timestamp"] = self.timestamp
|
||||
if self.color is not None:
|
||||
payload["color"] = self.color
|
||||
payload["color"] = self.color.value
|
||||
if self.footer:
|
||||
payload["footer"] = self.footer.to_dict()
|
||||
if self.image:
|
||||
@ -1113,6 +1114,36 @@ class TextChannel(Channel):
|
||||
self._client._messages.pop(mid, None)
|
||||
return ids
|
||||
|
||||
async def history(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = 100,
|
||||
before: "Snowflake | None" = None,
|
||||
):
|
||||
"""An async iterator over messages in the channel."""
|
||||
|
||||
params: Dict[str, Union[int, str]] = {}
|
||||
if before is not None:
|
||||
params["before"] = before
|
||||
|
||||
fetched = 0
|
||||
while True:
|
||||
to_fetch = 100 if limit is None else min(100, limit - fetched)
|
||||
if to_fetch <= 0:
|
||||
break
|
||||
params["limit"] = to_fetch
|
||||
messages = await self._client._http.request(
|
||||
"GET", f"/channels/{self.id}/messages", params=params.copy()
|
||||
)
|
||||
if not messages:
|
||||
break
|
||||
params["before"] = messages[-1]["id"]
|
||||
for msg in messages:
|
||||
yield Message(msg, self._client)
|
||||
fetched += 1
|
||||
if limit is not None and fetched >= limit:
|
||||
return
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
|
||||
|
||||
@ -1230,6 +1261,36 @@ class DMChannel(Channel):
|
||||
components=components,
|
||||
)
|
||||
|
||||
async def history(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = 100,
|
||||
before: "Snowflake | None" = None,
|
||||
):
|
||||
"""An async iterator over messages in this DM."""
|
||||
|
||||
params: Dict[str, Union[int, str]] = {}
|
||||
if before is not None:
|
||||
params["before"] = before
|
||||
|
||||
fetched = 0
|
||||
while True:
|
||||
to_fetch = 100 if limit is None else min(100, limit - fetched)
|
||||
if to_fetch <= 0:
|
||||
break
|
||||
params["limit"] = to_fetch
|
||||
messages = await self._client._http.request(
|
||||
"GET", f"/channels/{self.id}/messages", params=params.copy()
|
||||
)
|
||||
if not messages:
|
||||
break
|
||||
params["before"] = messages[-1]["id"]
|
||||
for msg in messages:
|
||||
yield Message(msg, self._client)
|
||||
fetched += 1
|
||||
if limit is not None and fetched >= limit:
|
||||
return
|
||||
|
||||
def __repr__(self) -> str:
|
||||
recipient_repr = self.recipient.username if self.recipient else "Unknown"
|
||||
return f"<DMChannel id='{self.id}' recipient='{recipient_repr}'>"
|
||||
@ -1714,13 +1775,13 @@ class Container(Component):
|
||||
def __init__(
|
||||
self,
|
||||
components: List[Component],
|
||||
accent_color: Optional[int] = None,
|
||||
accent_color: Color | int | str | None = None,
|
||||
spoiler: bool = False,
|
||||
id: Optional[int] = None,
|
||||
):
|
||||
super().__init__(ComponentType.CONTAINER)
|
||||
self.components = components
|
||||
self.accent_color = accent_color
|
||||
self.accent_color = Color.parse(accent_color)
|
||||
self.spoiler = spoiler
|
||||
self.id = id
|
||||
|
||||
@ -1728,7 +1789,7 @@ class Container(Component):
|
||||
payload = super().to_dict()
|
||||
payload["components"] = [c.to_dict() for c in self.components]
|
||||
if self.accent_color:
|
||||
payload["accent_color"] = self.accent_color
|
||||
payload["accent_color"] = self.accent_color.value
|
||||
if self.spoiler:
|
||||
payload["spoiler"] = self.spoiler
|
||||
if self.id is not None:
|
||||
|
@ -16,3 +16,9 @@ bot = Client(
|
||||
```
|
||||
|
||||
These values are passed to `GatewayClient` and applied whenever the connection needs to be re-established.
|
||||
|
||||
## Gateway Intents
|
||||
|
||||
`GatewayIntent` values control which events your bot receives from the Gateway. Use
|
||||
`GatewayIntent.none()` to opt out of all events entirely. It returns `0`, which
|
||||
represents a bitmask with no intents enabled.
|
||||
|
@ -35,6 +35,22 @@ async def worker():
|
||||
...
|
||||
```
|
||||
|
||||
Run setup and teardown code using `before_loop` and `after_loop`:
|
||||
|
||||
```python
|
||||
@tasks.loop(seconds=5.0)
|
||||
async def worker():
|
||||
...
|
||||
|
||||
@worker.before_loop
|
||||
async def before_worker():
|
||||
print("starting")
|
||||
|
||||
@worker.after_loop
|
||||
async def after_worker():
|
||||
print("stopped")
|
||||
```
|
||||
|
||||
You can also schedule a task at a specific time of day:
|
||||
|
||||
```python
|
||||
|
@ -152,7 +152,7 @@ from disagreement.models import Container, TextDisplay
|
||||
|
||||
container = Container(
|
||||
components=[TextDisplay(content="Inside a container")],
|
||||
accent_color=0xFF0000,
|
||||
accent_color="#FF0000", # int or Color() also work
|
||||
spoiler=False,
|
||||
)
|
||||
```
|
||||
|
29
tests/test_color_acceptance.py
Normal file
29
tests/test_color_acceptance.py
Normal file
@ -0,0 +1,29 @@
|
||||
from disagreement.color import Color
|
||||
from disagreement.models import Embed, Container, Component
|
||||
|
||||
|
||||
def test_color_parse():
|
||||
assert Color.parse(0x123456).value == 0x123456
|
||||
assert Color.parse("#123456").value == 0x123456
|
||||
c = Color(0xABCDEF)
|
||||
assert Color.parse(c) is c
|
||||
assert Color.parse(None) is None
|
||||
assert Color.parse((255, 0, 0)).value == 0xFF0000
|
||||
|
||||
|
||||
def test_embed_color_parsing():
|
||||
e = Embed({"color": "#FF0000"})
|
||||
assert e.color.value == 0xFF0000
|
||||
e = Embed({"color": Color(0x00FF00)})
|
||||
assert e.color.value == 0x00FF00
|
||||
e = Embed({"color": 0x0000FF})
|
||||
assert e.color.value == 0x0000FF
|
||||
|
||||
|
||||
def test_container_accent_color_parsing():
|
||||
container = Container(components=[], accent_color="#010203")
|
||||
assert container.accent_color.value == 0x010203
|
||||
container = Container(components=[], accent_color=Color(0x111111))
|
||||
assert container.accent_color.value == 0x111111
|
||||
container = Container(components=[], accent_color=0x222222)
|
||||
assert container.accent_color.value == 0x222222
|
7
tests/test_gateway_intent.py
Normal file
7
tests/test_gateway_intent.py
Normal file
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.enums import GatewayIntent
|
||||
|
||||
|
||||
def test_gateway_intent_none_equals_zero():
|
||||
assert GatewayIntent.none() == 0
|
70
tests/test_http_rate_limit.py
Normal file
70
tests/test_http_rate_limit.py
Normal file
@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from disagreement.http import HTTPClient
|
||||
|
||||
|
||||
class DummyResp:
|
||||
def __init__(self, status, headers=None, data=None):
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self._data = data or {}
|
||||
self.headers.setdefault("Content-Type", "application/json")
|
||||
|
||||
async def json(self):
|
||||
return self._data
|
||||
|
||||
async def text(self):
|
||||
return str(self._data)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_acquires_and_releases(monkeypatch):
|
||||
http = HTTPClient(token="t")
|
||||
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
|
||||
resp = DummyResp(200)
|
||||
http._session = SimpleNamespace(request=MagicMock(return_value=resp))
|
||||
http._rate_limiter.acquire = AsyncMock()
|
||||
http._rate_limiter.release = MagicMock()
|
||||
http._rate_limiter.handle_rate_limit = AsyncMock()
|
||||
|
||||
await http.request("GET", "/a")
|
||||
|
||||
http._rate_limiter.acquire.assert_awaited_once_with("GET:/a")
|
||||
http._rate_limiter.release.assert_called_once_with("GET:/a", resp.headers)
|
||||
http._rate_limiter.handle_rate_limit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_handles_rate_limit(monkeypatch):
|
||||
http = HTTPClient(token="t")
|
||||
monkeypatch.setattr(http, "_ensure_session", AsyncMock())
|
||||
resp1 = DummyResp(
|
||||
429,
|
||||
{
|
||||
"Retry-After": "0.1",
|
||||
"X-RateLimit-Global": "false",
|
||||
"X-RateLimit-Remaining": "0",
|
||||
"X-RateLimit-Reset-After": "0.1",
|
||||
},
|
||||
{"message": "slow"},
|
||||
)
|
||||
resp2 = DummyResp(200, {}, {})
|
||||
http._session = SimpleNamespace(request=MagicMock(side_effect=[resp1, resp2]))
|
||||
http._rate_limiter.acquire = AsyncMock()
|
||||
http._rate_limiter.release = MagicMock()
|
||||
http._rate_limiter.handle_rate_limit = AsyncMock()
|
||||
|
||||
result = await http.request("GET", "/a")
|
||||
|
||||
assert http._rate_limiter.acquire.await_count == 2
|
||||
assert http._rate_limiter.release.call_count == 2
|
||||
http._rate_limiter.handle_rate_limit.assert_awaited_once_with("GET:/a", 0.1, False)
|
||||
assert result == {}
|
@ -28,6 +28,5 @@ async def test_respond_modal(dummy_bot, interaction):
|
||||
await interaction.respond_modal(modal)
|
||||
dummy_bot._http.create_interaction_response.assert_called_once()
|
||||
payload = dummy_bot._http.create_interaction_response.call_args.kwargs["payload"]
|
||||
payload_dict = payload.to_dict()
|
||||
assert payload_dict["type"] == InteractionCallbackType.MODAL.value
|
||||
assert payload_dict["data"]["custom_id"] == "m1"
|
||||
assert payload.type == InteractionCallbackType.MODAL
|
||||
assert payload.data["custom_id"] == "m1"
|
||||
|
@ -58,3 +58,27 @@ async def test_loop_runs_and_stops() -> None:
|
||||
dummy.work.stop() # pylint: disable=no-member
|
||||
assert dummy.count >= 2
|
||||
assert not dummy.work.running # pylint: disable=no-member
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_after_loop_callbacks() -> None:
|
||||
events: list[str] = []
|
||||
|
||||
@tasks.loop(seconds=0.01)
|
||||
async def ticker() -> None:
|
||||
events.append("tick")
|
||||
|
||||
@ticker.before_loop
|
||||
async def before() -> None: # pragma: no cover - trivial callback
|
||||
events.append("before")
|
||||
|
||||
@ticker.after_loop
|
||||
async def after() -> None: # pragma: no cover - trivial callback
|
||||
events.append("after")
|
||||
|
||||
ticker.start()
|
||||
await asyncio.sleep(0.03)
|
||||
ticker.stop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert events and events[0] == "before"
|
||||
assert "after" in events
|
||||
|
65
tests/test_textchannel_history.py
Normal file
65
tests/test_textchannel_history.py
Normal file
@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.models import TextChannel, Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_textchannel_history_paginates():
|
||||
first_page = [
|
||||
{
|
||||
"id": "3",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "m3",
|
||||
"timestamp": "t",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "m2",
|
||||
"timestamp": "t",
|
||||
},
|
||||
]
|
||||
second_page = [
|
||||
{
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "1", "username": "u", "discriminator": "0001"},
|
||||
"content": "m1",
|
||||
"timestamp": "t",
|
||||
}
|
||||
]
|
||||
http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page]))
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||
|
||||
messages = []
|
||||
async for msg in channel.history(limit=3):
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 3
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3})
|
||||
http.request.assert_any_call(
|
||||
"GET", "/channels/c/messages", params={"limit": 1, "before": "2"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_textchannel_history_before_param():
|
||||
http = SimpleNamespace(request=AsyncMock(return_value=[]))
|
||||
client = Client.__new__(Client)
|
||||
client._http = http
|
||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||
|
||||
messages = [m async for m in channel.history(limit=1, before="b")]
|
||||
|
||||
assert messages == []
|
||||
http.request.assert_called_once_with(
|
||||
"GET", "/channels/c/messages", params={"limit": 1, "before": "b"}
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user