Add cog retrieval helpers (#89)
Some checks failed
Deploy MkDocs / deploy (push) Has been cancelled

This commit is contained in:
Slipstream 2025-06-15 18:15:45 -06:00 committed by GitHub
parent 98afb89629
commit 223c86cb78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 45 additions and 4 deletions

View File

@ -62,6 +62,7 @@ if not token:
intents = disagreement.GatewayIntent.default() | disagreement.GatewayIntent.MESSAGE_CONTENT
client = disagreement.Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
client.add_cog(Basics(client))
client.run()
```

View File

@ -632,7 +632,7 @@ class Client:
f"Registered app command/group '{app_cmd_obj.name}' from cog '{cog.cog_name}'."
)
def remove_cog(self, cog_name: str) -> Optional[Cog]:
def remove_cog(self, cog_name: str) -> Optional[Cog]:
"""
Removes a Cog from the bot.
@ -659,7 +659,12 @@ class Client:
# Note: AppCommandHandler.remove_command might need to be more specific if names aren't globally unique
# (e.g. if it needs type or if groups and commands can share names).
# For now, assuming name is sufficient for removal from the handler's flat list.
return removed_cog
return removed_cog
def get_cog(self, name: str) -> Optional[Cog]:
"""Return a loaded cog by name if present."""
return self.command_handler.get_cog(name)
def check(self, coro: Callable[["CommandContext"], Awaitable[bool]]):
"""

View File

@ -363,8 +363,13 @@ class CommandHandler:
self.commands.pop(alias.lower(), None)
return command
def get_command(self, name: str) -> Optional[Command]:
return self.commands.get(name.lower())
def get_command(self, name: str) -> Optional[Command]:
return self.commands.get(name.lower())
def get_cog(self, name: str) -> Optional["Cog"]:
"""Return a loaded cog by name if present."""
return self.cogs.get(name)
def add_cog(self, cog_to_add: "Cog") -> None:
from .cog import Cog

View File

@ -61,6 +61,7 @@ if not token:
intents = GatewayIntent.default() | GatewayIntent.MESSAGE_CONTENT
client = Client(token=token, command_prefix="!", intents=intents, mention_replies=True)
client.add_cog(Basics(client))
client.run()
```

29
tests/test_get_cog.py Normal file
View File

@ -0,0 +1,29 @@
import asyncio
import pytest
from disagreement.client import Client
from disagreement.ext import commands
class DummyCog(commands.Cog):
def __init__(self, client: Client) -> None:
super().__init__(client)
@pytest.mark.asyncio()
async def test_command_handler_get_cog():
bot = object()
handler = commands.core.CommandHandler(client=bot, prefix="!")
cog = DummyCog(bot) # type: ignore[arg-type]
handler.add_cog(cog)
await asyncio.sleep(0) # allow any scheduled tasks to start
assert handler.get_cog("DummyCog") is cog
@pytest.mark.asyncio()
async def test_client_get_cog():
client = Client(token="t")
cog = DummyCog(client)
client.add_cog(cog)
await asyncio.sleep(0)
assert client.get_cog("DummyCog") is cog