From 3a264f45302e98c44407b9496bf7a2e77229446b Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 15:17:42 -0600 Subject: [PATCH] feat(ext-loader): Support async setup functions Allow extension `setup` functions to be asynchronous. The loader now checks if `module.setup` returns an awaitable and runs it using asyncio, handling cases where an event loop is already running or not. This enables extensions to perform asynchronous initialization tasks. --- disagreement/ext/loader.py | 19 +++++++++++++++++-- tests/test_extension_loader.py | 26 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/disagreement/ext/loader.py b/disagreement/ext/loader.py index 0f14402..91267bd 100644 --- a/disagreement/ext/loader.py +++ b/disagreement/ext/loader.py @@ -1,9 +1,11 @@ from __future__ import annotations +import asyncio from importlib import import_module +import inspect import sys from types import ModuleType -from typing import Dict +from typing import Any, Coroutine, Dict, cast __all__ = ["load_extension", "unload_extension", "reload_extension"] @@ -25,7 +27,20 @@ def load_extension(name: str) -> ModuleType: if not hasattr(module, "setup"): raise ImportError(f"Extension '{name}' does not define a setup function") - module.setup() + result = module.setup() + if inspect.isawaitable(result): + coro = cast(Coroutine[Any, Any, Any], result) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + asyncio.run(coro) + else: + if loop.is_running(): + future = asyncio.run_coroutine_threadsafe(coro, loop) + future.result() + else: + loop.run_until_complete(coro) + _loaded_extensions[name] = module return module diff --git a/tests/test_extension_loader.py b/tests/test_extension_loader.py index 7531af1..be9cce6 100644 --- a/tests/test_extension_loader.py +++ b/tests/test_extension_loader.py @@ -22,6 +22,22 @@ def create_dummy_module(name): return called +def create_async_module(name): + mod = types.ModuleType(name) + called = {"setup": False, "teardown": False} + + async def setup(): + called["setup"] = True + + def teardown(): + called["teardown"] = True + + mod.setup = setup + mod.teardown = teardown + sys.modules[name] = mod + return called + + def test_load_and_unload_extension(): called = create_dummy_module("dummy_ext") @@ -75,3 +91,13 @@ def test_reload_extension(monkeypatch): loader.unload_extension("reload_ext") assert called_second["teardown"] is True + + +def test_async_setup(): + called = create_async_module("async_ext") + + loader.load_extension("async_ext") + assert called["setup"] is True + + loader.unload_extension("async_ext") + assert called["teardown"] is True