Add get_or_fetch cache helper and tests (#72)
This commit is contained in:
parent
0a3f680e7f
commit
235ea8fc69
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar
|
from typing import TYPE_CHECKING, Callable, Dict, Generic, Optional, TypeVar
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -40,6 +40,14 @@ class Cache(Generic[T]):
|
|||||||
self._data.move_to_end(key)
|
self._data.move_to_end(key)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
def get_or_fetch(self, key: str, fetch_fn: Callable[[], T]) -> T:
|
||||||
|
"""Return a cached item or fetch and store it if missing."""
|
||||||
|
value = self.get(key)
|
||||||
|
if value is None:
|
||||||
|
value = fetch_fn()
|
||||||
|
self.set(key, value)
|
||||||
|
return value
|
||||||
|
|
||||||
def invalidate(self, key: str) -> None:
|
def invalidate(self, key: str) -> None:
|
||||||
self._data.pop(key, None)
|
self._data.pop(key, None)
|
||||||
|
|
||||||
|
@ -26,3 +26,42 @@ def test_cache_lru_eviction():
|
|||||||
assert cache.get("b") is None
|
assert cache.get("b") is None
|
||||||
assert cache.get("a") == 1
|
assert cache.get("a") == 1
|
||||||
assert cache.get("c") == 3
|
assert cache.get("c") == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_or_fetch_uses_cache():
|
||||||
|
cache = Cache()
|
||||||
|
cache.set("a", 1)
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
raise AssertionError("fetch should not be called")
|
||||||
|
|
||||||
|
assert cache.get_or_fetch("a", fetch) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_or_fetch_fetches_and_stores():
|
||||||
|
cache = Cache()
|
||||||
|
called = False
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
return 2
|
||||||
|
|
||||||
|
assert cache.get_or_fetch("b", fetch) == 2
|
||||||
|
assert called
|
||||||
|
assert cache.get("b") == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_or_fetch_fetches_expired_item():
|
||||||
|
cache = Cache(ttl=0.01)
|
||||||
|
cache.set("c", 1)
|
||||||
|
time.sleep(0.02)
|
||||||
|
called = False
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
return 3
|
||||||
|
|
||||||
|
assert cache.get_or_fetch("c", fetch) == 3
|
||||||
|
assert called
|
||||||
|
Loading…
x
Reference in New Issue
Block a user