Implement channel and member aggregation (#117)

This commit is contained in:
Slipstream 2025-06-15 20:42:21 -06:00 committed by GitHub
parent d710487fc2
commit 8e88aaec2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 339 additions and 233 deletions

View File

@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]):
def _should_cache(self, member: Member) -> bool:
"""Determines if a member should be cached based on the flags."""
if self.flags.all:
if self.flags.all_enabled:
return True
if self.flags.none:
if self.flags.no_flags:
return False
if self.flags.online and member.status != "offline":

View File

@ -74,6 +74,14 @@ class MemberCacheFlags:
for name in self.VALID_FLAGS:
yield name, getattr(self, name)
@property
def all_enabled(self) -> bool:
return self.value == self.ALL_FLAGS
@property
def no_flags(self) -> bool:
return self.value == 0
def __int__(self) -> int:
return self.value

View File

@ -1453,6 +1453,26 @@ class Client:
return self._messages.get(message_id)
def get_all_channels(self) -> List["Channel"]:
"""Return all channels cached in every guild."""
channels: List["Channel"] = []
for guild in self._guilds.values():
channels.extend(guild._channels.values())
return channels
def get_all_members(self) -> List["Member"]:
"""Return all cached members across all guilds.
When member caching is disabled via :class:`MemberCacheFlags.none`, this
list will always be empty.
"""
members: List["Member"] = []
for guild in self._guilds.values():
members.extend(guild._members.values())
return members
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
"""Fetches a guild by ID from Discord and caches it."""

View File

@ -788,6 +788,11 @@ class HTTPClient:
return Webhook(data)
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
"""Fetches a webhook by ID and returns the raw payload."""
return await self.request("GET", f"/webhooks/{webhook_id}")
async def edit_webhook(
self, webhook_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook":
@ -803,7 +808,7 @@ class HTTPClient:
await self.request("DELETE", f"/webhooks/{webhook_id}")
async def get_webhook(
async def get_webhook_with_token(
self, webhook_id: "Snowflake", token: Optional[str] = None
) -> "Webhook":
"""Fetches a webhook by ID, optionally using its token."""

View File

@ -1,6 +1,60 @@
import time
from disagreement.cache import Cache
from disagreement.client import Client
from disagreement.caching import MemberCacheFlags
from disagreement.enums import (
ChannelType,
ExplicitContentFilterLevel,
GuildNSFWLevel,
MFALevel,
MessageNotificationLevel,
PremiumTier,
VerificationLevel,
)
def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict:
base = {
"id": gid,
"name": f"g{gid}",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
"channels": [],
"members": [],
}
for i in range(channel_count):
base["channels"].append(
{
"id": f"{gid}-c{i}",
"type": ChannelType.GUILD_TEXT.value,
"guild_id": gid,
"permission_overwrites": [],
}
)
for i in range(member_count):
base["members"].append(
{
"user": {
"id": f"{gid}-m{i}",
"username": f"u{i}",
"discriminator": "0001",
},
"joined_at": "t",
"roles": [],
}
)
return base
def test_cache_store_and_get():
@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item():
assert cache.get_or_fetch("c", fetch) == 3
assert called
def test_client_get_all_channels_and_members():
client = Client(token="t")
client.parse_guild(_guild_payload("1", 2, 2))
client.parse_guild(_guild_payload("2", 1, 1))
channels = {c.id for c in client.get_all_channels()}
members = {m.id for m in client.get_all_members()}
assert channels == {"1-c0", "1-c1", "2-c0"}
assert members == {"1-m0", "1-m1", "2-m0"}
def test_client_get_all_members_disabled_cache():
client = Client(token="t", member_cache_flags=MemberCacheFlags.none())
client.parse_guild(_guild_payload("1", 1, 2))
assert client.get_all_members() == []