Add before/after loop callbacks
This commit is contained in:
parent
534b5b3980
commit
df77a3fcec
@ -18,6 +18,8 @@ class Task:
|
|||||||
delta: Optional[datetime.timedelta] = None,
|
delta: Optional[datetime.timedelta] = None,
|
||||||
time_of_day: Optional[datetime.time] = None,
|
time_of_day: Optional[datetime.time] = None,
|
||||||
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
||||||
|
before_loop: Optional[Callable[[], Awaitable[None] | None]] = None,
|
||||||
|
after_loop: Optional[Callable[[], Awaitable[None] | None]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._coro = coro
|
self._coro = coro
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
self._task: Optional[asyncio.Task[None]] = None
|
||||||
@ -36,6 +38,8 @@ class Task:
|
|||||||
self._seconds = float(interval_seconds)
|
self._seconds = float(interval_seconds)
|
||||||
self._time_of_day = time_of_day
|
self._time_of_day = time_of_day
|
||||||
self._on_error = on_error
|
self._on_error = on_error
|
||||||
|
self._before_loop = before_loop
|
||||||
|
self._after_loop = after_loop
|
||||||
|
|
||||||
def _seconds_until_time(self) -> float:
|
def _seconds_until_time(self) -> float:
|
||||||
assert self._time_of_day is not None
|
assert self._time_of_day is not None
|
||||||
@ -47,6 +51,9 @@ class Task:
|
|||||||
|
|
||||||
async def _run(self, *args: Any, **kwargs: Any) -> None:
|
async def _run(self, *args: Any, **kwargs: Any) -> None:
|
||||||
try:
|
try:
|
||||||
|
if self._before_loop is not None:
|
||||||
|
await _maybe_call_no_args(self._before_loop)
|
||||||
|
|
||||||
first = True
|
first = True
|
||||||
while True:
|
while True:
|
||||||
if self._time_of_day is not None:
|
if self._time_of_day is not None:
|
||||||
@ -65,6 +72,9 @@ class Task:
|
|||||||
first = False
|
first = False
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
finally:
|
||||||
|
if self._after_loop is not None:
|
||||||
|
await _maybe_call_no_args(self._after_loop)
|
||||||
|
|
||||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||||
if self._task is None or self._task.done():
|
if self._task is None or self._task.done():
|
||||||
@ -89,6 +99,12 @@ async def _maybe_call(
|
|||||||
await result
|
await result
|
||||||
|
|
||||||
|
|
||||||
|
async def _maybe_call_no_args(func: Callable[[], Awaitable[None] | None]) -> None:
|
||||||
|
result = func()
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
await result
|
||||||
|
|
||||||
|
|
||||||
class _Loop:
|
class _Loop:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -110,6 +126,8 @@ class _Loop:
|
|||||||
self.on_error = on_error
|
self.on_error = on_error
|
||||||
self._task: Optional[Task] = None
|
self._task: Optional[Task] = None
|
||||||
self._owner: Any = None
|
self._owner: Any = None
|
||||||
|
self._before_loop: Optional[Callable[..., Awaitable[Any]]] = None
|
||||||
|
self._after_loop: Optional[Callable[..., Awaitable[Any]]] = None
|
||||||
|
|
||||||
def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop":
|
def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop":
|
||||||
return _BoundLoop(self, obj)
|
return _BoundLoop(self, obj)
|
||||||
@ -119,7 +137,33 @@ class _Loop:
|
|||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
return self.func(self._owner, *args, **kwargs)
|
return self.func(self._owner, *args, **kwargs)
|
||||||
|
|
||||||
|
def before_loop(
|
||||||
|
self, func: Callable[..., Awaitable[Any]]
|
||||||
|
) -> Callable[..., Awaitable[Any]]:
|
||||||
|
self._before_loop = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
def after_loop(
|
||||||
|
self, func: Callable[..., Awaitable[Any]]
|
||||||
|
) -> Callable[..., Awaitable[Any]]:
|
||||||
|
self._after_loop = func
|
||||||
|
return func
|
||||||
|
|
||||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||||
|
def call_before() -> Awaitable[None] | None:
|
||||||
|
if self._before_loop is None:
|
||||||
|
return None
|
||||||
|
if self._owner is not None:
|
||||||
|
return self._before_loop(self._owner)
|
||||||
|
return self._before_loop()
|
||||||
|
|
||||||
|
def call_after() -> Awaitable[None] | None:
|
||||||
|
if self._after_loop is None:
|
||||||
|
return None
|
||||||
|
if self._owner is not None:
|
||||||
|
return self._after_loop(self._owner)
|
||||||
|
return self._after_loop()
|
||||||
|
|
||||||
self._task = Task(
|
self._task = Task(
|
||||||
self._coro,
|
self._coro,
|
||||||
seconds=self.seconds,
|
seconds=self.seconds,
|
||||||
@ -128,6 +172,8 @@ class _Loop:
|
|||||||
delta=self.delta,
|
delta=self.delta,
|
||||||
time_of_day=self.time_of_day,
|
time_of_day=self.time_of_day,
|
||||||
on_error=self.on_error,
|
on_error=self.on_error,
|
||||||
|
before_loop=call_before,
|
||||||
|
after_loop=call_after,
|
||||||
)
|
)
|
||||||
return self._task.start(*args, **kwargs)
|
return self._task.start(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -35,6 +35,22 @@ async def worker():
|
|||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Run setup and teardown code using `before_loop` and `after_loop`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@tasks.loop(seconds=5.0)
|
||||||
|
async def worker():
|
||||||
|
...
|
||||||
|
|
||||||
|
@worker.before_loop
|
||||||
|
async def before_worker():
|
||||||
|
print("starting")
|
||||||
|
|
||||||
|
@worker.after_loop
|
||||||
|
async def after_worker():
|
||||||
|
print("stopped")
|
||||||
|
```
|
||||||
|
|
||||||
You can also schedule a task at a specific time of day:
|
You can also schedule a task at a specific time of day:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -28,5 +28,5 @@ async def test_respond_modal(dummy_bot, interaction):
|
|||||||
await interaction.respond_modal(modal)
|
await interaction.respond_modal(modal)
|
||||||
dummy_bot._http.create_interaction_response.assert_called_once()
|
dummy_bot._http.create_interaction_response.assert_called_once()
|
||||||
payload = dummy_bot._http.create_interaction_response.call_args.kwargs["payload"]
|
payload = dummy_bot._http.create_interaction_response.call_args.kwargs["payload"]
|
||||||
assert payload["type"] == InteractionCallbackType.MODAL.value
|
assert payload.type == InteractionCallbackType.MODAL
|
||||||
assert payload["data"]["custom_id"] == "m1"
|
assert payload.data["custom_id"] == "m1"
|
||||||
|
@ -58,3 +58,27 @@ async def test_loop_runs_and_stops() -> None:
|
|||||||
dummy.work.stop() # pylint: disable=no-member
|
dummy.work.stop() # pylint: disable=no-member
|
||||||
assert dummy.count >= 2
|
assert dummy.count >= 2
|
||||||
assert not dummy.work.running # pylint: disable=no-member
|
assert not dummy.work.running # pylint: disable=no-member
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_after_loop_callbacks() -> None:
|
||||||
|
events: list[str] = []
|
||||||
|
|
||||||
|
@tasks.loop(seconds=0.01)
|
||||||
|
async def ticker() -> None:
|
||||||
|
events.append("tick")
|
||||||
|
|
||||||
|
@ticker.before_loop
|
||||||
|
async def before() -> None: # pragma: no cover - trivial callback
|
||||||
|
events.append("before")
|
||||||
|
|
||||||
|
@ticker.after_loop
|
||||||
|
async def after() -> None: # pragma: no cover - trivial callback
|
||||||
|
events.append("after")
|
||||||
|
|
||||||
|
ticker.start()
|
||||||
|
await asyncio.sleep(0.03)
|
||||||
|
ticker.stop()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert events and events[0] == "before"
|
||||||
|
assert "after" in events
|
||||||
|
Loading…
x
Reference in New Issue
Block a user