feat(tasks): allow runtime interval change (#106)

This commit is contained in:
Slipstream 2025-06-15 18:49:45 -06:00 committed by GitHub
parent a222dec661
commit 7f9647a442
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 100 additions and 0 deletions

View File

@ -23,6 +23,7 @@ class Task:
) -> None: ) -> None:
self._coro = coro self._coro = coro
self._task: Optional[asyncio.Task[None]] = None self._task: Optional[asyncio.Task[None]] = None
self._current_loop = 0
if time_of_day is not None and ( if time_of_day is not None and (
seconds or minutes or hours or delta is not None seconds or minutes or hours or delta is not None
): ):
@ -68,6 +69,7 @@ class Task:
await _maybe_call(self._on_error, exc) await _maybe_call(self._on_error, exc)
else: else:
raise raise
self._current_loop += 1
first = False first = False
except asyncio.CancelledError: except asyncio.CancelledError:
@ -78,6 +80,7 @@ class Task:
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
if self._task is None or self._task.done(): if self._task is None or self._task.done():
self._current_loop = 0
self._task = asyncio.create_task(self._run(*args, **kwargs)) self._task = asyncio.create_task(self._run(*args, **kwargs))
return self._task return self._task
@ -90,6 +93,34 @@ class Task:
def running(self) -> bool: def running(self) -> bool:
return self._task is not None and not self._task.done() return self._task is not None and not self._task.done()
@property
def current_loop(self) -> int:
return self._current_loop
def change_interval(
self,
*,
seconds: float = 0.0,
minutes: float = 0.0,
hours: float = 0.0,
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
) -> None:
if time_of_day is not None and (
seconds or minutes or hours or delta is not None
):
raise ValueError("time_of_day cannot be used with an interval")
if delta is not None:
if not isinstance(delta, datetime.timedelta):
raise TypeError("delta must be a datetime.timedelta")
interval_seconds = delta.total_seconds()
else:
interval_seconds = seconds + minutes * 60.0 + hours * 3600.0
self._seconds = float(interval_seconds)
self._time_of_day = time_of_day
async def _maybe_call( async def _maybe_call(
func: Callable[[Exception], Awaitable[None] | None], exc: Exception func: Callable[[Exception], Awaitable[None] | None], exc: Exception
@ -181,10 +212,37 @@ class _Loop:
if self._task is not None: if self._task is not None:
self._task.stop() self._task.stop()
def change_interval(
self,
*,
seconds: float = 0.0,
minutes: float = 0.0,
hours: float = 0.0,
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
) -> None:
self.seconds = seconds
self.minutes = minutes
self.hours = hours
self.delta = delta
self.time_of_day = time_of_day
if self._task is not None:
self._task.change_interval(
seconds=seconds,
minutes=minutes,
hours=hours,
delta=delta,
time_of_day=time_of_day,
)
@property @property
def running(self) -> bool: def running(self) -> bool:
return self._task.running if self._task else False return self._task.running if self._task else False
@property
def current_loop(self) -> int:
return self._task.current_loop if self._task else 0
class _BoundLoop: class _BoundLoop:
def __init__(self, parent: _Loop, owner: Any) -> None: def __init__(self, parent: _Loop, owner: Any) -> None:
@ -202,6 +260,27 @@ class _BoundLoop:
def running(self) -> bool: def running(self) -> bool:
return self._parent.running return self._parent.running
def change_interval(
self,
*,
seconds: float = 0.0,
minutes: float = 0.0,
hours: float = 0.0,
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
) -> None:
self._parent.change_interval(
seconds=seconds,
minutes=minutes,
hours=hours,
delta=delta,
time_of_day=time_of_day,
)
@property
def current_loop(self) -> int:
return self._parent.current_loop
def loop( def loop(
*, *,

View File

@ -82,3 +82,24 @@ async def test_before_after_loop_callbacks() -> None:
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert events and events[0] == "before" assert events and events[0] == "before"
assert "after" in events assert "after" in events
@pytest.mark.asyncio
async def test_change_interval_and_current_loop() -> None:
count = 0
@tasks.loop(seconds=0.01)
async def ticker() -> None:
nonlocal count
count += 1
ticker.start()
await asyncio.sleep(0.03)
initial = ticker.current_loop
ticker.change_interval(seconds=0.02)
await asyncio.sleep(0.05)
ticker.stop()
assert initial >= 2
assert ticker.current_loop > initial
assert count == ticker.current_loop