Add LRU support to Cache and message cache size option (#61)
This commit is contained in:
parent
07daf78ef4
commit
c47a7e49f8
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
||||
from collections import OrderedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import Channel, Guild, Member
|
||||
@ -11,15 +12,22 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
class Cache(Generic[T]):
|
||||
"""Simple in-memory cache with optional TTL support."""
|
||||
"""Simple in-memory cache with optional TTL and max size support."""
|
||||
|
||||
def __init__(self, ttl: Optional[float] = None) -> None:
|
||||
def __init__(
|
||||
self, ttl: Optional[float] = None, maxlen: Optional[int] = None
|
||||
) -> None:
|
||||
self.ttl = ttl
|
||||
self._data: Dict[str, tuple[T, Optional[float]]] = {}
|
||||
self.maxlen = maxlen
|
||||
self._data: "OrderedDict[str, tuple[T, Optional[float]]]" = OrderedDict()
|
||||
|
||||
def set(self, key: str, value: T) -> None:
|
||||
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
|
||||
if key in self._data:
|
||||
self._data.move_to_end(key)
|
||||
self._data[key] = (value, expiry)
|
||||
if self.maxlen is not None and len(self._data) > self.maxlen:
|
||||
self._data.popitem(last=False)
|
||||
|
||||
def get(self, key: str) -> Optional[T]:
|
||||
item = self._data.get(key)
|
||||
@ -29,6 +37,7 @@ class Cache(Generic[T]):
|
||||
if expiry is not None and expiry < time.monotonic():
|
||||
self.invalidate(key)
|
||||
return None
|
||||
self._data.move_to_end(key)
|
||||
return value
|
||||
|
||||
def invalidate(self, key: str) -> None:
|
||||
|
@ -82,6 +82,8 @@ class Client:
|
||||
http_options (Optional[Dict[str, Any]]): Extra options passed to
|
||||
:class:`HTTPClient` for creating the internal
|
||||
:class:`aiohttp.ClientSession`.
|
||||
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
|
||||
in the cache. When ``None``, the cache size is unlimited.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -99,6 +101,7 @@ class Client:
|
||||
gateway_max_retries: int = 5,
|
||||
gateway_max_backoff: float = 60.0,
|
||||
member_cache_flags: Optional[MemberCacheFlags] = None,
|
||||
message_cache_maxlen: Optional[int] = None,
|
||||
http_options: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if not token:
|
||||
@ -108,6 +111,7 @@ class Client:
|
||||
self.member_cache_flags: MemberCacheFlags = (
|
||||
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
|
||||
)
|
||||
self.message_cache_maxlen: Optional[int] = message_cache_maxlen
|
||||
self.intents: int = intents if intents is not None else GatewayIntent.default()
|
||||
if loop:
|
||||
self.loop: asyncio.AbstractEventLoop = loop
|
||||
@ -157,7 +161,7 @@ class Client:
|
||||
self._guilds: GuildCache = GuildCache()
|
||||
self._channels: ChannelCache = ChannelCache()
|
||||
self._users: Cache["User"] = Cache()
|
||||
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour
|
||||
self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen)
|
||||
self._views: Dict[Snowflake, "View"] = {}
|
||||
self._persistent_views: Dict[str, "View"] = {}
|
||||
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
||||
@ -1514,35 +1518,35 @@ class Client:
|
||||
return [self.parse_invite(inv) for inv in data]
|
||||
|
||||
def add_persistent_view(self, view: "View") -> None:
|
||||
"""
|
||||
Registers a persistent view with the client.
|
||||
"""
|
||||
Registers a persistent view with the client.
|
||||
|
||||
Persistent views have a timeout of `None` and their components must have a `custom_id`.
|
||||
This allows the view to be re-instantiated across bot restarts.
|
||||
Persistent views have a timeout of `None` and their components must have a `custom_id`.
|
||||
This allows the view to be re-instantiated across bot restarts.
|
||||
|
||||
Args:
|
||||
view (View): The view instance to register.
|
||||
Args:
|
||||
view (View): The view instance to register.
|
||||
|
||||
Raises:
|
||||
ValueError: If the view is not persistent (timeout is not None) or if a component's
|
||||
custom_id is already registered.
|
||||
"""
|
||||
if self.is_ready():
|
||||
print(
|
||||
"Warning: Adding a persistent view after the client is ready. "
|
||||
"This view will only be available for interactions on this session."
|
||||
)
|
||||
Raises:
|
||||
ValueError: If the view is not persistent (timeout is not None) or if a component's
|
||||
custom_id is already registered.
|
||||
"""
|
||||
if self.is_ready():
|
||||
print(
|
||||
"Warning: Adding a persistent view after the client is ready. "
|
||||
"This view will only be available for interactions on this session."
|
||||
)
|
||||
|
||||
if view.timeout is not None:
|
||||
raise ValueError("Persistent views must have a timeout of None.")
|
||||
if view.timeout is not None:
|
||||
raise ValueError("Persistent views must have a timeout of None.")
|
||||
|
||||
for item in view.children:
|
||||
if item.custom_id: # Ensure custom_id is not None
|
||||
if item.custom_id in self._persistent_views:
|
||||
raise ValueError(
|
||||
f"A component with custom_id '{item.custom_id}' is already registered."
|
||||
)
|
||||
self._persistent_views[item.custom_id] = view
|
||||
for item in view.children:
|
||||
if item.custom_id: # Ensure custom_id is not None
|
||||
if item.custom_id in self._persistent_views:
|
||||
raise ValueError(
|
||||
f"A component with custom_id '{item.custom_id}' is already registered."
|
||||
)
|
||||
self._persistent_views[item.custom_id] = view
|
||||
|
||||
# --- Application Command Methods ---
|
||||
async def process_interaction(self, interaction: Interaction) -> None:
|
||||
|
@ -15,3 +15,14 @@ def test_cache_ttl_expiry():
|
||||
assert cache.get("b") == 1
|
||||
time.sleep(0.02)
|
||||
assert cache.get("b") is None
|
||||
|
||||
|
||||
def test_cache_lru_eviction():
|
||||
cache = Cache(maxlen=2)
|
||||
cache.set("a", 1)
|
||||
cache.set("b", 2)
|
||||
assert cache.get("a") == 1
|
||||
cache.set("c", 3)
|
||||
assert cache.get("b") is None
|
||||
assert cache.get("a") == 1
|
||||
assert cache.get("c") == 3
|
||||
|
23
tests/test_client_message_cache.py
Normal file
23
tests/test_client_message_cache.py
Normal file
@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from disagreement.client import Client
|
||||
|
||||
|
||||
def _add_message(client: Client, message_id: str) -> None:
|
||||
data = {
|
||||
"id": message_id,
|
||||
"channel_id": "c",
|
||||
"author": {"id": "u", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
client.parse_message(data)
|
||||
|
||||
|
||||
def test_client_message_cache_size():
|
||||
client = Client(token="t", message_cache_maxlen=1)
|
||||
_add_message(client, "1")
|
||||
assert client._messages.get("1").id == "1"
|
||||
_add_message(client, "2")
|
||||
assert client._messages.get("1") is None
|
||||
assert client._messages.get("2").id == "2"
|
Loading…
x
Reference in New Issue
Block a user