diff --git a/disagreement/cache.py b/disagreement/cache.py index 178e8ae..92eef02 100644 --- a/disagreement/cache.py +++ b/disagreement/cache.py @@ -2,6 +2,7 @@ from __future__ import annotations import time from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar +from collections import OrderedDict if TYPE_CHECKING: from .models import Channel, Guild, Member @@ -11,15 +12,22 @@ T = TypeVar("T") class Cache(Generic[T]): - """Simple in-memory cache with optional TTL support.""" + """Simple in-memory cache with optional TTL and max size support.""" - def __init__(self, ttl: Optional[float] = None) -> None: + def __init__( + self, ttl: Optional[float] = None, maxlen: Optional[int] = None + ) -> None: self.ttl = ttl - self._data: Dict[str, tuple[T, Optional[float]]] = {} + self.maxlen = maxlen + self._data: "OrderedDict[str, tuple[T, Optional[float]]]" = OrderedDict() def set(self, key: str, value: T) -> None: expiry = time.monotonic() + self.ttl if self.ttl is not None else None + if key in self._data: + self._data.move_to_end(key) self._data[key] = (value, expiry) + if self.maxlen is not None and len(self._data) > self.maxlen: + self._data.popitem(last=False) def get(self, key: str) -> Optional[T]: item = self._data.get(key) @@ -29,6 +37,7 @@ class Cache(Generic[T]): if expiry is not None and expiry < time.monotonic(): self.invalidate(key) return None + self._data.move_to_end(key) return value def invalidate(self, key: str) -> None: diff --git a/disagreement/client.py b/disagreement/client.py index 6ed4462..b1316b7 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -82,6 +82,8 @@ class Client: http_options (Optional[Dict[str, Any]]): Extra options passed to :class:`HTTPClient` for creating the internal :class:`aiohttp.ClientSession`. + message_cache_maxlen (Optional[int]): Maximum number of messages to keep + in the cache. When ``None``, the cache size is unlimited. """ def __init__( @@ -99,6 +101,7 @@ class Client: gateway_max_retries: int = 5, gateway_max_backoff: float = 60.0, member_cache_flags: Optional[MemberCacheFlags] = None, + message_cache_maxlen: Optional[int] = None, http_options: Optional[Dict[str, Any]] = None, ): if not token: @@ -108,6 +111,7 @@ class Client: self.member_cache_flags: MemberCacheFlags = ( member_cache_flags if member_cache_flags is not None else MemberCacheFlags() ) + self.message_cache_maxlen: Optional[int] = message_cache_maxlen self.intents: int = intents if intents is not None else GatewayIntent.default() if loop: self.loop: asyncio.AbstractEventLoop = loop @@ -157,7 +161,7 @@ class Client: self._guilds: GuildCache = GuildCache() self._channels: ChannelCache = ChannelCache() self._users: Cache["User"] = Cache() - self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour + self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen) self._views: Dict[Snowflake, "View"] = {} self._persistent_views: Dict[str, "View"] = {} self._voice_clients: Dict[Snowflake, VoiceClient] = {} @@ -693,7 +697,7 @@ class Client: ) # import traceback # traceback.print_exception(type(error.original), error.original, error.original.__traceback__) - + async def on_command_completion(self, ctx: "CommandContext") -> None: """ Default command completion handler. Called when a command has successfully completed. @@ -1514,35 +1518,35 @@ class Client: return [self.parse_invite(inv) for inv in data] def add_persistent_view(self, view: "View") -> None: - """ - Registers a persistent view with the client. + """ + Registers a persistent view with the client. - Persistent views have a timeout of `None` and their components must have a `custom_id`. - This allows the view to be re-instantiated across bot restarts. + Persistent views have a timeout of `None` and their components must have a `custom_id`. + This allows the view to be re-instantiated across bot restarts. - Args: - view (View): The view instance to register. + Args: + view (View): The view instance to register. - Raises: - ValueError: If the view is not persistent (timeout is not None) or if a component's - custom_id is already registered. - """ - if self.is_ready(): - print( - "Warning: Adding a persistent view after the client is ready. " - "This view will only be available for interactions on this session." - ) + Raises: + ValueError: If the view is not persistent (timeout is not None) or if a component's + custom_id is already registered. + """ + if self.is_ready(): + print( + "Warning: Adding a persistent view after the client is ready. " + "This view will only be available for interactions on this session." + ) - if view.timeout is not None: - raise ValueError("Persistent views must have a timeout of None.") + if view.timeout is not None: + raise ValueError("Persistent views must have a timeout of None.") - for item in view.children: - if item.custom_id: # Ensure custom_id is not None - if item.custom_id in self._persistent_views: - raise ValueError( - f"A component with custom_id '{item.custom_id}' is already registered." - ) - self._persistent_views[item.custom_id] = view + for item in view.children: + if item.custom_id: # Ensure custom_id is not None + if item.custom_id in self._persistent_views: + raise ValueError( + f"A component with custom_id '{item.custom_id}' is already registered." + ) + self._persistent_views[item.custom_id] = view # --- Application Command Methods --- async def process_interaction(self, interaction: Interaction) -> None: diff --git a/tests/test_cache.py b/tests/test_cache.py index 234077e..6909697 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -15,3 +15,14 @@ def test_cache_ttl_expiry(): assert cache.get("b") == 1 time.sleep(0.02) assert cache.get("b") is None + + +def test_cache_lru_eviction(): + cache = Cache(maxlen=2) + cache.set("a", 1) + cache.set("b", 2) + assert cache.get("a") == 1 + cache.set("c", 3) + assert cache.get("b") is None + assert cache.get("a") == 1 + assert cache.get("c") == 3 diff --git a/tests/test_client_message_cache.py b/tests/test_client_message_cache.py new file mode 100644 index 0000000..f89d98c --- /dev/null +++ b/tests/test_client_message_cache.py @@ -0,0 +1,23 @@ +import pytest + +from disagreement.client import Client + + +def _add_message(client: Client, message_id: str) -> None: + data = { + "id": message_id, + "channel_id": "c", + "author": {"id": "u", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + client.parse_message(data) + + +def test_client_message_cache_size(): + client = Client(token="t", message_cache_maxlen=1) + _add_message(client, "1") + assert client._messages.get("1").id == "1" + _add_message(client, "2") + assert client._messages.get("1") is None + assert client._messages.get("2").id == "2"