feat(ext-loader): Support async setup functions
Allow extension `setup` functions to be asynchronous. The loader now checks if `module.setup` returns an awaitable and runs it using asyncio, handling cases where an event loop is already running or not. This enables extensions to perform asynchronous initialization tasks.
This commit is contained in:
parent
a41a301927
commit
3a264f4530
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from importlib import import_module
|
||||
import inspect
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Dict
|
||||
from typing import Any, Coroutine, Dict, cast
|
||||
|
||||
__all__ = ["load_extension", "unload_extension", "reload_extension"]
|
||||
|
||||
@ -25,7 +27,20 @@ def load_extension(name: str) -> ModuleType:
|
||||
if not hasattr(module, "setup"):
|
||||
raise ImportError(f"Extension '{name}' does not define a setup function")
|
||||
|
||||
module.setup()
|
||||
result = module.setup()
|
||||
if inspect.isawaitable(result):
|
||||
coro = cast(Coroutine[Any, Any, Any], result)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
asyncio.run(coro)
|
||||
else:
|
||||
if loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
future.result()
|
||||
else:
|
||||
loop.run_until_complete(coro)
|
||||
|
||||
_loaded_extensions[name] = module
|
||||
return module
|
||||
|
||||
|
@ -22,6 +22,22 @@ def create_dummy_module(name):
|
||||
return called
|
||||
|
||||
|
||||
def create_async_module(name):
|
||||
mod = types.ModuleType(name)
|
||||
called = {"setup": False, "teardown": False}
|
||||
|
||||
async def setup():
|
||||
called["setup"] = True
|
||||
|
||||
def teardown():
|
||||
called["teardown"] = True
|
||||
|
||||
mod.setup = setup
|
||||
mod.teardown = teardown
|
||||
sys.modules[name] = mod
|
||||
return called
|
||||
|
||||
|
||||
def test_load_and_unload_extension():
|
||||
called = create_dummy_module("dummy_ext")
|
||||
|
||||
@ -75,3 +91,13 @@ def test_reload_extension(monkeypatch):
|
||||
|
||||
loader.unload_extension("reload_ext")
|
||||
assert called_second["teardown"] is True
|
||||
|
||||
|
||||
def test_async_setup():
|
||||
called = create_async_module("async_ext")
|
||||
|
||||
loader.load_extension("async_ext")
|
||||
assert called["setup"] is True
|
||||
|
||||
loader.unload_extension("async_ext")
|
||||
assert called["teardown"] is True
|
||||
|
Loading…
x
Reference in New Issue
Block a user