Merge remote-tracking branch 'origin/master' into codex/implement-ratelimiter-in-httpclient

This commit is contained in:
Slipstream 2025-06-10 17:55:09 -06:00
commit 0630c8b916
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
10 changed files with 281 additions and 9 deletions

View File

@ -48,3 +48,31 @@ class Color:
def to_rgb(self) -> tuple[int, int, int]:
return ((self.value >> 16) & 0xFF, (self.value >> 8) & 0xFF, self.value & 0xFF)
@classmethod
def parse(cls, value: "Color | int | str | tuple[int, int, int] | None") -> "Color | None":
"""Convert ``value`` to a :class:`Color` instance.
Parameters
----------
value:
The value to convert. May be ``None``, an existing ``Color``, an
integer in the ``0xRRGGBB`` format, or a hex string like ``"#RRGGBB"``.
Returns
-------
Optional[Color]
A ``Color`` object if ``value`` is not ``None``.
"""
if value is None:
return None
if isinstance(value, cls):
return value
if isinstance(value, int):
return cls(value)
if isinstance(value, str):
return cls.from_hex(value)
if isinstance(value, tuple) and len(value) == 3:
return cls.from_rgb(*value)
raise TypeError("Color value must be Color, int, str, tuple, or None")

View File

@ -18,6 +18,8 @@ class Task:
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
before_loop: Optional[Callable[[], Awaitable[None] | None]] = None,
after_loop: Optional[Callable[[], Awaitable[None] | None]] = None,
) -> None:
self._coro = coro
self._task: Optional[asyncio.Task[None]] = None
@ -36,6 +38,8 @@ class Task:
self._seconds = float(interval_seconds)
self._time_of_day = time_of_day
self._on_error = on_error
self._before_loop = before_loop
self._after_loop = after_loop
def _seconds_until_time(self) -> float:
assert self._time_of_day is not None
@ -47,6 +51,9 @@ class Task:
async def _run(self, *args: Any, **kwargs: Any) -> None:
try:
if self._before_loop is not None:
await _maybe_call_no_args(self._before_loop)
first = True
while True:
if self._time_of_day is not None:
@ -65,6 +72,9 @@ class Task:
first = False
except asyncio.CancelledError:
pass
finally:
if self._after_loop is not None:
await _maybe_call_no_args(self._after_loop)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
if self._task is None or self._task.done():
@ -89,6 +99,12 @@ async def _maybe_call(
await result
async def _maybe_call_no_args(func: Callable[[], Awaitable[None] | None]) -> None:
result = func()
if asyncio.iscoroutine(result):
await result
class _Loop:
def __init__(
self,
@ -110,6 +126,8 @@ class _Loop:
self.on_error = on_error
self._task: Optional[Task] = None
self._owner: Any = None
self._before_loop: Optional[Callable[..., Awaitable[Any]]] = None
self._after_loop: Optional[Callable[..., Awaitable[Any]]] = None
def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop":
return _BoundLoop(self, obj)
@ -119,7 +137,33 @@ class _Loop:
return self.func(*args, **kwargs)
return self.func(self._owner, *args, **kwargs)
def before_loop(
self, func: Callable[..., Awaitable[Any]]
) -> Callable[..., Awaitable[Any]]:
self._before_loop = func
return func
def after_loop(
self, func: Callable[..., Awaitable[Any]]
) -> Callable[..., Awaitable[Any]]:
self._after_loop = func
return func
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
def call_before() -> Awaitable[None] | None:
if self._before_loop is None:
return None
if self._owner is not None:
return self._before_loop(self._owner)
return self._before_loop()
def call_after() -> Awaitable[None] | None:
if self._after_loop is None:
return None
if self._owner is not None:
return self._after_loop(self._owner)
return self._after_loop()
self._task = Task(
self._coro,
seconds=self.seconds,
@ -128,6 +172,8 @@ class _Loop:
delta=self.delta,
time_of_day=self.time_of_day,
on_error=self.on_error,
before_loop=call_before,
after_loop=call_after,
)
return self._task.start(*args, **kwargs)

View File

@ -402,7 +402,7 @@ class Interaction:
await self._client._http.create_interaction_response(
interaction_id=self.id,
interaction_token=self.token,
payload=payload.to_dict(),
payload=payload.to_dict(), # type: ignore[arg-type]
)
async def edit(
@ -570,3 +570,6 @@ class InteractionResponsePayload:
def __repr__(self) -> str:
return f"<InteractionResponsePayload type={self.type!r}>"
def __getitem__(self, item: str) -> Any:
return self.to_dict()[item]

View File

@ -25,6 +25,7 @@ from .enums import ( # These enums will need to be defined in disagreement/enum
# SelectMenuType will be part of ComponentType or a new enum if needed
)
from .permissions import Permissions
from .color import Color
if TYPE_CHECKING:
@ -312,7 +313,7 @@ class Embed:
self.description: Optional[str] = data.get("description")
self.url: Optional[str] = data.get("url")
self.timestamp: Optional[str] = data.get("timestamp") # ISO8601 timestamp
self.color: Optional[int] = data.get("color")
self.color = Color.parse(data.get("color"))
self.footer: Optional[EmbedFooter] = (
EmbedFooter(data["footer"]) if data.get("footer") else None
@ -342,7 +343,7 @@ class Embed:
if self.timestamp:
payload["timestamp"] = self.timestamp
if self.color is not None:
payload["color"] = self.color
payload["color"] = self.color.value
if self.footer:
payload["footer"] = self.footer.to_dict()
if self.image:
@ -1107,6 +1108,36 @@ class TextChannel(Channel):
self._client._messages.pop(mid, None)
return ids
async def history(
self,
*,
limit: Optional[int] = 100,
before: "Snowflake | None" = None,
):
"""An async iterator over messages in the channel."""
params: Dict[str, Union[int, str]] = {}
if before is not None:
params["before"] = before
fetched = 0
while True:
to_fetch = 100 if limit is None else min(100, limit - fetched)
if to_fetch <= 0:
break
params["limit"] = to_fetch
messages = await self._client._http.request(
"GET", f"/channels/{self.id}/messages", params=params.copy()
)
if not messages:
break
params["before"] = messages[-1]["id"]
for msg in messages:
yield Message(msg, self._client)
fetched += 1
if limit is not None and fetched >= limit:
return
def __repr__(self) -> str:
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
@ -1224,6 +1255,36 @@ class DMChannel(Channel):
components=components,
)
async def history(
self,
*,
limit: Optional[int] = 100,
before: "Snowflake | None" = None,
):
"""An async iterator over messages in this DM."""
params: Dict[str, Union[int, str]] = {}
if before is not None:
params["before"] = before
fetched = 0
while True:
to_fetch = 100 if limit is None else min(100, limit - fetched)
if to_fetch <= 0:
break
params["limit"] = to_fetch
messages = await self._client._http.request(
"GET", f"/channels/{self.id}/messages", params=params.copy()
)
if not messages:
break
params["before"] = messages[-1]["id"]
for msg in messages:
yield Message(msg, self._client)
fetched += 1
if limit is not None and fetched >= limit:
return
def __repr__(self) -> str:
recipient_repr = self.recipient.username if self.recipient else "Unknown"
return f"<DMChannel id='{self.id}' recipient='{recipient_repr}'>"
@ -1708,13 +1769,13 @@ class Container(Component):
def __init__(
self,
components: List[Component],
accent_color: Optional[int] = None,
accent_color: Color | int | str | None = None,
spoiler: bool = False,
id: Optional[int] = None,
):
super().__init__(ComponentType.CONTAINER)
self.components = components
self.accent_color = accent_color
self.accent_color = Color.parse(accent_color)
self.spoiler = spoiler
self.id = id
@ -1722,7 +1783,7 @@ class Container(Component):
payload = super().to_dict()
payload["components"] = [c.to_dict() for c in self.components]
if self.accent_color:
payload["accent_color"] = self.accent_color
payload["accent_color"] = self.accent_color.value
if self.spoiler:
payload["spoiler"] = self.spoiler
if self.id is not None:

View File

@ -35,6 +35,22 @@ async def worker():
...
```
Run setup and teardown code using `before_loop` and `after_loop`:
```python
@tasks.loop(seconds=5.0)
async def worker():
...
@worker.before_loop
async def before_worker():
print("starting")
@worker.after_loop
async def after_worker():
print("stopped")
```
You can also schedule a task at a specific time of day:
```python

View File

@ -152,7 +152,7 @@ from disagreement.models import Container, TextDisplay
container = Container(
components=[TextDisplay(content="Inside a container")],
accent_color=0xFF0000,
accent_color="#FF0000", # int or Color() also work
spoiler=False,
)
```

View File

@ -0,0 +1,29 @@
from disagreement.color import Color
from disagreement.models import Embed, Container, Component
def test_color_parse():
assert Color.parse(0x123456).value == 0x123456
assert Color.parse("#123456").value == 0x123456
c = Color(0xABCDEF)
assert Color.parse(c) is c
assert Color.parse(None) is None
assert Color.parse((255, 0, 0)).value == 0xFF0000
def test_embed_color_parsing():
e = Embed({"color": "#FF0000"})
assert e.color.value == 0xFF0000
e = Embed({"color": Color(0x00FF00)})
assert e.color.value == 0x00FF00
e = Embed({"color": 0x0000FF})
assert e.color.value == 0x0000FF
def test_container_accent_color_parsing():
container = Container(components=[], accent_color="#010203")
assert container.accent_color.value == 0x010203
container = Container(components=[], accent_color=Color(0x111111))
assert container.accent_color.value == 0x111111
container = Container(components=[], accent_color=0x222222)
assert container.accent_color.value == 0x222222

View File

@ -28,5 +28,5 @@ async def test_respond_modal(dummy_bot, interaction):
await interaction.respond_modal(modal)
dummy_bot._http.create_interaction_response.assert_called_once()
payload = dummy_bot._http.create_interaction_response.call_args.kwargs["payload"]
assert payload["type"] == InteractionCallbackType.MODAL.value
assert payload["data"]["custom_id"] == "m1"
assert payload.type == InteractionCallbackType.MODAL
assert payload.data["custom_id"] == "m1"

View File

@ -58,3 +58,27 @@ async def test_loop_runs_and_stops() -> None:
dummy.work.stop() # pylint: disable=no-member
assert dummy.count >= 2
assert not dummy.work.running # pylint: disable=no-member
@pytest.mark.asyncio
async def test_before_after_loop_callbacks() -> None:
events: list[str] = []
@tasks.loop(seconds=0.01)
async def ticker() -> None:
events.append("tick")
@ticker.before_loop
async def before() -> None: # pragma: no cover - trivial callback
events.append("before")
@ticker.after_loop
async def after() -> None: # pragma: no cover - trivial callback
events.append("after")
ticker.start()
await asyncio.sleep(0.03)
ticker.stop()
await asyncio.sleep(0.01)
assert events and events[0] == "before"
assert "after" in events

View File

@ -0,0 +1,65 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.client import Client
from disagreement.models import TextChannel, Message
@pytest.mark.asyncio
async def test_textchannel_history_paginates():
first_page = [
{
"id": "3",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "m3",
"timestamp": "t",
},
{
"id": "2",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "m2",
"timestamp": "t",
},
]
second_page = [
{
"id": "1",
"channel_id": "c",
"author": {"id": "1", "username": "u", "discriminator": "0001"},
"content": "m1",
"timestamp": "t",
}
]
http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page]))
client = Client.__new__(Client)
client._http = http
channel = TextChannel({"id": "c", "type": 0}, client)
messages = []
async for msg in channel.history(limit=3):
messages.append(msg)
assert len(messages) == 3
assert all(isinstance(m, Message) for m in messages)
http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3})
http.request.assert_any_call(
"GET", "/channels/c/messages", params={"limit": 1, "before": "2"}
)
@pytest.mark.asyncio
async def test_textchannel_history_before_param():
http = SimpleNamespace(request=AsyncMock(return_value=[]))
client = Client.__new__(Client)
client._http = http
channel = TextChannel({"id": "c", "type": 0}, client)
messages = [m async for m in channel.history(limit=1, before="b")]
assert messages == []
http.request.assert_called_once_with(
"GET", "/channels/c/messages", params={"limit": 1, "before": "b"}
)