From 8e88aaec2f360ce0091776636c9d4cc96186f6d5 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:42:21 -0600 Subject: [PATCH] Implement channel and member aggregation (#117) --- disagreement/cache.py | 4 +- disagreement/caching.py | 8 ++ disagreement/client.py | 28 +++- disagreement/http.py | 173 ++++++++++++------------ disagreement/models.py | 286 ++++++++++++++++++++-------------------- tests/test_cache.py | 73 ++++++++++ 6 files changed, 339 insertions(+), 233 deletions(-) diff --git a/disagreement/cache.py b/disagreement/cache.py index 32c6639..456797c 100644 --- a/disagreement/cache.py +++ b/disagreement/cache.py @@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]): def _should_cache(self, member: Member) -> bool: """Determines if a member should be cached based on the flags.""" - if self.flags.all: + if self.flags.all_enabled: return True - if self.flags.none: + if self.flags.no_flags: return False if self.flags.online and member.status != "offline": diff --git a/disagreement/caching.py b/disagreement/caching.py index e9a481f..9decdd2 100644 --- a/disagreement/caching.py +++ b/disagreement/caching.py @@ -74,6 +74,14 @@ class MemberCacheFlags: for name in self.VALID_FLAGS: yield name, getattr(self, name) + @property + def all_enabled(self) -> bool: + return self.value == self.ALL_FLAGS + + @property + def no_flags(self) -> bool: + return self.value == 0 + def __int__(self) -> int: return self.value diff --git a/disagreement/client.py b/disagreement/client.py index fecc20e..2c4d4a9 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -1448,10 +1448,30 @@ class Client: return self._channels.get(channel_id) - def get_message(self, message_id: Snowflake) -> Optional["Message"]: - """Returns a message from the internal cache.""" - - return self._messages.get(message_id) + def get_message(self, message_id: Snowflake) -> Optional["Message"]: + """Returns a message from the internal cache.""" + + return self._messages.get(message_id) + + def get_all_channels(self) -> List["Channel"]: + """Return all channels cached in every guild.""" + + channels: List["Channel"] = [] + for guild in self._guilds.values(): + channels.extend(guild._channels.values()) + return channels + + def get_all_members(self) -> List["Member"]: + """Return all cached members across all guilds. + + When member caching is disabled via :class:`MemberCacheFlags.none`, this + list will always be empty. + """ + + members: List["Member"] = [] + for guild in self._guilds.values(): + members.extend(guild._members.values()) + return members async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]: """Fetches a guild by ID from Discord and caches it.""" diff --git a/disagreement/http.py b/disagreement/http.py index 2cb9bca..1831a6f 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -656,21 +656,21 @@ class HTTPClient: await self.request("PUT", f"/channels/{channel_id}/pins/{message_id}") - async def unpin_message( - self, channel_id: "Snowflake", message_id: "Snowflake" - ) -> None: - """Unpins a message from a channel.""" - - await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}") - - async def crosspost_message( - self, channel_id: "Snowflake", message_id: "Snowflake" - ) -> Dict[str, Any]: - """Crossposts a message to any following channels.""" - - return await self.request( - "POST", f"/channels/{channel_id}/messages/{message_id}/crosspost" - ) + async def unpin_message( + self, channel_id: "Snowflake", message_id: "Snowflake" + ) -> None: + """Unpins a message from a channel.""" + + await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}") + + async def crosspost_message( + self, channel_id: "Snowflake", message_id: "Snowflake" + ) -> Dict[str, Any]: + """Crossposts a message to any following channels.""" + + return await self.request( + "POST", f"/channels/{channel_id}/messages/{message_id}/crosspost" + ) async def delete_channel( self, channel_id: str, reason: Optional[str] = None @@ -734,63 +734,68 @@ class HTTPClient: return await self.request("GET", f"/channels/{channel_id}/invites") - async def create_invite( - self, channel_id: "Snowflake", payload: Dict[str, Any] - ) -> "Invite": - """Creates an invite for a channel.""" + async def create_invite( + self, channel_id: "Snowflake", payload: Dict[str, Any] + ) -> "Invite": + """Creates an invite for a channel.""" data = await self.request( "POST", f"/channels/{channel_id}/invites", payload=payload ) - from .models import Invite - - return Invite.from_dict(data) - - async def create_channel_invite( - self, - channel_id: "Snowflake", - payload: Dict[str, Any], - *, - reason: Optional[str] = None, - ) -> "Invite": - """Creates an invite for a channel with an optional audit log reason.""" - - headers = {"X-Audit-Log-Reason": reason} if reason else None - data = await self.request( - "POST", - f"/channels/{channel_id}/invites", - payload=payload, - custom_headers=headers, - ) - from .models import Invite - - return Invite.from_dict(data) + from .models import Invite - async def delete_invite(self, code: str) -> None: - """Deletes an invite by code.""" - - await self.request("DELETE", f"/invites/{code}") - - async def get_invite(self, code: "Snowflake") -> Dict[str, Any]: - """Fetches a single invite by its code.""" - - return await self.request("GET", f"/invites/{code}") + return Invite.from_dict(data) - async def create_webhook( - self, channel_id: "Snowflake", payload: Dict[str, Any] - ) -> "Webhook": - """Creates a webhook in the specified channel.""" + async def create_channel_invite( + self, + channel_id: "Snowflake", + payload: Dict[str, Any], + *, + reason: Optional[str] = None, + ) -> "Invite": + """Creates an invite for a channel with an optional audit log reason.""" + + headers = {"X-Audit-Log-Reason": reason} if reason else None + data = await self.request( + "POST", + f"/channels/{channel_id}/invites", + payload=payload, + custom_headers=headers, + ) + from .models import Invite + + return Invite.from_dict(data) + + async def delete_invite(self, code: str) -> None: + """Deletes an invite by code.""" + + await self.request("DELETE", f"/invites/{code}") + + async def get_invite(self, code: "Snowflake") -> Dict[str, Any]: + """Fetches a single invite by its code.""" + + return await self.request("GET", f"/invites/{code}") + + async def create_webhook( + self, channel_id: "Snowflake", payload: Dict[str, Any] + ) -> "Webhook": + """Creates a webhook in the specified channel.""" data = await self.request( "POST", f"/channels/{channel_id}/webhooks", payload=payload ) - from .models import Webhook - - return Webhook(data) - - async def edit_webhook( - self, webhook_id: "Snowflake", payload: Dict[str, Any] - ) -> "Webhook": + from .models import Webhook + + return Webhook(data) + + async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]: + """Fetches a webhook by ID and returns the raw payload.""" + + return await self.request("GET", f"/webhooks/{webhook_id}") + + async def edit_webhook( + self, webhook_id: "Snowflake", payload: Dict[str, Any] + ) -> "Webhook": """Edits an existing webhook.""" data = await self.request("PATCH", f"/webhooks/{webhook_id}", payload=payload) @@ -798,28 +803,28 @@ class HTTPClient: return Webhook(data) - async def delete_webhook(self, webhook_id: "Snowflake") -> None: - """Deletes a webhook.""" - - await self.request("DELETE", f"/webhooks/{webhook_id}") - - async def get_webhook( - self, webhook_id: "Snowflake", token: Optional[str] = None - ) -> "Webhook": - """Fetches a webhook by ID, optionally using its token.""" - - endpoint = f"/webhooks/{webhook_id}" - use_auth = True - if token is not None: - endpoint += f"/{token}" - use_auth = False - if use_auth: - data = await self.request("GET", endpoint) - else: - data = await self.request("GET", endpoint, use_auth_header=False) - from .models import Webhook - - return Webhook(data) + async def delete_webhook(self, webhook_id: "Snowflake") -> None: + """Deletes a webhook.""" + + await self.request("DELETE", f"/webhooks/{webhook_id}") + + async def get_webhook_with_token( + self, webhook_id: "Snowflake", token: Optional[str] = None + ) -> "Webhook": + """Fetches a webhook by ID, optionally using its token.""" + + endpoint = f"/webhooks/{webhook_id}" + use_auth = True + if token is not None: + endpoint += f"/{token}" + use_auth = False + if use_auth: + data = await self.request("GET", endpoint) + else: + data = await self.request("GET", endpoint, use_auth_header=False) + from .models import Webhook + + return Webhook(data) async def execute_webhook( self, diff --git a/disagreement/models.py b/disagreement/models.py index 1f09d7c..f302664 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1,8 +1,8 @@ -""" -Data models for Discord objects. -""" - -from __future__ import annotations +""" +Data models for Discord objects. +""" + +from __future__ import annotations import asyncio import datetime @@ -49,14 +49,14 @@ from .enums import ( # These enums will need to be defined in disagreement/enum from .permissions import Permissions -if TYPE_CHECKING: - from .client import Client # For type hinting to avoid circular imports - from .enums import OverwriteType # For PermissionOverwrite model - from .ui.view import View - from .interactions import Snowflake - from .typing import Typing - from .shard_manager import Shard - from .asset import Asset +if TYPE_CHECKING: + from .client import Client # For type hinting to avoid circular imports + from .enums import OverwriteType # For PermissionOverwrite model + from .ui.view import View + from .interactions import Snowflake + from .typing import Typing + from .shard_manager import Shard + from .asset import Asset # Forward reference Message if it were used in type hints before its definition # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. @@ -104,23 +104,23 @@ class User(HashableById): return f"" @property - def avatar(self) -> Optional["Asset"]: - """Return the user's avatar as an :class:`Asset`.""" - - if self._avatar: - from .asset import Asset - - return Asset(self._avatar, self._client) - return None - - @avatar.setter - def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._avatar = value - elif value is None: - self._avatar = None - else: - self._avatar = value.url + def avatar(self) -> Optional["Asset"]: + """Return the user's avatar as an :class:`Asset`.""" + + if self._avatar: + from .asset import Asset + + return Asset(self._avatar, self._client) + return None + + @avatar.setter + def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._avatar = value + elif value is None: + self._avatar = None + else: + self._avatar = value.url async def send( self, @@ -843,21 +843,21 @@ class Role: return f"" @property - def icon(self) -> Optional["Asset"]: - if self._icon: - from .asset import Asset - - return Asset(self._icon, None) - return None - - @icon.setter - def icon(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._icon = value - elif value is None: - self._icon = None - else: - self._icon = value.url + def icon(self) -> Optional["Asset"]: + if self._icon: + from .asset import Asset + + return Asset(self._icon, None) + return None + + @icon.setter + def icon(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._icon = value + elif value is None: + self._icon = None + else: + self._icon = value.url class Member(User): # Member inherits from User @@ -924,23 +924,23 @@ class Member(User): # Member inherits from User return f"" @property - def avatar(self) -> Optional["Asset"]: - """Return the member's avatar as an :class:`Asset`.""" - - if self._avatar: - from .asset import Asset - - return Asset(self._avatar, self._client) - return None - - @avatar.setter - def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._avatar = value - elif value is None: - self._avatar = None - else: - self._avatar = value.url + def avatar(self) -> Optional["Asset"]: + """Return the member's avatar as an :class:`Asset`.""" + + if self._avatar: + from .asset import Asset + + return Asset(self._avatar, self._client) + return None + + @avatar.setter + def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._avatar = value + elif value is None: + self._avatar = None + else: + self._avatar = value.url @property def display_name(self) -> str: @@ -1034,10 +1034,10 @@ class Member(User): # Member inherits from User return Permissions(~0) return base - - @property - def voice(self) -> Optional["VoiceState"]: - """Return the member's cached voice state as a :class:`VoiceState`.""" + + @property + def voice(self) -> Optional["VoiceState"]: + """Return the member's cached voice state as a :class:`VoiceState`.""" if self.voice_state is None: return None @@ -1466,72 +1466,72 @@ class Guild(HashableById): return f"" @property - def icon(self) -> Optional["Asset"]: - if self._icon: - from .asset import Asset - - return Asset(self._icon, self._client) - return None - - @icon.setter - def icon(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._icon = value - elif value is None: - self._icon = None - else: - self._icon = value.url + def icon(self) -> Optional["Asset"]: + if self._icon: + from .asset import Asset + + return Asset(self._icon, self._client) + return None + + @icon.setter + def icon(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._icon = value + elif value is None: + self._icon = None + else: + self._icon = value.url @property - def splash(self) -> Optional["Asset"]: - if self._splash: - from .asset import Asset - - return Asset(self._splash, self._client) - return None - - @splash.setter - def splash(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._splash = value - elif value is None: - self._splash = None - else: - self._splash = value.url + def splash(self) -> Optional["Asset"]: + if self._splash: + from .asset import Asset + + return Asset(self._splash, self._client) + return None + + @splash.setter + def splash(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._splash = value + elif value is None: + self._splash = None + else: + self._splash = value.url @property - def discovery_splash(self) -> Optional["Asset"]: - if self._discovery_splash: - from .asset import Asset - - return Asset(self._discovery_splash, self._client) - return None - - @discovery_splash.setter - def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._discovery_splash = value - elif value is None: - self._discovery_splash = None - else: - self._discovery_splash = value.url + def discovery_splash(self) -> Optional["Asset"]: + if self._discovery_splash: + from .asset import Asset + + return Asset(self._discovery_splash, self._client) + return None + + @discovery_splash.setter + def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._discovery_splash = value + elif value is None: + self._discovery_splash = None + else: + self._discovery_splash = value.url @property - def banner(self) -> Optional["Asset"]: - if self._banner: - from .asset import Asset - - return Asset(self._banner, self._client) - return None - - @banner.setter - def banner(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._banner = value - elif value is None: - self._banner = None - else: - self._banner = value.url + def banner(self) -> Optional["Asset"]: + if self._banner: + from .asset import Asset + + return Asset(self._banner, self._client) + return None + + @banner.setter + def banner(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._banner = value + elif value is None: + self._banner = None + else: + self._banner = value.url async def fetch_widget(self) -> Dict[str, Any]: """|coro| Fetch this guild's widget settings.""" @@ -2280,23 +2280,23 @@ class Webhook: return f"" @property - def avatar(self) -> Optional["Asset"]: - """Return the webhook's avatar as an :class:`Asset`.""" - - if self._avatar: - from .asset import Asset - - return Asset(self._avatar, self._client) - return None - - @avatar.setter - def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: - if isinstance(value, str): - self._avatar = value - elif value is None: - self._avatar = None - else: - self._avatar = value.url + def avatar(self) -> Optional["Asset"]: + """Return the webhook's avatar as an :class:`Asset`.""" + + if self._avatar: + from .asset import Asset + + return Asset(self._avatar, self._client) + return None + + @avatar.setter + def avatar(self, value: Optional[Union[str, "Asset"]]) -> None: + if isinstance(value, str): + self._avatar = value + elif value is None: + self._avatar = None + else: + self._avatar = value.url @classmethod def from_url( diff --git a/tests/test_cache.py b/tests/test_cache.py index 88effe9..5dbef27 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,6 +1,60 @@ import time from disagreement.cache import Cache +from disagreement.client import Client +from disagreement.caching import MemberCacheFlags +from disagreement.enums import ( + ChannelType, + ExplicitContentFilterLevel, + GuildNSFWLevel, + MFALevel, + MessageNotificationLevel, + PremiumTier, + VerificationLevel, +) + + +def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict: + base = { + "id": gid, + "name": f"g{gid}", + "owner_id": "1", + "afk_timeout": 60, + "verification_level": VerificationLevel.NONE.value, + "default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value, + "explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value, + "roles": [], + "emojis": [], + "features": [], + "mfa_level": MFALevel.NONE.value, + "system_channel_flags": 0, + "premium_tier": PremiumTier.NONE.value, + "nsfw_level": GuildNSFWLevel.DEFAULT.value, + "channels": [], + "members": [], + } + for i in range(channel_count): + base["channels"].append( + { + "id": f"{gid}-c{i}", + "type": ChannelType.GUILD_TEXT.value, + "guild_id": gid, + "permission_overwrites": [], + } + ) + for i in range(member_count): + base["members"].append( + { + "user": { + "id": f"{gid}-m{i}", + "username": f"u{i}", + "discriminator": "0001", + }, + "joined_at": "t", + "roles": [], + } + ) + return base def test_cache_store_and_get(): @@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item(): assert cache.get_or_fetch("c", fetch) == 3 assert called + + +def test_client_get_all_channels_and_members(): + client = Client(token="t") + client.parse_guild(_guild_payload("1", 2, 2)) + client.parse_guild(_guild_payload("2", 1, 1)) + + channels = {c.id for c in client.get_all_channels()} + members = {m.id for m in client.get_all_members()} + + assert channels == {"1-c0", "1-c1", "2-c0"} + assert members == {"1-m0", "1-m1", "2-m0"} + + +def test_client_get_all_members_disabled_cache(): + client = Client(token="t", member_cache_flags=MemberCacheFlags.none()) + client.parse_guild(_guild_payload("1", 1, 2)) + + assert client.get_all_members() == []