From df77a3fcec5826ccd9f16c72619689f8194c2c56 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Tue, 10 Jun 2025 17:52:05 -0600 Subject: [PATCH] Add before/after loop callbacks --- disagreement/ext/tasks.py | 46 +++++++++++++++++++++++++++++++++++ docs/task_loop.md | 16 ++++++++++++ tests/test_modals.py | 4 +-- tests/test_tasks_extension.py | 24 ++++++++++++++++++ 4 files changed, 88 insertions(+), 2 deletions(-) diff --git a/disagreement/ext/tasks.py b/disagreement/ext/tasks.py index 7290c90..986e2da 100644 --- a/disagreement/ext/tasks.py +++ b/disagreement/ext/tasks.py @@ -18,6 +18,8 @@ class Task: delta: Optional[datetime.timedelta] = None, time_of_day: Optional[datetime.time] = None, on_error: Optional[Callable[[Exception], Awaitable[None]]] = None, + before_loop: Optional[Callable[[], Awaitable[None] | None]] = None, + after_loop: Optional[Callable[[], Awaitable[None] | None]] = None, ) -> None: self._coro = coro self._task: Optional[asyncio.Task[None]] = None @@ -36,6 +38,8 @@ class Task: self._seconds = float(interval_seconds) self._time_of_day = time_of_day self._on_error = on_error + self._before_loop = before_loop + self._after_loop = after_loop def _seconds_until_time(self) -> float: assert self._time_of_day is not None @@ -47,6 +51,9 @@ class Task: async def _run(self, *args: Any, **kwargs: Any) -> None: try: + if self._before_loop is not None: + await _maybe_call_no_args(self._before_loop) + first = True while True: if self._time_of_day is not None: @@ -65,6 +72,9 @@ class Task: first = False except asyncio.CancelledError: pass + finally: + if self._after_loop is not None: + await _maybe_call_no_args(self._after_loop) def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: if self._task is None or self._task.done(): @@ -89,6 +99,12 @@ async def _maybe_call( await result +async def _maybe_call_no_args(func: Callable[[], Awaitable[None] | None]) -> None: + result = func() + if asyncio.iscoroutine(result): + await result + + class _Loop: def __init__( self, @@ -110,6 +126,8 @@ class _Loop: self.on_error = on_error self._task: Optional[Task] = None self._owner: Any = None + self._before_loop: Optional[Callable[..., Awaitable[Any]]] = None + self._after_loop: Optional[Callable[..., Awaitable[Any]]] = None def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop": return _BoundLoop(self, obj) @@ -119,7 +137,33 @@ class _Loop: return self.func(*args, **kwargs) return self.func(self._owner, *args, **kwargs) + def before_loop( + self, func: Callable[..., Awaitable[Any]] + ) -> Callable[..., Awaitable[Any]]: + self._before_loop = func + return func + + def after_loop( + self, func: Callable[..., Awaitable[Any]] + ) -> Callable[..., Awaitable[Any]]: + self._after_loop = func + return func + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: + def call_before() -> Awaitable[None] | None: + if self._before_loop is None: + return None + if self._owner is not None: + return self._before_loop(self._owner) + return self._before_loop() + + def call_after() -> Awaitable[None] | None: + if self._after_loop is None: + return None + if self._owner is not None: + return self._after_loop(self._owner) + return self._after_loop() + self._task = Task( self._coro, seconds=self.seconds, @@ -128,6 +172,8 @@ class _Loop: delta=self.delta, time_of_day=self.time_of_day, on_error=self.on_error, + before_loop=call_before, + after_loop=call_after, ) return self._task.start(*args, **kwargs) diff --git a/docs/task_loop.md b/docs/task_loop.md index b054909..c60a689 100644 --- a/docs/task_loop.md +++ b/docs/task_loop.md @@ -35,6 +35,22 @@ async def worker(): ... ``` +Run setup and teardown code using `before_loop` and `after_loop`: + +```python +@tasks.loop(seconds=5.0) +async def worker(): + ... + +@worker.before_loop +async def before_worker(): + print("starting") + +@worker.after_loop +async def after_worker(): + print("stopped") +``` + You can also schedule a task at a specific time of day: ```python diff --git a/tests/test_modals.py b/tests/test_modals.py index 3d1fdff..697cb8b 100644 --- a/tests/test_modals.py +++ b/tests/test_modals.py @@ -28,5 +28,5 @@ async def test_respond_modal(dummy_bot, interaction): await interaction.respond_modal(modal) dummy_bot._http.create_interaction_response.assert_called_once() payload = dummy_bot._http.create_interaction_response.call_args.kwargs["payload"] - assert payload["type"] == InteractionCallbackType.MODAL.value - assert payload["data"]["custom_id"] == "m1" + assert payload.type == InteractionCallbackType.MODAL + assert payload.data["custom_id"] == "m1" diff --git a/tests/test_tasks_extension.py b/tests/test_tasks_extension.py index 19f6f64..4b2abfd 100644 --- a/tests/test_tasks_extension.py +++ b/tests/test_tasks_extension.py @@ -58,3 +58,27 @@ async def test_loop_runs_and_stops() -> None: dummy.work.stop() # pylint: disable=no-member assert dummy.count >= 2 assert not dummy.work.running # pylint: disable=no-member + + +@pytest.mark.asyncio +async def test_before_after_loop_callbacks() -> None: + events: list[str] = [] + + @tasks.loop(seconds=0.01) + async def ticker() -> None: + events.append("tick") + + @ticker.before_loop + async def before() -> None: # pragma: no cover - trivial callback + events.append("before") + + @ticker.after_loop + async def after() -> None: # pragma: no cover - trivial callback + events.append("after") + + ticker.start() + await asyncio.sleep(0.03) + ticker.stop() + await asyncio.sleep(0.01) + assert events and events[0] == "before" + assert "after" in events