Extend tasks with scheduling options (#23)

This commit is contained in:
Slipstream 2025-06-10 15:44:35 -06:00 committed by GitHub
parent 484f091897
commit f9a7895ecb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 172 additions and 10 deletions

View File

@ -1,4 +1,5 @@
import asyncio
import datetime
from typing import Any, Awaitable, Callable, Optional
__all__ = ["loop", "Task"]
@ -7,16 +8,61 @@ __all__ = ["loop", "Task"]
class Task:
"""Simple repeating task."""
def __init__(self, coro: Callable[..., Awaitable[Any]], *, seconds: float) -> None:
def __init__(
self,
coro: Callable[..., Awaitable[Any]],
*,
seconds: float = 0.0,
minutes: float = 0.0,
hours: float = 0.0,
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
) -> None:
self._coro = coro
self._seconds = float(seconds)
self._task: Optional[asyncio.Task[None]] = None
if time_of_day is not None and (
seconds or minutes or hours or delta is not None
):
raise ValueError("time_of_day cannot be used with an interval")
if delta is not None:
if not isinstance(delta, datetime.timedelta):
raise TypeError("delta must be a datetime.timedelta")
interval_seconds = delta.total_seconds()
else:
interval_seconds = seconds + minutes * 60.0 + hours * 3600.0
self._seconds = float(interval_seconds)
self._time_of_day = time_of_day
self._on_error = on_error
def _seconds_until_time(self) -> float:
assert self._time_of_day is not None
now = datetime.datetime.now()
target = datetime.datetime.combine(now.date(), self._time_of_day)
if target <= now:
target += datetime.timedelta(days=1)
return (target - now).total_seconds()
async def _run(self, *args: Any, **kwargs: Any) -> None:
try:
first = True
while True:
await self._coro(*args, **kwargs)
await asyncio.sleep(self._seconds)
if self._time_of_day is not None:
await asyncio.sleep(self._seconds_until_time())
elif not first:
await asyncio.sleep(self._seconds)
try:
await self._coro(*args, **kwargs)
except Exception as exc: # noqa: BLE001
if self._on_error is not None:
await _maybe_call(self._on_error, exc)
else:
raise
first = False
except asyncio.CancelledError:
pass
@ -35,10 +81,33 @@ class Task:
return self._task is not None and not self._task.done()
async def _maybe_call(
func: Callable[[Exception], Awaitable[None] | None], exc: Exception
) -> None:
result = func(exc)
if asyncio.iscoroutine(result):
await result
class _Loop:
def __init__(self, func: Callable[..., Awaitable[Any]], seconds: float) -> None:
def __init__(
self,
func: Callable[..., Awaitable[Any]],
*,
seconds: float = 0.0,
minutes: float = 0.0,
hours: float = 0.0,
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
) -> None:
self.func = func
self.seconds = seconds
self.minutes = minutes
self.hours = hours
self.delta = delta
self.time_of_day = time_of_day
self.on_error = on_error
self._task: Optional[Task] = None
self._owner: Any = None
@ -51,7 +120,15 @@ class _Loop:
return self.func(self._owner, *args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
self._task = Task(self._coro, seconds=self.seconds)
self._task = Task(
self._coro,
seconds=self.seconds,
minutes=self.minutes,
hours=self.hours,
delta=self.delta,
time_of_day=self.time_of_day,
on_error=self.on_error,
)
return self._task.start(*args, **kwargs)
def stop(self) -> None:
@ -80,10 +157,26 @@ class _BoundLoop:
return self._parent.running
def loop(*, seconds: float) -> Callable[[Callable[..., Awaitable[Any]]], _Loop]:
def loop(
*,
seconds: float = 0.0,
minutes: float = 0.0,
hours: float = 0.0,
delta: Optional[datetime.timedelta] = None,
time_of_day: Optional[datetime.time] = None,
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
) -> Callable[[Callable[..., Awaitable[Any]]], _Loop]:
"""Decorator to create a looping task."""
def decorator(func: Callable[..., Awaitable[Any]]) -> _Loop:
return _Loop(func, seconds)
return _Loop(
func,
seconds=seconds,
minutes=minutes,
hours=hours,
delta=delta,
time_of_day=time_of_day,
on_error=on_error,
)
return decorator

View File

@ -1,11 +1,11 @@
# Task Loops
The tasks extension allows you to run functions periodically. Decorate an async function with `@tasks.loop(seconds=...)` and start it using `.start()`.
The tasks extension allows you to run functions periodically. Decorate an async function with `@tasks.loop` and start it using `.start()`.
```python
from disagreement.ext import tasks
@tasks.loop(seconds=5.0)
@tasks.loop(minutes=1.0)
async def announce():
print("Hello from a loop")
@ -13,3 +13,36 @@ announce.start()
```
Stop the loop with `.stop()` when you no longer need it.
You can provide the interval in seconds, minutes, hours or as a `datetime.timedelta`:
```python
import datetime
@tasks.loop(delta=datetime.timedelta(seconds=30))
async def ping():
...
```
Handle exceptions raised by the looped coroutine using `on_error`:
```python
async def log_error(exc: Exception) -> None:
print("Loop failed:", exc)
@tasks.loop(seconds=5.0, on_error=log_error)
async def worker():
...
```
You can also schedule a task at a specific time of day:
```python
from datetime import datetime, timedelta
time_to_run = (datetime.now() + timedelta(seconds=5)).time()
@tasks.loop(time_of_day=time_to_run)
async def daily_task():
...
```

View File

@ -1,4 +1,5 @@
import asyncio
import datetime
import pytest
@ -14,6 +15,41 @@ class Dummy:
self.count += 1
@pytest.mark.asyncio
async def test_loop_on_error_callback_called() -> None:
called = False
def handler(exc: Exception) -> None: # pragma: no cover - simple callback
nonlocal called
called = True
@tasks.loop(seconds=0.01, on_error=handler)
async def failing() -> None:
raise RuntimeError("fail")
failing.start()
await asyncio.sleep(0.03)
failing.stop()
assert called
@pytest.mark.asyncio
async def test_loop_time_of_day() -> None:
run_count = 0
target_time = (datetime.datetime.now() + datetime.timedelta(seconds=0.05)).time()
@tasks.loop(time_of_day=target_time)
async def daily() -> None:
nonlocal run_count
run_count += 1
daily.start()
await asyncio.sleep(0.1)
daily.stop()
assert run_count >= 1
@pytest.mark.asyncio
async def test_loop_runs_and_stops() -> None:
dummy = Dummy()