Implement channel and member aggregation (#117)
This commit is contained in:
parent
d710487fc2
commit
8e88aaec2f
@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]):
|
||||
|
||||
def _should_cache(self, member: Member) -> bool:
|
||||
"""Determines if a member should be cached based on the flags."""
|
||||
if self.flags.all:
|
||||
if self.flags.all_enabled:
|
||||
return True
|
||||
if self.flags.none:
|
||||
if self.flags.no_flags:
|
||||
return False
|
||||
|
||||
if self.flags.online and member.status != "offline":
|
||||
|
@ -74,6 +74,14 @@ class MemberCacheFlags:
|
||||
for name in self.VALID_FLAGS:
|
||||
yield name, getattr(self, name)
|
||||
|
||||
@property
|
||||
def all_enabled(self) -> bool:
|
||||
return self.value == self.ALL_FLAGS
|
||||
|
||||
@property
|
||||
def no_flags(self) -> bool:
|
||||
return self.value == 0
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
|
@ -1453,6 +1453,26 @@ class Client:
|
||||
|
||||
return self._messages.get(message_id)
|
||||
|
||||
def get_all_channels(self) -> List["Channel"]:
|
||||
"""Return all channels cached in every guild."""
|
||||
|
||||
channels: List["Channel"] = []
|
||||
for guild in self._guilds.values():
|
||||
channels.extend(guild._channels.values())
|
||||
return channels
|
||||
|
||||
def get_all_members(self) -> List["Member"]:
|
||||
"""Return all cached members across all guilds.
|
||||
|
||||
When member caching is disabled via :class:`MemberCacheFlags.none`, this
|
||||
list will always be empty.
|
||||
"""
|
||||
|
||||
members: List["Member"] = []
|
||||
for guild in self._guilds.values():
|
||||
members.extend(guild._members.values())
|
||||
return members
|
||||
|
||||
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
|
||||
"""Fetches a guild by ID from Discord and caches it."""
|
||||
|
||||
|
@ -788,6 +788,11 @@ class HTTPClient:
|
||||
|
||||
return Webhook(data)
|
||||
|
||||
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
|
||||
"""Fetches a webhook by ID and returns the raw payload."""
|
||||
|
||||
return await self.request("GET", f"/webhooks/{webhook_id}")
|
||||
|
||||
async def edit_webhook(
|
||||
self, webhook_id: "Snowflake", payload: Dict[str, Any]
|
||||
) -> "Webhook":
|
||||
@ -803,7 +808,7 @@ class HTTPClient:
|
||||
|
||||
await self.request("DELETE", f"/webhooks/{webhook_id}")
|
||||
|
||||
async def get_webhook(
|
||||
async def get_webhook_with_token(
|
||||
self, webhook_id: "Snowflake", token: Optional[str] = None
|
||||
) -> "Webhook":
|
||||
"""Fetches a webhook by ID, optionally using its token."""
|
||||
|
@ -1,6 +1,60 @@
|
||||
import time
|
||||
|
||||
from disagreement.cache import Cache
|
||||
from disagreement.client import Client
|
||||
from disagreement.caching import MemberCacheFlags
|
||||
from disagreement.enums import (
|
||||
ChannelType,
|
||||
ExplicitContentFilterLevel,
|
||||
GuildNSFWLevel,
|
||||
MFALevel,
|
||||
MessageNotificationLevel,
|
||||
PremiumTier,
|
||||
VerificationLevel,
|
||||
)
|
||||
|
||||
|
||||
def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict:
|
||||
base = {
|
||||
"id": gid,
|
||||
"name": f"g{gid}",
|
||||
"owner_id": "1",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
"channels": [],
|
||||
"members": [],
|
||||
}
|
||||
for i in range(channel_count):
|
||||
base["channels"].append(
|
||||
{
|
||||
"id": f"{gid}-c{i}",
|
||||
"type": ChannelType.GUILD_TEXT.value,
|
||||
"guild_id": gid,
|
||||
"permission_overwrites": [],
|
||||
}
|
||||
)
|
||||
for i in range(member_count):
|
||||
base["members"].append(
|
||||
{
|
||||
"user": {
|
||||
"id": f"{gid}-m{i}",
|
||||
"username": f"u{i}",
|
||||
"discriminator": "0001",
|
||||
},
|
||||
"joined_at": "t",
|
||||
"roles": [],
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
|
||||
def test_cache_store_and_get():
|
||||
@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item():
|
||||
|
||||
assert cache.get_or_fetch("c", fetch) == 3
|
||||
assert called
|
||||
|
||||
|
||||
def test_client_get_all_channels_and_members():
|
||||
client = Client(token="t")
|
||||
client.parse_guild(_guild_payload("1", 2, 2))
|
||||
client.parse_guild(_guild_payload("2", 1, 1))
|
||||
|
||||
channels = {c.id for c in client.get_all_channels()}
|
||||
members = {m.id for m in client.get_all_members()}
|
||||
|
||||
assert channels == {"1-c0", "1-c1", "2-c0"}
|
||||
assert members == {"1-m0", "1-m1", "2-m0"}
|
||||
|
||||
|
||||
def test_client_get_all_members_disabled_cache():
|
||||
client = Client(token="t", member_cache_flags=MemberCacheFlags.none())
|
||||
client.parse_guild(_guild_payload("1", 1, 2))
|
||||
|
||||
assert client.get_all_members() == []
|
||||
|
Loading…
x
Reference in New Issue
Block a user