diff --git a/README.md b/README.md index ef2aceb..26e5c55 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ if not token: intents = disagreement.GatewayIntent.default() | disagreement.GatewayIntent.MESSAGE_CONTENT client = disagreement.Client(token=token, command_prefix="!", intents=intents, mention_replies=True) + client.add_cog(Basics(client)) client.run() ``` diff --git a/disagreement/client.py b/disagreement/client.py index c0e8240..7efbb11 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -632,7 +632,7 @@ class Client: f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'." ) - def remove_cog(self, cog_name: str) -> Optional[Cog]: + def remove_cog(self, cog_name: str) -> Optional[Cog]: """ Removes a Cog from the bot. @@ -659,7 +659,12 @@ class Client: # Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique # (e.g. if it needs type or if groups and commands can share names). # For now, assuming name is sufficient for removal from the handler's flat list. - return removed_cog + return removed_cog + + def get_cog(self, name: str) -> Optional[Cog]: + """Return a loaded cog by name if present.""" + + return self.command_handler.get_cog(name) def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]): """ diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py index 0e76282..663cb65 100644 --- a/disagreement/ext/commands/core.py +++ b/disagreement/ext/commands/core.py @@ -363,8 +363,13 @@ class CommandHandler: self.commands.pop(alias.lower(), None) return command - def get_command(self, name: str) -> Optional[Command]: - return self.commands.get(name.lower()) + def get_command(self, name: str) -> Optional[Command]: + return self.commands.get(name.lower()) + + def get_cog(self, name: str) -> Optional["Cog"]: + """Return a loaded cog by name if present.""" + + return self.cogs.get(name) def add_cog(self, cog_to_add: "Cog") -> None: from .cog import Cog diff --git a/docs/introduction.md b/docs/introduction.md index c2748f1..b558c15 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -61,6 +61,7 @@ if not token: intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True) + client.add_cog(Basics(client)) client.run() ``` diff --git a/tests/test_get_cog.py b/tests/test_get_cog.py new file mode 100644 index 0000000..7bf7b3a --- /dev/null +++ b/tests/test_get_cog.py @@ -0,0 +1,29 @@ +import asyncio +import pytest + +from disagreement.client import Client +from disagreement.ext import commands + + +class DummyCog(commands.Cog): + def __init__(self, client: Client) -> None: + super().__init__(client) + + +@pytest.mark.asyncio() +async def test_command_handler_get_cog(): + bot = object() + handler = commands.core.CommandHandler(client=bot, prefix="!") + cog = DummyCog(bot) # type: ignore[arg-type] + handler.add_cog(cog) + await asyncio.sleep(0) # allow any scheduled tasks to start + assert handler.get_cog("DummyCog") is cog + + +@pytest.mark.asyncio() +async def test_client_get_cog(): + client = Client(token="t") + cog = DummyCog(client) + client.add_cog(cog) + await asyncio.sleep(0) + assert client.get_cog("DummyCog") is cog