Implements caching system with TTL and member filtering
Introduces a flexible caching infrastructure with time-to-live support and configurable member caching based on status, voice state, and join events. Adds AudioSink abstract base class to support audio output handling in voice connections. Replaces direct dictionary access with cache objects throughout the client, enabling automatic expiration and intelligent member filtering based on user-defined flags. Updates guild parsing to incorporate presence and voice state data for more accurate member caching decisions.
This commit is contained in:
parent
0151526d07
commit
ed83a9da85
@ -114,3 +114,20 @@ class FFmpegAudioSource(AudioSource):
|
||||
if isinstance(self.source, io.IOBase):
|
||||
with contextlib.suppress(Exception):
|
||||
self.source.close()
|
||||
|
||||
class AudioSink:
|
||||
"""Abstract base class for audio sinks."""
|
||||
|
||||
def write(self, user, data):
|
||||
"""Write a chunk of PCM audio.
|
||||
|
||||
Subclasses must implement this. The data is raw PCM at 48kHz
|
||||
stereo.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cleanup the sink when the voice client disconnects."""
|
||||
|
||||
return None
|
||||
|
@ -4,7 +4,8 @@ import time
|
||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import Channel, Guild
|
||||
from .models import Channel, Guild, Member
|
||||
from .caching import MemberCacheFlags
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -53,3 +54,32 @@ class GuildCache(Cache["Guild"]):
|
||||
|
||||
class ChannelCache(Cache["Channel"]):
|
||||
"""Cache specifically for :class:`Channel` objects."""
|
||||
|
||||
|
||||
class MemberCache(Cache["Member"]):
|
||||
"""
|
||||
A cache for :class:`Member` objects that respects :class:`MemberCacheFlags`.
|
||||
"""
|
||||
|
||||
def __init__(self, flags: MemberCacheFlags, ttl: Optional[float] = None) -> None:
|
||||
super().__init__(ttl)
|
||||
self.flags = flags
|
||||
|
||||
def _should_cache(self, member: Member) -> bool:
|
||||
"""Determines if a member should be cached based on the flags."""
|
||||
if self.flags.all:
|
||||
return True
|
||||
if self.flags.none:
|
||||
return False
|
||||
|
||||
if self.flags.online and member.status != "offline":
|
||||
return True
|
||||
if self.flags.voice and member.voice_state is not None:
|
||||
return True
|
||||
if self.flags.joined and getattr(member, "_just_joined", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
def set(self, key: str, value: Member) -> None:
|
||||
if self._should_cache(value):
|
||||
super().set(key, value)
|
||||
|
120
disagreement/caching.py
Normal file
120
disagreement/caching.py
Normal file
@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterator, Tuple
|
||||
|
||||
|
||||
class _MemberCacheFlagValue:
|
||||
flag: int
|
||||
|
||||
def __init__(self, func: Callable[[Any], bool]):
|
||||
self.flag = getattr(func, 'flag', 0)
|
||||
self.__doc__ = func.__doc__
|
||||
|
||||
def __get__(self, instance: 'MemberCacheFlags', owner: type) -> Any:
|
||||
if instance is None:
|
||||
return self
|
||||
return instance.value & self.flag != 0
|
||||
|
||||
def __set__(self, instance: Any, value: bool) -> None:
|
||||
if value:
|
||||
instance.value |= self.flag
|
||||
else:
|
||||
instance.value &= ~self.flag
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} flag={self.flag}>'
|
||||
|
||||
|
||||
def flag_value(flag: int) -> Callable[[Callable[[Any], bool]], _MemberCacheFlagValue]:
|
||||
def decorator(func: Callable[[Any], bool]) -> _MemberCacheFlagValue:
|
||||
setattr(func, 'flag', flag)
|
||||
return _MemberCacheFlagValue(func)
|
||||
return decorator
|
||||
|
||||
|
||||
class MemberCacheFlags:
|
||||
__slots__ = ('value',)
|
||||
|
||||
VALID_FLAGS: ClassVar[Dict[str, int]] = {
|
||||
'joined': 1 << 0,
|
||||
'voice': 1 << 1,
|
||||
'online': 1 << 2,
|
||||
}
|
||||
DEFAULT_FLAGS: ClassVar[int] = 1 | 2 | 4
|
||||
ALL_FLAGS: ClassVar[int] = sum(VALID_FLAGS.values())
|
||||
|
||||
def __init__(self, **kwargs: bool):
|
||||
self.value = self.DEFAULT_FLAGS
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError(f'{key!r} is not a valid member cache flag.')
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
def _from_value(cls, value: int) -> MemberCacheFlags:
|
||||
self = cls.__new__(cls)
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, MemberCacheFlags) and self.value == other.value
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<MemberCacheFlags value={self.value}>'
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[str, bool]]:
|
||||
for name in self.VALID_FLAGS:
|
||||
yield name, getattr(self, name)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.value
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with all flags enabled."""
|
||||
return cls._from_value(cls.ALL_FLAGS)
|
||||
|
||||
@classmethod
|
||||
def none(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with all flags disabled."""
|
||||
return cls._from_value(0)
|
||||
|
||||
@classmethod
|
||||
def only_joined(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `joined` flag enabled."""
|
||||
return cls._from_value(cls.VALID_FLAGS['joined'])
|
||||
|
||||
@classmethod
|
||||
def only_voice(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `voice` flag enabled."""
|
||||
return cls._from_value(cls.VALID_FLAGS['voice'])
|
||||
|
||||
@classmethod
|
||||
def only_online(cls) -> MemberCacheFlags:
|
||||
"""A factory method that creates a :class:`MemberCacheFlags` with only the `online` flag enabled."""
|
||||
return cls._from_value(cls.VALID_FLAGS['online'])
|
||||
|
||||
@flag_value(1 << 0)
|
||||
def joined(self) -> bool:
|
||||
"""Whether to cache members that have just joined the guild."""
|
||||
return False
|
||||
|
||||
@flag_value(1 << 1)
|
||||
def voice(self) -> bool:
|
||||
"""Whether to cache members that are in a voice channel."""
|
||||
return False
|
||||
|
||||
@flag_value(1 << 2)
|
||||
def online(self) -> bool:
|
||||
"""Whether to cache members that are online."""
|
||||
return False
|
@ -26,7 +26,9 @@ from .event_dispatcher import EventDispatcher
|
||||
from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion
|
||||
from .errors import DisagreementException, AuthenticationError
|
||||
from .typing import Typing
|
||||
from .ext.commands.core import CommandHandler
|
||||
from .caching import MemberCacheFlags
|
||||
from .cache import Cache, GuildCache, ChannelCache, MemberCache
|
||||
from .ext.commands.core import Command, CommandHandler, Group
|
||||
from .ext.commands.cog import Cog
|
||||
from .ext.app_commands.handler import AppCommandHandler
|
||||
from .ext.app_commands.context import AppCommandContext
|
||||
@ -96,12 +98,16 @@ class Client:
|
||||
shard_count: Optional[int] = None,
|
||||
gateway_max_retries: int = 5,
|
||||
gateway_max_backoff: float = 60.0,
|
||||
member_cache_flags: Optional[MemberCacheFlags] = None,
|
||||
http_options: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if not token:
|
||||
raise ValueError("A bot token must be provided.")
|
||||
|
||||
self.token: str = token
|
||||
self.member_cache_flags: MemberCacheFlags = (
|
||||
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
|
||||
)
|
||||
self.intents: int = intents if intents is not None else GatewayIntent.default()
|
||||
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
||||
self.application_id: Optional[Snowflake] = (
|
||||
@ -141,15 +147,12 @@ class Client:
|
||||
)
|
||||
|
||||
# Internal Caches
|
||||
self._guilds: Dict[Snowflake, "Guild"] = {}
|
||||
self._channels: Dict[Snowflake, "Channel"] = (
|
||||
{}
|
||||
) # Stores all channel types by ID
|
||||
self._users: Dict[Snowflake, Any] = (
|
||||
{}
|
||||
) # Placeholder for User model cache if needed
|
||||
self._messages: Dict[Snowflake, "Message"] = {}
|
||||
self._guilds: GuildCache = GuildCache()
|
||||
self._channels: ChannelCache = ChannelCache()
|
||||
self._users: Cache["User"] = Cache()
|
||||
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour
|
||||
self._views: Dict[Snowflake, "View"] = {}
|
||||
self._persistent_views: Dict[str, "View"] = {}
|
||||
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
||||
self._webhooks: Dict[Snowflake, "Webhook"] = {}
|
||||
|
||||
@ -718,7 +721,7 @@ class Client:
|
||||
from .models import User # Ensure User model is available
|
||||
|
||||
user = User(data)
|
||||
self._users[user.id] = user # Cache the user
|
||||
self._users.set(user.id, user) # Cache the user
|
||||
return user
|
||||
|
||||
def parse_channel(self, data: Dict[str, Any]) -> "Channel":
|
||||
@ -727,11 +730,11 @@ class Client:
|
||||
from .models import channel_factory
|
||||
|
||||
channel = channel_factory(data, self)
|
||||
self._channels[channel.id] = channel
|
||||
self._channels.set(channel.id, channel)
|
||||
if channel.guild_id:
|
||||
guild = self._guilds.get(channel.guild_id)
|
||||
if guild:
|
||||
guild._channels[channel.id] = channel
|
||||
guild._channels.set(channel.id, channel)
|
||||
return channel
|
||||
|
||||
def parse_message(self, data: Dict[str, Any]) -> "Message":
|
||||
@ -740,7 +743,7 @@ class Client:
|
||||
from .models import Message
|
||||
|
||||
message = Message(data, client_instance=self)
|
||||
self._messages[message.id] = message
|
||||
self._messages.set(message.id, message)
|
||||
return message
|
||||
|
||||
def parse_webhook(self, data: Union[Dict[str, Any], "Webhook"]) -> "Webhook":
|
||||
@ -797,7 +800,7 @@ class Client:
|
||||
|
||||
cached_user = self._users.get(user_id)
|
||||
if cached_user:
|
||||
return cached_user # Return cached if available, though fetch implies wanting fresh
|
||||
return cached_user
|
||||
|
||||
try:
|
||||
user_data = await self._http.get_user(user_id)
|
||||
@ -827,23 +830,26 @@ class Client:
|
||||
)
|
||||
return None
|
||||
|
||||
def parse_member(self, data: Dict[str, Any], guild_id: Snowflake) -> "Member":
|
||||
def parse_member(
|
||||
self, data: Dict[str, Any], guild_id: Snowflake, *, just_joined: bool = False
|
||||
) -> "Member":
|
||||
"""Parses member data and returns a Member object, updating relevant caches."""
|
||||
from .models import Member # Ensure Member model is available
|
||||
from .models import Member
|
||||
|
||||
# Member's __init__ should handle the nested 'user' data.
|
||||
member = Member(data, client_instance=self)
|
||||
member.guild_id = str(guild_id)
|
||||
|
||||
# Cache the member in the guild's member cache
|
||||
if just_joined:
|
||||
setattr(member, "_just_joined", True)
|
||||
|
||||
guild = self._guilds.get(guild_id)
|
||||
if guild:
|
||||
guild._members[member.id] = member # Assuming Guild has _members dict
|
||||
guild._members.set(member.id, member)
|
||||
|
||||
# Also cache the user part if not already cached or if this is newer
|
||||
# Since Member inherits from User, the member object itself is the user.
|
||||
self._users[member.id] = member
|
||||
# If 'user' was in data and Member.__init__ used it, it's already part of 'member'.
|
||||
if just_joined and hasattr(member, "_just_joined"):
|
||||
delattr(member, "_just_joined")
|
||||
|
||||
self._users.set(member.id, member)
|
||||
return member
|
||||
|
||||
async def fetch_member(
|
||||
@ -887,20 +893,29 @@ class Client:
|
||||
|
||||
def parse_guild(self, data: Dict[str, Any]) -> "Guild":
|
||||
"""Parses guild data and returns a Guild object, updating cache."""
|
||||
|
||||
from .models import Guild
|
||||
|
||||
guild = Guild(data, client_instance=self)
|
||||
self._guilds[guild.id] = guild
|
||||
self._guilds.set(guild.id, guild)
|
||||
|
||||
# Populate channel and member caches if provided
|
||||
for ch in data.get("channels", []):
|
||||
channel_obj = self.parse_channel(ch)
|
||||
guild._channels[channel_obj.id] = channel_obj
|
||||
presences = {p["user"]["id"]: p for p in data.get("presences", [])}
|
||||
voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])}
|
||||
|
||||
for member in data.get("members", []):
|
||||
member_obj = self.parse_member(member, guild.id)
|
||||
guild._members[member_obj.id] = member_obj
|
||||
for ch_data in data.get("channels", []):
|
||||
self.parse_channel(ch_data)
|
||||
|
||||
for member_data in data.get("members", []):
|
||||
user_id = member_data.get("user", {}).get("id")
|
||||
if user_id:
|
||||
presence = presences.get(user_id)
|
||||
if presence:
|
||||
member_data["status"] = presence.get("status", "offline")
|
||||
|
||||
voice_state = voice_states.get(user_id)
|
||||
if voice_state:
|
||||
member_data["voice_state"] = voice_state
|
||||
|
||||
self.parse_member(member_data, guild.id)
|
||||
|
||||
return guild
|
||||
|
||||
@ -1305,7 +1320,7 @@ class Client:
|
||||
|
||||
channel = channel_factory(channel_data, self)
|
||||
|
||||
self._channels[channel.id] = channel
|
||||
self._channels.set(channel.id, channel)
|
||||
return channel
|
||||
|
||||
except DisagreementException as e: # Includes HTTPException
|
||||
|
@ -76,7 +76,7 @@ class EventDispatcher:
|
||||
"""Parses MESSAGE_DELETE and updates message cache."""
|
||||
message_id = data.get("id")
|
||||
if message_id:
|
||||
self._client._messages.pop(message_id, None)
|
||||
self._client._messages.invalidate(message_id)
|
||||
return data
|
||||
|
||||
def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@ -124,7 +124,7 @@ class EventDispatcher:
|
||||
"""Parses GUILD_MEMBER_ADD into a Member object."""
|
||||
|
||||
guild_id = str(data.get("guild_id"))
|
||||
return self._client.parse_member(data, guild_id)
|
||||
return self._client.parse_member(data, guild_id, just_joined=True)
|
||||
|
||||
def _parse_guild_member_remove(self, data: Dict[str, Any]):
|
||||
"""Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model."""
|
||||
|
@ -7,7 +7,9 @@ Data models for Discord objects.
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||
|
||||
from .cache import ChannelCache, MemberCache
|
||||
|
||||
import aiohttp # pylint: disable=import-error
|
||||
from .color import Color
|
||||
@ -656,6 +658,8 @@ class Member(User): # Member inherits from User
|
||||
):
|
||||
self._client: Optional["Client"] = client_instance
|
||||
self.guild_id: Optional[str] = None
|
||||
self.status: Optional[str] = None
|
||||
self.voice_state: Optional[Dict[str, Any]] = None
|
||||
# User part is nested under 'user' key in member data from gateway/API
|
||||
user_data = data.get("user", {})
|
||||
# If 'id' is not in user_data but is top-level (e.g. from interaction resolved member without user object)
|
||||
@ -1082,8 +1086,8 @@ class Guild:
|
||||
)
|
||||
|
||||
# Internal caches, populated by events or specific fetches
|
||||
self._channels: Dict[str, "Channel"] = {}
|
||||
self._members: Dict[str, Member] = {}
|
||||
self._channels: ChannelCache = ChannelCache()
|
||||
self._members: MemberCache = MemberCache(client_instance.member_cache_flags)
|
||||
self._threads: Dict[str, "Thread"] = {}
|
||||
|
||||
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
||||
@ -1338,12 +1342,98 @@ class TextChannel(Channel):
|
||||
|
||||
await self._client._http.bulk_delete_messages(self.id, ids)
|
||||
for mid in ids:
|
||||
self._client._messages.pop(mid, None)
|
||||
self._client._messages.invalidate(mid)
|
||||
return ids
|
||||
|
||||
def get_partial_message(self, id: int) -> "PartialMessage":
|
||||
"""Returns a :class:`PartialMessage` for the given ID.
|
||||
|
||||
This allows performing actions on a message without fetching it first.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
id: int
|
||||
The ID of the message to get a partial instance of.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PartialMessage
|
||||
The partial message instance.
|
||||
"""
|
||||
return PartialMessage(id=str(id), channel=self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TextChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
|
||||
|
||||
async def pins(self) -> List["Message"]:
|
||||
"""|coro|
|
||||
|
||||
Fetches all pinned messages in this channel.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Message]
|
||||
The pinned messages.
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException
|
||||
Fetching the pinned messages failed.
|
||||
"""
|
||||
|
||||
messages_data = await self._client._http.get_pinned_messages(self.id)
|
||||
return [self._client.parse_message(m) for m in messages_data]
|
||||
|
||||
async def create_thread(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
type: ChannelType = ChannelType.PUBLIC_THREAD,
|
||||
auto_archive_duration: Optional[int] = None,
|
||||
invitable: Optional[bool] = None,
|
||||
rate_limit_per_user: Optional[int] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> "Thread":
|
||||
"""|coro|
|
||||
|
||||
Creates a new thread in this channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the thread.
|
||||
type: ChannelType
|
||||
The type of thread to create. Defaults to PUBLIC_THREAD.
|
||||
Can be PUBLIC_THREAD, PRIVATE_THREAD, or ANNOUNCEMENT_THREAD.
|
||||
auto_archive_duration: Optional[int]
|
||||
The duration in minutes to automatically archive the thread after recent activity.
|
||||
invitable: Optional[bool]
|
||||
Whether non-moderators can invite other non-moderators to a private thread.
|
||||
Only applicable to private threads.
|
||||
rate_limit_per_user: Optional[int]
|
||||
The number of seconds a user has to wait before sending another message.
|
||||
reason: Optional[str]
|
||||
The reason for creating the thread.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Thread
|
||||
The created thread.
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"name": name,
|
||||
"type": type.value,
|
||||
}
|
||||
if auto_archive_duration is not None:
|
||||
payload["auto_archive_duration"] = auto_archive_duration
|
||||
if invitable is not None and type == ChannelType.PRIVATE_THREAD:
|
||||
payload["invitable"] = invitable
|
||||
if rate_limit_per_user is not None:
|
||||
payload["rate_limit_per_user"] = rate_limit_per_user
|
||||
|
||||
data = await self._client._http.start_thread_without_message(self.id, payload)
|
||||
return cast("Thread", self._client.parse_channel(data))
|
||||
|
||||
|
||||
class VoiceChannel(Channel):
|
||||
"""Represents a guild voice channel or stage voice channel."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user