Add LRU support to Cache and message cache size option (#61)

This commit is contained in:
Slipstream 2025-06-11 14:26:44 -06:00 committed by GitHub
parent 07daf78ef4
commit c47a7e49f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 76 additions and 29 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
from collections import OrderedDict
if TYPE_CHECKING:
from .models import Channel, Guild, Member
@ -11,15 +12,22 @@ T = TypeVar("T")
class Cache(Generic[T]):
"""Simple in-memory cache with optional TTL support."""
"""Simple in-memory cache with optional TTL and max size support."""
def __init__(self, ttl: Optional[float] = None) -> None:
def __init__(
self, ttl: Optional[float] = None, maxlen: Optional[int] = None
) -> None:
self.ttl = ttl
self._data: Dict[str, tuple[T, Optional[float]]] = {}
self.maxlen = maxlen
self._data: "OrderedDict[str, tuple[T, Optional[float]]]" = OrderedDict()
def set(self, key: str, value: T) -> None:
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
if key in self._data:
self._data.move_to_end(key)
self._data[key] = (value, expiry)
if self.maxlen is not None and len(self._data) > self.maxlen:
self._data.popitem(last=False)
def get(self, key: str) -> Optional[T]:
item = self._data.get(key)
@ -29,6 +37,7 @@ class Cache(Generic[T]):
if expiry is not None and expiry < time.monotonic():
self.invalidate(key)
return None
self._data.move_to_end(key)
return value
def invalidate(self, key: str) -> None:

View File

@ -82,6 +82,8 @@ class Client:
http_options (Optional[Dict[str, Any]]): Extra options passed to
:class:`HTTPClient` for creating the internal
:class:`aiohttp.ClientSession`.
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
in the cache. When ``None``, the cache size is unlimited.
"""
def __init__(
@ -99,6 +101,7 @@ class Client:
gateway_max_retries: int = 5,
gateway_max_backoff: float = 60.0,
member_cache_flags: Optional[MemberCacheFlags] = None,
message_cache_maxlen: Optional[int] = None,
http_options: Optional[Dict[str, Any]] = None,
):
if not token:
@ -108,6 +111,7 @@ class Client:
self.member_cache_flags: MemberCacheFlags = (
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
)
self.message_cache_maxlen: Optional[int] = message_cache_maxlen
self.intents: int = intents if intents is not None else GatewayIntent.default()
if loop:
self.loop: asyncio.AbstractEventLoop = loop
@ -157,7 +161,7 @@ class Client:
self._guilds: GuildCache = GuildCache()
self._channels: ChannelCache = ChannelCache()
self._users: Cache["User"] = Cache()
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour
self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen)
self._views: Dict[Snowflake, "View"] = {}
self._persistent_views: Dict[str, "View"] = {}
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
@ -693,7 +697,7 @@ class Client:
)
# import traceback
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
async def on_command_completion(self, ctx: "CommandContext") -> None:
"""
Default command completion handler. Called when a command has successfully completed.
@ -1514,35 +1518,35 @@ class Client:
return [self.parse_invite(inv) for inv in data]
def add_persistent_view(self, view: "View") -> None:
"""
Registers a persistent view with the client.
"""
Registers a persistent view with the client.
Persistent views have a timeout of `None` and their components must have a `custom_id`.
This allows the view to be re-instantiated across bot restarts.
Persistent views have a timeout of `None` and their components must have a `custom_id`.
This allows the view to be re-instantiated across bot restarts.
Args:
view (View): The view instance to register.
Args:
view (View): The view instance to register.
Raises:
ValueError: If the view is not persistent (timeout is not None) or if a component's
custom_id is already registered.
"""
if self.is_ready():
print(
"Warning: Adding a persistent view after the client is ready. "
"This view will only be available for interactions on this session."
)
Raises:
ValueError: If the view is not persistent (timeout is not None) or if a component's
custom_id is already registered.
"""
if self.is_ready():
print(
"Warning: Adding a persistent view after the client is ready. "
"This view will only be available for interactions on this session."
)
if view.timeout is not None:
raise ValueError("Persistent views must have a timeout of None.")
if view.timeout is not None:
raise ValueError("Persistent views must have a timeout of None.")
for item in view.children:
if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views:
raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered."
)
self._persistent_views[item.custom_id] = view
for item in view.children:
if item.custom_id: # Ensure custom_id is not None
if item.custom_id in self._persistent_views:
raise ValueError(
f"A component with custom_id '{item.custom_id}' is already registered."
)
self._persistent_views[item.custom_id] = view
# --- Application Command Methods ---
async def process_interaction(self, interaction: Interaction) -> None:

View File

@ -15,3 +15,14 @@ def test_cache_ttl_expiry():
assert cache.get("b") == 1
time.sleep(0.02)
assert cache.get("b") is None
def test_cache_lru_eviction():
cache = Cache(maxlen=2)
cache.set("a", 1)
cache.set("b", 2)
assert cache.get("a") == 1
cache.set("c", 3)
assert cache.get("b") is None
assert cache.get("a") == 1
assert cache.get("c") == 3

View File

@ -0,0 +1,23 @@
import pytest
from disagreement.client import Client
def _add_message(client: Client, message_id: str) -> None:
data = {
"id": message_id,
"channel_id": "c",
"author": {"id": "u", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
client.parse_message(data)
def test_client_message_cache_size():
client = Client(token="t", message_cache_maxlen=1)
_add_message(client, "1")
assert client._messages.get("1").id == "1"
_add_message(client, "2")
assert client._messages.get("1") is None
assert client._messages.get("2").id == "2"