diff --git a/disagreement/ext/loader.py b/disagreement/ext/loader.py index 91267bd..4d4c0b8 100644 --- a/disagreement/ext/loader.py +++ b/disagreement/ext/loader.py @@ -53,7 +53,19 @@ def unload_extension(name: str) -> None: raise ValueError(f"Extension '{name}' is not loaded") if hasattr(module, "teardown"): - module.teardown() + result = module.teardown() + if inspect.isawaitable(result): + coro = cast(Coroutine[Any, Any, Any], result) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + asyncio.run(coro) + else: + if loop.is_running(): + future = asyncio.run_coroutine_threadsafe(coro, loop) + future.result() + else: + loop.run_until_complete(coro) sys.modules.pop(name, None) diff --git a/tests/test_extension_loader.py b/tests/test_extension_loader.py index be9cce6..f476bad 100644 --- a/tests/test_extension_loader.py +++ b/tests/test_extension_loader.py @@ -38,6 +38,22 @@ def create_async_module(name): return called +def create_async_teardown_module(name): + mod = types.ModuleType(name) + called = {"setup": False, "teardown": False} + + def setup(): + called["setup"] = True + + async def teardown(): + called["teardown"] = True + + mod.setup = setup + mod.teardown = teardown + sys.modules[name] = mod + return called + + def test_load_and_unload_extension(): called = create_dummy_module("dummy_ext") @@ -101,3 +117,13 @@ def test_async_setup(): loader.unload_extension("async_ext") assert called["teardown"] is True + + +def test_async_teardown(): + called = create_async_teardown_module("async_teardown_ext") + + loader.load_extension("async_teardown_ext") + assert called["setup"] is True + + loader.unload_extension("async_teardown_ext") + assert called["teardown"] is True