Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to DictionaryCache and TTLCache. #9442

Merged
merged 7 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9442.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to the caching module.
10 changes: 6 additions & 4 deletions synapse/http/federation/well_known_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@
logger = logging.getLogger(__name__)


_well_known_cache = TTLCache("well-known")
_had_valid_well_known_cache = TTLCache("had-valid-well-known")
_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
_had_valid_well_known_cache = TTLCache(
"had-valid-well-known"
) # type: TTLCache[bytes, bool]


@attr.s(slots=True, frozen=True)
Expand All @@ -88,8 +90,8 @@ def __init__(
reactor: IReactorTime,
agent: IAgent,
user_agent: bytes,
well_known_cache: Optional[TTLCache] = None,
had_well_known_cache: Optional[TTLCache] = None,
well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None,
had_well_known_cache: Optional[TTLCache[bytes, bool]] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
Expand Down
9 changes: 5 additions & 4 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,13 @@ def _get_state_for_group_using_cache(self, cache, group, state_filter):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
is_all, known_absent, state_dict_ids = cache.get(group)
cache_entry = cache.get(group)
state_dict_ids = cache_entry.value

if is_all or state_filter.is_full():
if cache_entry.full or state_filter.is_full():
# Either we have everything or want everything, either way
# `is_all` tells us whether we've gotten everything.
return state_filter.filter_state(state_dict_ids), is_all
return state_filter.filter_state(state_dict_ids), cache_entry.full

# tracks whether any of our requested types are missing from the cache
missing_types = False
Expand All @@ -202,7 +203,7 @@ def _get_state_for_group_using_cache(self, cache, group, state_filter):
# There aren't any wild cards, so `concrete_types()` returns the
# complete list of event types we're wanting.
for key in state_filter.concrete_types():
if key not in state_dict_ids and key not in known_absent:
if key not in state_dict_ids and key not in cache_entry.known_absent:
missing_types = True
break

Expand Down
64 changes: 43 additions & 21 deletions synapse/util/caches/dictionary_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,38 @@
import enum
import logging
import threading
from collections import namedtuple
from typing import Any
from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar

import attr

from synapse.util.caches.lrucache import LruCache

logger = logging.getLogger(__name__)


class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
# The type of the cache keys.
KT = TypeVar("KT")
# The type of the dictionary keys.
DKT = TypeVar("DKT")


@attr.s(slots=True)
class DictionaryEntry:
"""Returned when getting an entry from the cache

Attributes:
full (bool): Whether the cache has the full or dict or just some keys.
full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
in `value`
known_absent (set): Keys that were looked up in the dict and were not
known_absent: Keys that were looked up in the dict and were not
there.
value (dict): The full or partial dict value
value: The full or partial dict value
"""

full = attr.ib(type=bool)
known_absent = attr.ib()
value = attr.ib()

def __len__(self):
return len(self.value)

Expand All @@ -45,21 +57,21 @@ class _Sentinel(enum.Enum):
sentinel = object()


class DictionaryCache:
class DictionaryCache(Generic[KT, DKT]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""

def __init__(self, name, max_entries=1000):
def __init__(self, name: str, max_entries: int = 1000):
self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
) # type: LruCache[Any, DictionaryEntry]
) # type: LruCache[KT, DictionaryEntry]

self.name = name
self.sequence = 0
self.thread = None
self.thread = None # type: Optional[threading.Thread]

def check_thread(self):
def check_thread(self) -> None:
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
Expand All @@ -69,12 +81,14 @@ def check_thread(self):
"Cache objects can only be accessed from the main thread"
)

def get(self, key, dict_keys=None):
def get(
self, key: KT, dict_keys: Optional[Iterable[DKT]] = None
) -> DictionaryEntry:
"""Fetch an entry out of the cache

Args:
key
dict_key(list): If given a set of keys then return only those keys
dict_key: If given a set of keys then return only those keys
that exist in the cache.

Returns:
Expand All @@ -95,27 +109,33 @@ def get(self, key, dict_keys=None):

return DictionaryEntry(False, set(), {})

def invalidate(self, key):
def invalidate(self, key: KT) -> None:
self.check_thread()

# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)

def invalidate_all(self):
def invalidate_all(self) -> None:
self.check_thread()
self.sequence += 1
self.cache.clear()

def update(self, sequence, key, value, fetched_keys=None):
def update(
self,
sequence: int,
key: KT,
value: Dict[DKT, Any],
fetched_keys: Optional[Set[DKT]] = None,
) -> None:
"""Updates the entry in the cache

Args:
sequence
key (K)
value (dict[X,Y]): The value to update the cache with.
fetched_keys (None|set[X]): All of the dictionary keys which were
key
value: The value to update the cache with.
fetched_keys: All of the dictionary keys which were
fetched from the database.

If None, this is the complete value for key K. Otherwise, it
Expand All @@ -131,7 +151,9 @@ def update(self, sequence, key, value, fetched_keys=None):
else:
self._update_or_insert(key, value, fetched_keys)

def _update_or_insert(self, key, value, known_absent):
def _update_or_insert(
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed

Expand All @@ -140,5 +162,5 @@ def _update_or_insert(self, key, value, known_absent):
entry.known_absent.update(known_absent)
self.cache[key] = entry

def _insert(self, key, value, known_absent):
def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value)
51 changes: 29 additions & 22 deletions synapse/util/caches/ttlcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import time
from typing import Callable, Dict, Generic, Tuple, TypeVar, Union

import attr
from sortedcontainers import SortedList
Expand All @@ -25,13 +26,17 @@

SENTINEL = object()

T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")

class TTLCache:

class TTLCache(Generic[KT, VT]):
"""A key/value cache implementation where each entry has its own TTL"""

def __init__(self, cache_name, timer=time.time):
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry
self._data = {}
self._data = {} # type: Dict[KT, _CacheEntry]

# the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
Expand All @@ -40,26 +45,27 @@ def __init__(self, cache_name, timer=time.time):

self._metrics = register_cache("ttl", cache_name, self, resizable=False)

def set(self, key, value, ttl):
def set(self, key: KT, value: VT, ttl: float) -> None:
"""Add/update an entry in the cache

Args:
key: key for this entry
value: value for this entry
ttl (float): TTL for this entry, in seconds
ttl: TTL for this entry, in seconds
"""
expiry = self._timer() + ttl

self.expire()
e = self._data.pop(key, SENTINEL)
if e != SENTINEL:
if e is not SENTINEL:
assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)

entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
self._data[key] = entry
self._expiry_list.add(entry)

def get(self, key, default=SENTINEL):
def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""Get a value from the cache

Args:
Expand All @@ -72,23 +78,23 @@ def get(self, key, default=SENTINEL):
"""
self.expire()
e = self._data.get(key, SENTINEL)
if e == SENTINEL:
if e is SENTINEL:
self._metrics.inc_misses()
if default == SENTINEL:
if default is SENTINEL:
raise KeyError(key)
return default
assert isinstance(e, _CacheEntry)
self._metrics.inc_hits()
return e.value

def get_with_expiry(self, key):
def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]:
"""Get a value, and its expiry time, from the cache

Args:
key: key to look up

Returns:
Tuple[Any, float, float]: the value from the cache, the expiry time
and the TTL
A tuple of the value from the cache, the expiry time and the TTL

Raises:
KeyError if the entry is not found
Expand All @@ -102,7 +108,7 @@ def get_with_expiry(self, key):
self._metrics.inc_hits()
return e.value, e.expiry_time, e.ttl

def pop(self, key, default=SENTINEL):
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
"""Remove a value from the cache

If key is in the cache, remove it and return its value, else return default.
Expand All @@ -118,29 +124,30 @@ def pop(self, key, default=SENTINEL):
"""
self.expire()
e = self._data.pop(key, SENTINEL)
if e == SENTINEL:
if e is SENTINEL:
self._metrics.inc_misses()
if default == SENTINEL:
if default is SENTINEL:
raise KeyError(key)
return default
assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)
self._metrics.inc_hits()
return e.value

def __getitem__(self, key):
def __getitem__(self, key: KT) -> VT:
return self.get(key)

def __delitem__(self, key):
def __delitem__(self, key: KT) -> None:
self.pop(key)

def __contains__(self, key):
def __contains__(self, key: KT) -> bool:
return key in self._data

def __len__(self):
def __len__(self) -> int:
self.expire()
return len(self._data)

def expire(self):
def expire(self) -> None:
"""Run the expiry on the cache. Any entries whose expiry times are due will
be removed
"""
Expand All @@ -158,7 +165,7 @@ class _CacheEntry:
"""TTLCache entry"""

# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib()
ttl = attr.ib()
expiry_time = attr.ib(type=float)
ttl = attr.ib(type=float)
key = attr.ib()
value = attr.ib()
22 changes: 8 additions & 14 deletions tests/storage/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,11 @@ def test_get_state_for_event(self):
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache

(
is_all,
known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
cache_entry = self.state_datastore._state_group_cache.get(group)
state_dict_ids = cache_entry.value

self.assertEqual(is_all, True)
self.assertEqual(known_absent, set())
self.assertEqual(cache_entry.full, True)
self.assertEqual(cache_entry.known_absent, set())
self.assertDictEqual(
state_dict_ids,
{
Expand All @@ -403,14 +400,11 @@ def test_get_state_for_event(self):
fetched_keys=((e1.type, e1.state_key),),
)

(
is_all,
known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
cache_entry = self.state_datastore._state_group_cache.get(group)
state_dict_ids = cache_entry.value

self.assertEqual(is_all, False)
self.assertEqual(known_absent, {(e1.type, e1.state_key)})
self.assertEqual(cache_entry.full, False)
self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})

############################################
Expand Down
4 changes: 3 additions & 1 deletion tests/util/test_dict_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def test_simple_cache_hit_full(self):
key = "test_simple_cache_hit_full"

v = self.cache.get(key)
self.assertEqual((False, set(), {}), v)
self.assertIs(v.full, False)
self.assertEqual(v.known_absent, set())
self.assertEqual({}, v.value)

seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}
Expand Down