diff --git a/disagreement/ext/tasks.py b/disagreement/ext/tasks.py index 986e2da..394815b 100644 --- a/disagreement/ext/tasks.py +++ b/disagreement/ext/tasks.py @@ -23,6 +23,7 @@ class Task: ) -> None: self._coro = coro self._task: Optional[asyncio.Task[None]] = None + self._current_loop = 0 if time_of_day is not None and ( seconds or minutes or hours or delta is not None ): @@ -68,6 +69,7 @@ class Task: await _maybe_call(self._on_error, exc) else: raise + self._current_loop += 1 first = False except asyncio.CancelledError: @@ -78,6 +80,7 @@ class Task: def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: if self._task is None or self._task.done(): + self._current_loop = 0 self._task = asyncio.create_task(self._run(*args, **kwargs)) return self._task @@ -90,6 +93,34 @@ class Task: def running(self) -> bool: return self._task is not None and not self._task.done() + @property + def current_loop(self) -> int: + return self._current_loop + + def change_interval( + self, + *, + seconds: float = 0.0, + minutes: float = 0.0, + hours: float = 0.0, + delta: Optional[datetime.timedelta] = None, + time_of_day: Optional[datetime.time] = None, + ) -> None: + if time_of_day is not None and ( + seconds or minutes or hours or delta is not None + ): + raise ValueError("time_of_day cannot be used with an interval") + + if delta is not None: + if not isinstance(delta, datetime.timedelta): + raise TypeError("delta must be a datetime.timedelta") + interval_seconds = delta.total_seconds() + else: + interval_seconds = seconds + minutes * 60.0 + hours * 3600.0 + + self._seconds = float(interval_seconds) + self._time_of_day = time_of_day + async def _maybe_call( func: Callable[[Exception], Awaitable[None] | None], exc: Exception @@ -181,10 +212,37 @@ class _Loop: if self._task is not None: self._task.stop() + def change_interval( + self, + *, + seconds: float = 0.0, + minutes: float = 0.0, + hours: float = 0.0, + delta: Optional[datetime.timedelta] = None, + time_of_day: Optional[datetime.time] = None, + ) -> None: + self.seconds = seconds + self.minutes = minutes + self.hours = hours + self.delta = delta + self.time_of_day = time_of_day + if self._task is not None: + self._task.change_interval( + seconds=seconds, + minutes=minutes, + hours=hours, + delta=delta, + time_of_day=time_of_day, + ) + @property def running(self) -> bool: return self._task.running if self._task else False + @property + def current_loop(self) -> int: + return self._task.current_loop if self._task else 0 + class _BoundLoop: def __init__(self, parent: _Loop, owner: Any) -> None: @@ -202,6 +260,27 @@ class _BoundLoop: def running(self) -> bool: return self._parent.running + def change_interval( + self, + *, + seconds: float = 0.0, + minutes: float = 0.0, + hours: float = 0.0, + delta: Optional[datetime.timedelta] = None, + time_of_day: Optional[datetime.time] = None, + ) -> None: + self._parent.change_interval( + seconds=seconds, + minutes=minutes, + hours=hours, + delta=delta, + time_of_day=time_of_day, + ) + + @property + def current_loop(self) -> int: + return self._parent.current_loop + def loop( *, diff --git a/tests/test_tasks_extension.py b/tests/test_tasks_extension.py index 4b2abfd..7a96b3b 100644 --- a/tests/test_tasks_extension.py +++ b/tests/test_tasks_extension.py @@ -82,3 +82,24 @@ async def test_before_after_loop_callbacks() -> None: await asyncio.sleep(0.01) assert events and events[0] == "before" assert "after" in events + + +@pytest.mark.asyncio +async def test_change_interval_and_current_loop() -> None: + count = 0 + + @tasks.loop(seconds=0.01) + async def ticker() -> None: + nonlocal count + count += 1 + + ticker.start() + await asyncio.sleep(0.03) + initial = ticker.current_loop + ticker.change_interval(seconds=0.02) + await asyncio.sleep(0.05) + ticker.stop() + + assert initial >= 2 + assert ticker.current_loop > initial + assert count == ticker.current_loop