diff --git a/disagreement/color.py b/disagreement/color.py index 59d5e8b..d1884f6 100644 --- a/disagreement/color.py +++ b/disagreement/color.py @@ -48,3 +48,31 @@ class Color: def to_rgb(self) -> tuple[int, int, int]: return ((self.value >> 16) & 0xFF, (self.value >> 8) & 0xFF, self.value & 0xFF) + + @classmethod + def parse(cls, value: "Color | int | str | tuple[int, int, int] | None") -> "Color | None": + """Convert ``value`` to a :class:`Color` instance. + + Parameters + ---------- + value: + The value to convert. May be ``None``, an existing ``Color``, an + integer in the ``0xRRGGBB`` format, or a hex string like ``"#RRGGBB"``. + + Returns + ------- + Optional[Color] + A ``Color`` object if ``value`` is not ``None``. + """ + + if value is None: + return None + if isinstance(value, cls): + return value + if isinstance(value, int): + return cls(value) + if isinstance(value, str): + return cls.from_hex(value) + if isinstance(value, tuple) and len(value) == 3: + return cls.from_rgb(*value) + raise TypeError("Color value must be Color, int, str, tuple, or None") diff --git a/disagreement/ext/tasks.py b/disagreement/ext/tasks.py index 7290c90..986e2da 100644 --- a/disagreement/ext/tasks.py +++ b/disagreement/ext/tasks.py @@ -18,6 +18,8 @@ class Task: delta: Optional[datetime.timedelta] = None, time_of_day: Optional[datetime.time] = None, on_error: Optional[Callable[[Exception], Awaitable[None]]] = None, + before_loop: Optional[Callable[[], Awaitable[None] | None]] = None, + after_loop: Optional[Callable[[], Awaitable[None] | None]] = None, ) -> None: self._coro = coro self._task: Optional[asyncio.Task[None]] = None @@ -36,6 +38,8 @@ class Task: self._seconds = float(interval_seconds) self._time_of_day = time_of_day self._on_error = on_error + self._before_loop = before_loop + self._after_loop = after_loop def _seconds_until_time(self) -> float: assert self._time_of_day is not None @@ -47,6 +51,9 @@ class Task: async def _run(self, *args: Any, **kwargs: Any) -> None: try: + if self._before_loop is not None: + await _maybe_call_no_args(self._before_loop) + first = True while True: if self._time_of_day is not None: @@ -65,6 +72,9 @@ class Task: first = False except asyncio.CancelledError: pass + finally: + if self._after_loop is not None: + await _maybe_call_no_args(self._after_loop) def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: if self._task is None or self._task.done(): @@ -89,6 +99,12 @@ async def _maybe_call( await result +async def _maybe_call_no_args(func: Callable[[], Awaitable[None] | None]) -> None: + result = func() + if asyncio.iscoroutine(result): + await result + + class _Loop: def __init__( self, @@ -110,6 +126,8 @@ class _Loop: self.on_error = on_error self._task: Optional[Task] = None self._owner: Any = None + self._before_loop: Optional[Callable[..., Awaitable[Any]]] = None + self._after_loop: Optional[Callable[..., Awaitable[Any]]] = None def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop": return _BoundLoop(self, obj) @@ -119,7 +137,33 @@ class _Loop: return self.func(*args, **kwargs) return self.func(self._owner, *args, **kwargs) + def before_loop( + self, func: Callable[..., Awaitable[Any]] + ) -> Callable[..., Awaitable[Any]]: + self._before_loop = func + return func + + def after_loop( + self, func: Callable[..., Awaitable[Any]] + ) -> Callable[..., Awaitable[Any]]: + self._after_loop = func + return func + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: + def call_before() -> Awaitable[None] | None: + if self._before_loop is None: + return None + if self._owner is not None: + return self._before_loop(self._owner) + return self._before_loop() + + def call_after() -> Awaitable[None] | None: + if self._after_loop is None: + return None + if self._owner is not None: + return self._after_loop(self._owner) + return self._after_loop() + self._task = Task( self._coro, seconds=self.seconds, @@ -128,6 +172,8 @@ class _Loop: delta=self.delta, time_of_day=self.time_of_day, on_error=self.on_error, + before_loop=call_before, + after_loop=call_after, ) return self._task.start(*args, **kwargs) diff --git a/disagreement/interactions.py b/disagreement/interactions.py index 4dc7337..4b89b6b 100644 --- a/disagreement/interactions.py +++ b/disagreement/interactions.py @@ -402,7 +402,7 @@ class Interaction: await self._client._http.create_interaction_response( interaction_id=self.id, interaction_token=self.token, - payload=payload, + payload=payload.to_dict(), # type: ignore[arg-type] ) async def edit( @@ -503,9 +503,7 @@ class InteractionCallbackData: self.tts: Optional[bool] = data.get("tts") self.content: Optional[str] = data.get("content") self.embeds: Optional[List[Embed]] = ( - [Embed(e) for e in data.get("embeds", [])] - if data.get("embeds") - else None + [Embed(e) for e in data.get("embeds", [])] if data.get("embeds") else None ) self.allowed_mentions: Optional[AllowedMentions] = ( AllowedMentions(data["allowed_mentions"]) @@ -572,3 +570,6 @@ class InteractionResponsePayload: def __repr__(self) -> str: return f"" + + def __getitem__(self, item: str) -> Any: + return self.to_dict()[item] diff --git a/disagreement/models.py b/disagreement/models.py index 6bc5d39..a3e7370 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -25,6 +25,7 @@ from .enums import ( # These enums will need to be defined in disagreement/enum # SelectMenuType will be part of ComponentType or a new enum if needed ) from .permissions import Permissions +from .color import Color if TYPE_CHECKING: @@ -312,7 +313,7 @@ class Embed: self.description: Optional[str] = data.get("description") self.url: Optional[str] = data.get("url") self.timestamp: Optional[str] = data.get("timestamp") # ISO8601 timestamp - self.color: Optional[int] = data.get("color") + self.color = Color.parse(data.get("color")) self.footer: Optional[EmbedFooter] = ( EmbedFooter(data["footer"]) if data.get("footer") else None @@ -342,7 +343,7 @@ class Embed: if self.timestamp: payload["timestamp"] = self.timestamp if self.color is not None: - payload["color"] = self.color + payload["color"] = self.color.value if self.footer: payload["footer"] = self.footer.to_dict() if self.image: @@ -1107,6 +1108,36 @@ class TextChannel(Channel): self._client._messages.pop(mid, None) return ids + async def history( + self, + *, + limit: Optional[int] = 100, + before: "Snowflake | None" = None, + ): + """An async iterator over messages in the channel.""" + + params: Dict[str, Union[int, str]] = {} + if before is not None: + params["before"] = before + + fetched = 0 + while True: + to_fetch = 100 if limit is None else min(100, limit - fetched) + if to_fetch <= 0: + break + params["limit"] = to_fetch + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params.copy() + ) + if not messages: + break + params["before"] = messages[-1]["id"] + for msg in messages: + yield Message(msg, self._client) + fetched += 1 + if limit is not None and fetched >= limit: + return + def __repr__(self) -> str: return f"" @@ -1224,6 +1255,36 @@ class DMChannel(Channel): components=components, ) + async def history( + self, + *, + limit: Optional[int] = 100, + before: "Snowflake | None" = None, + ): + """An async iterator over messages in this DM.""" + + params: Dict[str, Union[int, str]] = {} + if before is not None: + params["before"] = before + + fetched = 0 + while True: + to_fetch = 100 if limit is None else min(100, limit - fetched) + if to_fetch <= 0: + break + params["limit"] = to_fetch + messages = await self._client._http.request( + "GET", f"/channels/{self.id}/messages", params=params.copy() + ) + if not messages: + break + params["before"] = messages[-1]["id"] + for msg in messages: + yield Message(msg, self._client) + fetched += 1 + if limit is not None and fetched >= limit: + return + def __repr__(self) -> str: recipient_repr = self.recipient.username if self.recipient else "Unknown" return f"" @@ -1708,13 +1769,13 @@ class Container(Component): def __init__( self, components: List[Component], - accent_color: Optional[int] = None, + accent_color: Color | int | str | None = None, spoiler: bool = False, id: Optional[int] = None, ): super().__init__(ComponentType.CONTAINER) self.components = components - self.accent_color = accent_color + self.accent_color = Color.parse(accent_color) self.spoiler = spoiler self.id = id @@ -1722,7 +1783,7 @@ class Container(Component): payload = super().to_dict() payload["components"] = [c.to_dict() for c in self.components] if self.accent_color: - payload["accent_color"] = self.accent_color + payload["accent_color"] = self.accent_color.value if self.spoiler: payload["spoiler"] = self.spoiler if self.id is not None: diff --git a/docs/task_loop.md b/docs/task_loop.md index b054909..c60a689 100644 --- a/docs/task_loop.md +++ b/docs/task_loop.md @@ -35,6 +35,22 @@ async def worker(): ... ``` +Run setup and teardown code using `before_loop` and `after_loop`: + +```python +@tasks.loop(seconds=5.0) +async def worker(): + ... + +@worker.before_loop +async def before_worker(): + print("starting") + +@worker.after_loop +async def after_worker(): + print("stopped") +``` + You can also schedule a task at a specific time of day: ```python diff --git a/docs/using_components.md b/docs/using_components.md index 9ebd65a..ac508b0 100644 --- a/docs/using_components.md +++ b/docs/using_components.md @@ -152,7 +152,7 @@ from disagreement.models import Container, TextDisplay container = Container( components=[TextDisplay(content="Inside a container")], - accent_color=0xFF0000, + accent_color="#FF0000", # int or Color() also work spoiler=False, ) ``` diff --git a/tests/test_color_acceptance.py b/tests/test_color_acceptance.py new file mode 100644 index 0000000..f495a1d --- /dev/null +++ b/tests/test_color_acceptance.py @@ -0,0 +1,29 @@ +from disagreement.color import Color +from disagreement.models import Embed, Container, Component + + +def test_color_parse(): + assert Color.parse(0x123456).value == 0x123456 + assert Color.parse("#123456").value == 0x123456 + c = Color(0xABCDEF) + assert Color.parse(c) is c + assert Color.parse(None) is None + assert Color.parse((255, 0, 0)).value == 0xFF0000 + + +def test_embed_color_parsing(): + e = Embed({"color": "#FF0000"}) + assert e.color.value == 0xFF0000 + e = Embed({"color": Color(0x00FF00)}) + assert e.color.value == 0x00FF00 + e = Embed({"color": 0x0000FF}) + assert e.color.value == 0x0000FF + + +def test_container_accent_color_parsing(): + container = Container(components=[], accent_color="#010203") + assert container.accent_color.value == 0x010203 + container = Container(components=[], accent_color=Color(0x111111)) + assert container.accent_color.value == 0x111111 + container = Container(components=[], accent_color=0x222222) + assert container.accent_color.value == 0x222222 diff --git a/tests/test_modals.py b/tests/test_modals.py index 0b4b6c0..697cb8b 100644 --- a/tests/test_modals.py +++ b/tests/test_modals.py @@ -28,6 +28,5 @@ async def test_respond_modal(dummy_bot, interaction): await interaction.respond_modal(modal) dummy_bot._http.create_interaction_response.assert_called_once() payload = dummy_bot._http.create_interaction_response.call_args.kwargs["payload"] - data = payload.to_dict() - assert data["type"] == InteractionCallbackType.MODAL.value - assert data["data"]["custom_id"] == "m1" + assert payload.type == InteractionCallbackType.MODAL + assert payload.data["custom_id"] == "m1" diff --git a/tests/test_tasks_extension.py b/tests/test_tasks_extension.py index 19f6f64..4b2abfd 100644 --- a/tests/test_tasks_extension.py +++ b/tests/test_tasks_extension.py @@ -58,3 +58,27 @@ async def test_loop_runs_and_stops() -> None: dummy.work.stop() # pylint: disable=no-member assert dummy.count >= 2 assert not dummy.work.running # pylint: disable=no-member + + +@pytest.mark.asyncio +async def test_before_after_loop_callbacks() -> None: + events: list[str] = [] + + @tasks.loop(seconds=0.01) + async def ticker() -> None: + events.append("tick") + + @ticker.before_loop + async def before() -> None: # pragma: no cover - trivial callback + events.append("before") + + @ticker.after_loop + async def after() -> None: # pragma: no cover - trivial callback + events.append("after") + + ticker.start() + await asyncio.sleep(0.03) + ticker.stop() + await asyncio.sleep(0.01) + assert events and events[0] == "before" + assert "after" in events diff --git a/tests/test_textchannel_history.py b/tests/test_textchannel_history.py new file mode 100644 index 0000000..4dfc2c3 --- /dev/null +++ b/tests/test_textchannel_history.py @@ -0,0 +1,65 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.client import Client +from disagreement.models import TextChannel, Message + + +@pytest.mark.asyncio +async def test_textchannel_history_paginates(): + first_page = [ + { + "id": "3", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m3", + "timestamp": "t", + }, + { + "id": "2", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m2", + "timestamp": "t", + }, + ] + second_page = [ + { + "id": "1", + "channel_id": "c", + "author": {"id": "1", "username": "u", "discriminator": "0001"}, + "content": "m1", + "timestamp": "t", + } + ] + http = SimpleNamespace(request=AsyncMock(side_effect=[first_page, second_page])) + client = Client.__new__(Client) + client._http = http + channel = TextChannel({"id": "c", "type": 0}, client) + + messages = [] + async for msg in channel.history(limit=3): + messages.append(msg) + + assert len(messages) == 3 + assert all(isinstance(m, Message) for m in messages) + http.request.assert_any_call("GET", "/channels/c/messages", params={"limit": 3}) + http.request.assert_any_call( + "GET", "/channels/c/messages", params={"limit": 1, "before": "2"} + ) + + +@pytest.mark.asyncio +async def test_textchannel_history_before_param(): + http = SimpleNamespace(request=AsyncMock(return_value=[])) + client = Client.__new__(Client) + client._http = http + channel = TextChannel({"id": "c", "type": 0}, client) + + messages = [m async for m in channel.history(limit=1, before="b")] + + assert messages == [] + http.request.assert_called_once_with( + "GET", "/channels/c/messages", params={"limit": 1, "before": "b"} + )