diff --git a/disagreement/client.py b/disagreement/client.py index 4ddbaaf..11694bb 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -16,6 +16,7 @@ from typing import ( List, Dict, ) +from types import ModuleType from .http import HTTPClient from .gateway import GatewayClient @@ -28,6 +29,7 @@ from .ext.commands.core import CommandHandler from .ext.commands.cog import Cog from .ext.app_commands.handler import AppCommandHandler from .ext.app_commands.context import AppCommandContext +from .ext import loader as ext_loader from .interactions import Interaction, Snowflake from .error_handler import setup_global_error_handler from .voice_client import VoiceClient @@ -638,6 +640,23 @@ class Client: # import traceback # traceback.print_exception(type(error.original), error.original, error.original.__traceback__) + # --- Extension Management Methods --- + + def load_extension(self, name: str) -> ModuleType: + """Load an extension by name using :mod:`disagreement.ext.loader`.""" + + return ext_loader.load_extension(name) + + def unload_extension(self, name: str) -> None: + """Unload a previously loaded extension.""" + + ext_loader.unload_extension(name) + + def reload_extension(self, name: str) -> ModuleType: + """Reload an extension by name.""" + + return ext_loader.reload_extension(name) + # --- Model Parsing and Fetching --- def parse_user(self, data: Dict[str, Any]) -> "User": diff --git a/disagreement/interactions.py b/disagreement/interactions.py index 4b89b6b..ee3ae4a 100644 --- a/disagreement/interactions.py +++ b/disagreement/interactions.py @@ -395,6 +395,8 @@ class Interaction: async def respond_modal(self, modal: "Modal") -> None: """|coro| Send a modal in response to this interaction.""" + from typing import Any, cast + payload = InteractionResponsePayload( type=InteractionCallbackType.MODAL, data=modal.to_dict(), @@ -402,7 +404,7 @@ class Interaction: await self._client._http.create_interaction_response( interaction_id=self.id, interaction_token=self.token, - payload=payload.to_dict(), # type: ignore[arg-type] + payload=cast(Any, payload.to_dict()), ) async def edit( diff --git a/examples/extension_management.py b/examples/extension_management.py new file mode 100644 index 0000000..8ba86cc --- /dev/null +++ b/examples/extension_management.py @@ -0,0 +1,45 @@ +"""Demonstrates dynamic extension loading using Client.load_extension.""" + +import asyncio +import os +import sys + +from dotenv import load_dotenv + +# Allow running from the examples folder without installing +if os.path.join(os.getcwd(), "examples") == os.path.dirname(os.path.abspath(__file__)): + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from disagreement import Client + +load_dotenv() + +TOKEN = os.environ.get("DISCORD_BOT_TOKEN") + + +async def main() -> None: + if not TOKEN: + print("DISCORD_BOT_TOKEN environment variable not set") + return + + client = Client(token=TOKEN) + + # Load the extension which starts a simple ticker task + client.load_extension("examples.sample_extension") + + await client.connect() + await asyncio.sleep(6) + + # Reload the extension to restart the ticker + client.reload_extension("examples.sample_extension") + await asyncio.sleep(6) + + # Unload the extension and stop the ticker + client.unload_extension("examples.sample_extension") + + await asyncio.sleep(1) + await client.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/sample_extension.py b/examples/sample_extension.py new file mode 100644 index 0000000..e67da7b --- /dev/null +++ b/examples/sample_extension.py @@ -0,0 +1,16 @@ +from disagreement.ext import tasks + + +@tasks.loop(seconds=2.0) +async def ticker() -> None: + print("Extension tick") + + +def setup() -> None: + print("sample_extension setup") + ticker.start() + + +def teardown() -> None: + print("sample_extension teardown") + ticker.stop()