Implement channel and member aggregation (#117)

This commit is contained in:
Slipstream 2025-06-15 20:42:21 -06:00 committed by GitHub
parent d710487fc2
commit 8e88aaec2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 339 additions and 233 deletions

View File

@ -84,9 +84,9 @@ class MemberCache(Cache["Member"]):
def _should_cache(self, member: Member) -> bool:
"""Determines if a member should be cached based on the flags."""
if self.flags.all:
if self.flags.all_enabled:
return True
if self.flags.none:
if self.flags.no_flags:
return False
if self.flags.online and member.status != "offline":

View File

@ -74,6 +74,14 @@ class MemberCacheFlags:
for name in self.VALID_FLAGS:
yield name, getattr(self, name)
@property
def all_enabled(self) -> bool:
return self.value == self.ALL_FLAGS
@property
def no_flags(self) -> bool:
return self.value == 0
def __int__(self) -> int:
return self.value

View File

@ -1448,10 +1448,30 @@ class Client:
return self._channels.get(channel_id)
def get_message(self, message_id: Snowflake) -> Optional["Message"]:
"""Returns a message from the internal cache."""
return self._messages.get(message_id)
def get_message(self, message_id: Snowflake) -> Optional["Message"]:
"""Returns a message from the internal cache."""
return self._messages.get(message_id)
def get_all_channels(self) -> List["Channel"]:
"""Return all channels cached in every guild."""
channels: List["Channel"] = []
for guild in self._guilds.values():
channels.extend(guild._channels.values())
return channels
def get_all_members(self) -> List["Member"]:
"""Return all cached members across all guilds.
When member caching is disabled via :class:`MemberCacheFlags.none`, this
list will always be empty.
"""
members: List["Member"] = []
for guild in self._guilds.values():
members.extend(guild._members.values())
return members
async def fetch_guild(self, guild_id: Snowflake) -> Optional["Guild"]:
"""Fetches a guild by ID from Discord and caches it."""

View File

@ -656,21 +656,21 @@ class HTTPClient:
await self.request("PUT", f"/channels/{channel_id}/pins/{message_id}")
async def unpin_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> None:
"""Unpins a message from a channel."""
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
async def crosspost_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> Dict[str, Any]:
"""Crossposts a message to any following channels."""
return await self.request(
"POST", f"/channels/{channel_id}/messages/{message_id}/crosspost"
)
async def unpin_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> None:
"""Unpins a message from a channel."""
await self.request("DELETE", f"/channels/{channel_id}/pins/{message_id}")
async def crosspost_message(
self, channel_id: "Snowflake", message_id: "Snowflake"
) -> Dict[str, Any]:
"""Crossposts a message to any following channels."""
return await self.request(
"POST", f"/channels/{channel_id}/messages/{message_id}/crosspost"
)
async def delete_channel(
self, channel_id: str, reason: Optional[str] = None
@ -734,63 +734,68 @@ class HTTPClient:
return await self.request("GET", f"/channels/{channel_id}/invites")
async def create_invite(
self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> "Invite":
"""Creates an invite for a channel."""
async def create_invite(
self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> "Invite":
"""Creates an invite for a channel."""
data = await self.request(
"POST", f"/channels/{channel_id}/invites", payload=payload
)
from .models import Invite
return Invite.from_dict(data)
async def create_channel_invite(
self,
channel_id: "Snowflake",
payload: Dict[str, Any],
*,
reason: Optional[str] = None,
) -> "Invite":
"""Creates an invite for a channel with an optional audit log reason."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
data = await self.request(
"POST",
f"/channels/{channel_id}/invites",
payload=payload,
custom_headers=headers,
)
from .models import Invite
return Invite.from_dict(data)
from .models import Invite
async def delete_invite(self, code: str) -> None:
"""Deletes an invite by code."""
await self.request("DELETE", f"/invites/{code}")
async def get_invite(self, code: "Snowflake") -> Dict[str, Any]:
"""Fetches a single invite by its code."""
return await self.request("GET", f"/invites/{code}")
return Invite.from_dict(data)
async def create_webhook(
self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook":
"""Creates a webhook in the specified channel."""
async def create_channel_invite(
self,
channel_id: "Snowflake",
payload: Dict[str, Any],
*,
reason: Optional[str] = None,
) -> "Invite":
"""Creates an invite for a channel with an optional audit log reason."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
data = await self.request(
"POST",
f"/channels/{channel_id}/invites",
payload=payload,
custom_headers=headers,
)
from .models import Invite
return Invite.from_dict(data)
async def delete_invite(self, code: str) -> None:
"""Deletes an invite by code."""
await self.request("DELETE", f"/invites/{code}")
async def get_invite(self, code: "Snowflake") -> Dict[str, Any]:
"""Fetches a single invite by its code."""
return await self.request("GET", f"/invites/{code}")
async def create_webhook(
self, channel_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook":
"""Creates a webhook in the specified channel."""
data = await self.request(
"POST", f"/channels/{channel_id}/webhooks", payload=payload
)
from .models import Webhook
return Webhook(data)
async def edit_webhook(
self, webhook_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook":
from .models import Webhook
return Webhook(data)
async def get_webhook(self, webhook_id: "Snowflake") -> Dict[str, Any]:
"""Fetches a webhook by ID and returns the raw payload."""
return await self.request("GET", f"/webhooks/{webhook_id}")
async def edit_webhook(
self, webhook_id: "Snowflake", payload: Dict[str, Any]
) -> "Webhook":
"""Edits an existing webhook."""
data = await self.request("PATCH", f"/webhooks/{webhook_id}", payload=payload)
@ -798,28 +803,28 @@ class HTTPClient:
return Webhook(data)
async def delete_webhook(self, webhook_id: "Snowflake") -> None:
"""Deletes a webhook."""
await self.request("DELETE", f"/webhooks/{webhook_id}")
async def get_webhook(
self, webhook_id: "Snowflake", token: Optional[str] = None
) -> "Webhook":
"""Fetches a webhook by ID, optionally using its token."""
endpoint = f"/webhooks/{webhook_id}"
use_auth = True
if token is not None:
endpoint += f"/{token}"
use_auth = False
if use_auth:
data = await self.request("GET", endpoint)
else:
data = await self.request("GET", endpoint, use_auth_header=False)
from .models import Webhook
return Webhook(data)
async def delete_webhook(self, webhook_id: "Snowflake") -> None:
"""Deletes a webhook."""
await self.request("DELETE", f"/webhooks/{webhook_id}")
async def get_webhook_with_token(
self, webhook_id: "Snowflake", token: Optional[str] = None
) -> "Webhook":
"""Fetches a webhook by ID, optionally using its token."""
endpoint = f"/webhooks/{webhook_id}"
use_auth = True
if token is not None:
endpoint += f"/{token}"
use_auth = False
if use_auth:
data = await self.request("GET", endpoint)
else:
data = await self.request("GET", endpoint, use_auth_header=False)
from .models import Webhook
return Webhook(data)
async def execute_webhook(
self,

View File

@ -1,8 +1,8 @@
"""
Data models for Discord objects.
"""
from __future__ import annotations
"""
Data models for Discord objects.
"""
from __future__ import annotations
import asyncio
import datetime
@ -49,14 +49,14 @@ from .enums import ( # These enums will need to be defined in disagreement/enum
from .permissions import Permissions
if TYPE_CHECKING:
from .client import Client # For type hinting to avoid circular imports
from .enums import OverwriteType # For PermissionOverwrite model
from .ui.view import View
from .interactions import Snowflake
from .typing import Typing
from .shard_manager import Shard
from .asset import Asset
if TYPE_CHECKING:
from .client import Client # For type hinting to avoid circular imports
from .enums import OverwriteType # For PermissionOverwrite model
from .ui.view import View
from .interactions import Snowflake
from .typing import Typing
from .shard_manager import Shard
from .asset import Asset
# Forward reference Message if it were used in type hints before its definition
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
@ -104,23 +104,23 @@ class User(HashableById):
return f"<User id='{self.id}' username='{username}' discriminator='{disc}'>"
@property
def avatar(self) -> Optional["Asset"]:
"""Return the user's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
def avatar(self) -> Optional["Asset"]:
"""Return the user's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
async def send(
self,
@ -843,21 +843,21 @@ class Role:
return f"<Role id='{self.id}' name='{self.name}'>"
@property
def icon(self) -> Optional["Asset"]:
if self._icon:
from .asset import Asset
return Asset(self._icon, None)
return None
@icon.setter
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._icon = value
elif value is None:
self._icon = None
else:
self._icon = value.url
def icon(self) -> Optional["Asset"]:
if self._icon:
from .asset import Asset
return Asset(self._icon, None)
return None
@icon.setter
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._icon = value
elif value is None:
self._icon = None
else:
self._icon = value.url
class Member(User): # Member inherits from User
@ -924,23 +924,23 @@ class Member(User): # Member inherits from User
return f"<Member id='{self.id}' username='{self.username}' nick='{self.nick}'>"
@property
def avatar(self) -> Optional["Asset"]:
"""Return the member's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
def avatar(self) -> Optional["Asset"]:
"""Return the member's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
@property
def display_name(self) -> str:
@ -1034,10 +1034,10 @@ class Member(User): # Member inherits from User
return Permissions(~0)
return base
@property
def voice(self) -> Optional["VoiceState"]:
"""Return the member's cached voice state as a :class:`VoiceState`."""
@property
def voice(self) -> Optional["VoiceState"]:
"""Return the member's cached voice state as a :class:`VoiceState`."""
if self.voice_state is None:
return None
@ -1466,72 +1466,72 @@ class Guild(HashableById):
return f"<Guild id='{self.id}' name='{self.name}'>"
@property
def icon(self) -> Optional["Asset"]:
if self._icon:
from .asset import Asset
return Asset(self._icon, self._client)
return None
@icon.setter
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._icon = value
elif value is None:
self._icon = None
else:
self._icon = value.url
def icon(self) -> Optional["Asset"]:
if self._icon:
from .asset import Asset
return Asset(self._icon, self._client)
return None
@icon.setter
def icon(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._icon = value
elif value is None:
self._icon = None
else:
self._icon = value.url
@property
def splash(self) -> Optional["Asset"]:
if self._splash:
from .asset import Asset
return Asset(self._splash, self._client)
return None
@splash.setter
def splash(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._splash = value
elif value is None:
self._splash = None
else:
self._splash = value.url
def splash(self) -> Optional["Asset"]:
if self._splash:
from .asset import Asset
return Asset(self._splash, self._client)
return None
@splash.setter
def splash(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._splash = value
elif value is None:
self._splash = None
else:
self._splash = value.url
@property
def discovery_splash(self) -> Optional["Asset"]:
if self._discovery_splash:
from .asset import Asset
return Asset(self._discovery_splash, self._client)
return None
@discovery_splash.setter
def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._discovery_splash = value
elif value is None:
self._discovery_splash = None
else:
self._discovery_splash = value.url
def discovery_splash(self) -> Optional["Asset"]:
if self._discovery_splash:
from .asset import Asset
return Asset(self._discovery_splash, self._client)
return None
@discovery_splash.setter
def discovery_splash(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._discovery_splash = value
elif value is None:
self._discovery_splash = None
else:
self._discovery_splash = value.url
@property
def banner(self) -> Optional["Asset"]:
if self._banner:
from .asset import Asset
return Asset(self._banner, self._client)
return None
@banner.setter
def banner(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._banner = value
elif value is None:
self._banner = None
else:
self._banner = value.url
def banner(self) -> Optional["Asset"]:
if self._banner:
from .asset import Asset
return Asset(self._banner, self._client)
return None
@banner.setter
def banner(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._banner = value
elif value is None:
self._banner = None
else:
self._banner = value.url
async def fetch_widget(self) -> Dict[str, Any]:
"""|coro| Fetch this guild's widget settings."""
@ -2280,23 +2280,23 @@ class Webhook:
return f"<Webhook id='{self.id}' name='{self.name}'>"
@property
def avatar(self) -> Optional["Asset"]:
"""Return the webhook's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
def avatar(self) -> Optional["Asset"]:
"""Return the webhook's avatar as an :class:`Asset`."""
if self._avatar:
from .asset import Asset
return Asset(self._avatar, self._client)
return None
@avatar.setter
def avatar(self, value: Optional[Union[str, "Asset"]]) -> None:
if isinstance(value, str):
self._avatar = value
elif value is None:
self._avatar = None
else:
self._avatar = value.url
@classmethod
def from_url(

View File

@ -1,6 +1,60 @@
import time
from disagreement.cache import Cache
from disagreement.client import Client
from disagreement.caching import MemberCacheFlags
from disagreement.enums import (
ChannelType,
ExplicitContentFilterLevel,
GuildNSFWLevel,
MFALevel,
MessageNotificationLevel,
PremiumTier,
VerificationLevel,
)
def _guild_payload(gid: str, channel_count: int, member_count: int) -> dict:
base = {
"id": gid,
"name": f"g{gid}",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
"channels": [],
"members": [],
}
for i in range(channel_count):
base["channels"].append(
{
"id": f"{gid}-c{i}",
"type": ChannelType.GUILD_TEXT.value,
"guild_id": gid,
"permission_overwrites": [],
}
)
for i in range(member_count):
base["members"].append(
{
"user": {
"id": f"{gid}-m{i}",
"username": f"u{i}",
"discriminator": "0001",
},
"joined_at": "t",
"roles": [],
}
)
return base
def test_cache_store_and_get():
@ -65,3 +119,22 @@ def test_get_or_fetch_fetches_expired_item():
assert cache.get_or_fetch("c", fetch) == 3
assert called
def test_client_get_all_channels_and_members():
client = Client(token="t")
client.parse_guild(_guild_payload("1", 2, 2))
client.parse_guild(_guild_payload("2", 1, 1))
channels = {c.id for c in client.get_all_channels()}
members = {m.id for m in client.get_all_members()}
assert channels == {"1-c0", "1-c1", "2-c0"}
assert members == {"1-m0", "1-m1", "2-m0"}
def test_client_get_all_members_disabled_cache():
client = Client(token="t", member_cache_flags=MemberCacheFlags.none())
client.parse_guild(_guild_payload("1", 1, 2))
assert client.get_all_members() == []