Add LRU support to Cache and message cache size option (#61)

This commit is contained in:
Slipstream 2025-06-11 14:26:44 -06:00 committed by GitHub
parent 07daf78ef4
commit c47a7e49f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 76 additions and 29 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import time import time
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
from collections import OrderedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from .models import Channel, Guild, Member from .models import Channel, Guild, Member
@ -11,15 +12,22 @@ T = TypeVar("T")
class Cache(Generic[T]): class Cache(Generic[T]):
"""Simple in-memory cache with optional TTL support.""" """Simple in-memory cache with optional TTL and max size support."""
def __init__(self, ttl: Optional[float] = None) -> None: def __init__(
self, ttl: Optional[float] = None, maxlen: Optional[int] = None
) -> None:
self.ttl = ttl self.ttl = ttl
self._data: Dict[str, tuple[T, Optional[float]]] = {} self.maxlen = maxlen
self._data: "OrderedDict[str, tuple[T, Optional[float]]]" = OrderedDict()
def set(self, key: str, value: T) -> None: def set(self, key: str, value: T) -> None:
expiry = time.monotonic() + self.ttl if self.ttl is not None else None expiry = time.monotonic() + self.ttl if self.ttl is not None else None
if key in self._data:
self._data.move_to_end(key)
self._data[key] = (value, expiry) self._data[key] = (value, expiry)
if self.maxlen is not None and len(self._data) > self.maxlen:
self._data.popitem(last=False)
def get(self, key: str) -> Optional[T]: def get(self, key: str) -> Optional[T]:
item = self._data.get(key) item = self._data.get(key)
@ -29,6 +37,7 @@ class Cache(Generic[T]):
if expiry is not None and expiry < time.monotonic(): if expiry is not None and expiry < time.monotonic():
self.invalidate(key) self.invalidate(key)
return None return None
self._data.move_to_end(key)
return value return value
def invalidate(self, key: str) -> None: def invalidate(self, key: str) -> None:

View File

@ -82,6 +82,8 @@ class Client:
http_options (Optional[Dict[str, Any]]): Extra options passed to http_options (Optional[Dict[str, Any]]): Extra options passed to
:class:`HTTPClient` for creating the internal :class:`HTTPClient` for creating the internal
:class:`aiohttp.ClientSession`. :class:`aiohttp.ClientSession`.
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
in the cache. When ``None``, the cache size is unlimited.
""" """
def __init__( def __init__(
@ -99,6 +101,7 @@ class Client:
gateway_max_retries: int = 5, gateway_max_retries: int = 5,
gateway_max_backoff: float = 60.0, gateway_max_backoff: float = 60.0,
member_cache_flags: Optional[MemberCacheFlags] = None, member_cache_flags: Optional[MemberCacheFlags] = None,
message_cache_maxlen: Optional[int] = None,
http_options: Optional[Dict[str, Any]] = None, http_options: Optional[Dict[str, Any]] = None,
): ):
if not token: if not token:
@ -108,6 +111,7 @@ class Client:
self.member_cache_flags: MemberCacheFlags = ( self.member_cache_flags: MemberCacheFlags = (
member_cache_flags if member_cache_flags is not None else MemberCacheFlags() member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
) )
self.message_cache_maxlen: Optional[int] = message_cache_maxlen
self.intents: int = intents if intents is not None else GatewayIntent.default() self.intents: int = intents if intents is not None else GatewayIntent.default()
if loop: if loop:
self.loop: asyncio.AbstractEventLoop = loop self.loop: asyncio.AbstractEventLoop = loop
@ -157,7 +161,7 @@ class Client:
self._guilds: GuildCache = GuildCache() self._guilds: GuildCache = GuildCache()
self._channels: ChannelCache = ChannelCache() self._channels: ChannelCache = ChannelCache()
self._users: Cache["User"] = Cache() self._users: Cache["User"] = Cache()
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen)
self._views: Dict[Snowflake, "View"] = {} self._views: Dict[Snowflake, "View"] = {}
self._persistent_views: Dict[str, "View"] = {} self._persistent_views: Dict[str, "View"] = {}
self._voice_clients: Dict[Snowflake, VoiceClient] = {} self._voice_clients: Dict[Snowflake, VoiceClient] = {}
@ -1514,35 +1518,35 @@ class Client:
return [self.parse_invite(inv) for inv in data] return [self.parse_invite(inv) for inv in data]
def add_persistent_view(self, view: "View") -> None: def add_persistent_view(self, view: "View") -> None:
""" """
Registers a persistent view with the client. Registers a persistent view with the client.
Persistent views have a timeout of `None` and their components must have a `custom_id`. Persistent views have a timeout of `None` and their components must have a `custom_id`.
This allows the view to be re-instantiated across bot restarts. This allows the view to be re-instantiated across bot restarts.
Args: Args:
view (View): The view instance to register. view (View): The view instance to register.
Raises: Raises:
ValueError: If the view is not persistent (timeout is not None) or if a component's ValueError: If the view is not persistent (timeout is not None) or if a component's
custom_id is already registered. custom_id is already registered.
""" """
if self.is_ready(): if self.is_ready():
print( print(
"Warning: Adding a persistent view after the client is ready. " "Warning: Adding a persistent view after the client is ready. "
"This view will only be available for interactions on this session." "This view will only be available for interactions on this session."
) )
if view.timeout is not None: if view.timeout is not None:
raise ValueError("Persistent views must have a timeout of None.") raise ValueError("Persistent views must have a timeout of None.")
for item in view.children: for item in view.children:
if item.custom_id: # Ensure custom_id is not None if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views: if item.custom_id in self._persistent_views:
raise ValueError( raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered." f"A component with custom_id '{item.custom_id}' is already registered."
) )
self._persistent_views[item.custom_id] = view self._persistent_views[item.custom_id] = view
# --- Application Command Methods --- # --- Application Command Methods ---
async def process_interaction(self, interaction: Interaction) -> None: async def process_interaction(self, interaction: Interaction) -> None:

View File

@ -15,3 +15,14 @@ def test_cache_ttl_expiry():
assert cache.get("b") == 1 assert cache.get("b") == 1
time.sleep(0.02) time.sleep(0.02)
assert cache.get("b") is None assert cache.get("b") is None
def test_cache_lru_eviction():
cache = Cache(maxlen=2)
cache.set("a", 1)
cache.set("b", 2)
assert cache.get("a") == 1
cache.set("c", 3)
assert cache.get("b") is None
assert cache.get("a") == 1
assert cache.get("c") == 3

View File

@ -0,0 +1,23 @@
import pytest
from disagreement.client import Client
def _add_message(client: Client, message_id: str) -> None:
data = {
"id": message_id,
"channel_id": "c",
"author": {"id": "u", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
client.parse_message(data)
def test_client_message_cache_size():
client = Client(token="t", message_cache_maxlen=1)
_add_message(client, "1")
assert client._messages.get("1").id == "1"
_add_message(client, "2")
assert client._messages.get("1") is None
assert client._messages.get("2").id == "2"