diff --git a/disagreement/cache.py b/disagreement/cache.py index 92eef02..32c6639 100644 --- a/disagreement/cache.py +++ b/disagreement/cache.py @@ -1,7 +1,7 @@ from __future__ import annotations import time -from typing import TYPE_CHECKING, Dict, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Callable, Dict, Generic, Optional, TypeVar from collections import OrderedDict if TYPE_CHECKING: @@ -40,6 +40,14 @@ class Cache(Generic[T]): self._data.move_to_end(key) return value + def get_or_fetch(self, key: str, fetch_fn: Callable[[], T]) -> T: + """Return a cached item or fetch and store it if missing.""" + value = self.get(key) + if value is None: + value = fetch_fn() + self.set(key, value) + return value + def invalidate(self, key: str) -> None: self._data.pop(key, None) diff --git a/tests/test_cache.py b/tests/test_cache.py index 6909697..88effe9 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -26,3 +26,42 @@ def test_cache_lru_eviction(): assert cache.get("b") is None assert cache.get("a") == 1 assert cache.get("c") == 3 + + +def test_get_or_fetch_uses_cache(): + cache = Cache() + cache.set("a", 1) + + def fetch(): + raise AssertionError("fetch should not be called") + + assert cache.get_or_fetch("a", fetch) == 1 + + +def test_get_or_fetch_fetches_and_stores(): + cache = Cache() + called = False + + def fetch(): + nonlocal called + called = True + return 2 + + assert cache.get_or_fetch("b", fetch) == 2 + assert called + assert cache.get("b") == 2 + + +def test_get_or_fetch_fetches_expired_item(): + cache = Cache(ttl=0.01) + cache.set("c", 1) + time.sleep(0.02) + called = False + + def fetch(): + nonlocal called + called = True + return 3 + + assert cache.get_or_fetch("c", fetch) == 3 + assert called