Add LRU support to Cache and message cache size option (#61)
This commit is contained in:
parent
07daf78ef4
commit
c47a7e49f8
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .models import Channel, Guild, Member
|
from .models import Channel, Guild, Member
|
||||||
@ -11,15 +12,22 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
class Cache(Generic[T]):
|
class Cache(Generic[T]):
|
||||||
"""Simple in-memory cache with optional TTL support."""
|
"""Simple in-memory cache with optional TTL and max size support."""
|
||||||
|
|
||||||
def __init__(self, ttl: Optional[float] = None) -> None:
|
def __init__(
|
||||||
|
self, ttl: Optional[float] = None, maxlen: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
self.ttl = ttl
|
self.ttl = ttl
|
||||||
self._data: Dict[str, tuple[T, Optional[float]]] = {}
|
self.maxlen = maxlen
|
||||||
|
self._data: "OrderedDict[str, tuple[T, Optional[float]]]" = OrderedDict()
|
||||||
|
|
||||||
def set(self, key: str, value: T) -> None:
|
def set(self, key: str, value: T) -> None:
|
||||||
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
|
expiry = time.monotonic() + self.ttl if self.ttl is not None else None
|
||||||
|
if key in self._data:
|
||||||
|
self._data.move_to_end(key)
|
||||||
self._data[key] = (value, expiry)
|
self._data[key] = (value, expiry)
|
||||||
|
if self.maxlen is not None and len(self._data) > self.maxlen:
|
||||||
|
self._data.popitem(last=False)
|
||||||
|
|
||||||
def get(self, key: str) -> Optional[T]:
|
def get(self, key: str) -> Optional[T]:
|
||||||
item = self._data.get(key)
|
item = self._data.get(key)
|
||||||
@ -29,6 +37,7 @@ class Cache(Generic[T]):
|
|||||||
if expiry is not None and expiry < time.monotonic():
|
if expiry is not None and expiry < time.monotonic():
|
||||||
self.invalidate(key)
|
self.invalidate(key)
|
||||||
return None
|
return None
|
||||||
|
self._data.move_to_end(key)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def invalidate(self, key: str) -> None:
|
def invalidate(self, key: str) -> None:
|
||||||
|
@ -82,6 +82,8 @@ class Client:
|
|||||||
http_options (Optional[Dict[str, Any]]): Extra options passed to
|
http_options (Optional[Dict[str, Any]]): Extra options passed to
|
||||||
:class:`HTTPClient` for creating the internal
|
:class:`HTTPClient` for creating the internal
|
||||||
:class:`aiohttp.ClientSession`.
|
:class:`aiohttp.ClientSession`.
|
||||||
|
message_cache_maxlen (Optional[int]): Maximum number of messages to keep
|
||||||
|
in the cache. When ``None``, the cache size is unlimited.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -99,6 +101,7 @@ class Client:
|
|||||||
gateway_max_retries: int = 5,
|
gateway_max_retries: int = 5,
|
||||||
gateway_max_backoff: float = 60.0,
|
gateway_max_backoff: float = 60.0,
|
||||||
member_cache_flags: Optional[MemberCacheFlags] = None,
|
member_cache_flags: Optional[MemberCacheFlags] = None,
|
||||||
|
message_cache_maxlen: Optional[int] = None,
|
||||||
http_options: Optional[Dict[str, Any]] = None,
|
http_options: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
if not token:
|
if not token:
|
||||||
@ -108,6 +111,7 @@ class Client:
|
|||||||
self.member_cache_flags: MemberCacheFlags = (
|
self.member_cache_flags: MemberCacheFlags = (
|
||||||
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
|
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
|
||||||
)
|
)
|
||||||
|
self.message_cache_maxlen: Optional[int] = message_cache_maxlen
|
||||||
self.intents: int = intents if intents is not None else GatewayIntent.default()
|
self.intents: int = intents if intents is not None else GatewayIntent.default()
|
||||||
if loop:
|
if loop:
|
||||||
self.loop: asyncio.AbstractEventLoop = loop
|
self.loop: asyncio.AbstractEventLoop = loop
|
||||||
@ -157,7 +161,7 @@ class Client:
|
|||||||
self._guilds: GuildCache = GuildCache()
|
self._guilds: GuildCache = GuildCache()
|
||||||
self._channels: ChannelCache = ChannelCache()
|
self._channels: ChannelCache = ChannelCache()
|
||||||
self._users: Cache["User"] = Cache()
|
self._users: Cache["User"] = Cache()
|
||||||
self._messages: Cache["Message"] = Cache(ttl=3600) # Cache messages for an hour
|
self._messages: Cache["Message"] = Cache(ttl=3600, maxlen=message_cache_maxlen)
|
||||||
self._views: Dict[Snowflake, "View"] = {}
|
self._views: Dict[Snowflake, "View"] = {}
|
||||||
self._persistent_views: Dict[str, "View"] = {}
|
self._persistent_views: Dict[str, "View"] = {}
|
||||||
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
self._voice_clients: Dict[Snowflake, VoiceClient] = {}
|
||||||
@ -693,7 +697,7 @@ class Client:
|
|||||||
)
|
)
|
||||||
# import traceback
|
# import traceback
|
||||||
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
|
# traceback.print_exception(type(error.original), error.original, error.original.__traceback__)
|
||||||
|
|
||||||
async def on_command_completion(self, ctx: "CommandContext") -> None:
|
async def on_command_completion(self, ctx: "CommandContext") -> None:
|
||||||
"""
|
"""
|
||||||
Default command completion handler. Called when a command has successfully completed.
|
Default command completion handler. Called when a command has successfully completed.
|
||||||
@ -1514,35 +1518,35 @@ class Client:
|
|||||||
return [self.parse_invite(inv) for inv in data]
|
return [self.parse_invite(inv) for inv in data]
|
||||||
|
|
||||||
def add_persistent_view(self, view: "View") -> None:
|
def add_persistent_view(self, view: "View") -> None:
|
||||||
"""
|
"""
|
||||||
Registers a persistent view with the client.
|
Registers a persistent view with the client.
|
||||||
|
|
||||||
Persistent views have a timeout of `None` and their components must have a `custom_id`.
|
Persistent views have a timeout of `None` and their components must have a `custom_id`.
|
||||||
This allows the view to be re-instantiated across bot restarts.
|
This allows the view to be re-instantiated across bot restarts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
view (View): The view instance to register.
|
view (View): The view instance to register.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the view is not persistent (timeout is not None) or if a component's
|
ValueError: If the view is not persistent (timeout is not None) or if a component's
|
||||||
custom_id is already registered.
|
custom_id is already registered.
|
||||||
"""
|
"""
|
||||||
if self.is_ready():
|
if self.is_ready():
|
||||||
print(
|
print(
|
||||||
"Warning: Adding a persistent view after the client is ready. "
|
"Warning: Adding a persistent view after the client is ready. "
|
||||||
"This view will only be available for interactions on this session."
|
"This view will only be available for interactions on this session."
|
||||||
)
|
)
|
||||||
|
|
||||||
if view.timeout is not None:
|
if view.timeout is not None:
|
||||||
raise ValueError("Persistent views must have a timeout of None.")
|
raise ValueError("Persistent views must have a timeout of None.")
|
||||||
|
|
||||||
for item in view.children:
|
for item in view.children:
|
||||||
if item.custom_id: # Ensure custom_id is not None
|
if item.custom_id: # Ensure custom_id is not None
|
||||||
if item.custom_id in self._persistent_views:
|
if item.custom_id in self._persistent_views:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"A component with custom_id '{item.custom_id}' is already registered."
|
f"A component with custom_id '{item.custom_id}' is already registered."
|
||||||
)
|
)
|
||||||
self._persistent_views[item.custom_id] = view
|
self._persistent_views[item.custom_id] = view
|
||||||
|
|
||||||
# --- Application Command Methods ---
|
# --- Application Command Methods ---
|
||||||
async def process_interaction(self, interaction: Interaction) -> None:
|
async def process_interaction(self, interaction: Interaction) -> None:
|
||||||
|
@ -15,3 +15,14 @@ def test_cache_ttl_expiry():
|
|||||||
assert cache.get("b") == 1
|
assert cache.get("b") == 1
|
||||||
time.sleep(0.02)
|
time.sleep(0.02)
|
||||||
assert cache.get("b") is None
|
assert cache.get("b") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_lru_eviction():
|
||||||
|
cache = Cache(maxlen=2)
|
||||||
|
cache.set("a", 1)
|
||||||
|
cache.set("b", 2)
|
||||||
|
assert cache.get("a") == 1
|
||||||
|
cache.set("c", 3)
|
||||||
|
assert cache.get("b") is None
|
||||||
|
assert cache.get("a") == 1
|
||||||
|
assert cache.get("c") == 3
|
||||||
|
23
tests/test_client_message_cache.py
Normal file
23
tests/test_client_message_cache.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
def _add_message(client: Client, message_id: str) -> None:
|
||||||
|
data = {
|
||||||
|
"id": message_id,
|
||||||
|
"channel_id": "c",
|
||||||
|
"author": {"id": "u", "username": "u", "discriminator": "0001"},
|
||||||
|
"content": "hi",
|
||||||
|
"timestamp": "t",
|
||||||
|
}
|
||||||
|
client.parse_message(data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_message_cache_size():
|
||||||
|
client = Client(token="t", message_cache_maxlen=1)
|
||||||
|
_add_message(client, "1")
|
||||||
|
assert client._messages.get("1").id == "1"
|
||||||
|
_add_message(client, "2")
|
||||||
|
assert client._messages.get("1") is None
|
||||||
|
assert client._messages.get("2").id == "2"
|
Loading…
x
Reference in New Issue
Block a user