You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
159 lines
4.5 KiB
159 lines
4.5 KiB
""" |
|
Utility functions. |
|
""" |
|
|
|
from collections import OrderedDict, abc |
|
from typing import List, Iterator, TypeVar, Generic, Union, Optional, Type, \ |
|
TYPE_CHECKING |
|
|
|
K = TypeVar('K') |
|
V = TypeVar('V') |
|
D = TypeVar('D') |
|
T = TypeVar('T') |
|
|
|
__all__ = ('LRUCache', 'freeze', 'with_typehint') |
|
|
|
|
|
def with_typehint(baseclass: Type[T]): |
|
""" |
|
Add type hints from a specified class to a base class: |
|
|
|
>>> class Foo(with_typehint(Bar)): |
|
... pass |
|
|
|
This would add type hints from class ``Bar`` to class ``Foo``. |
|
|
|
Note that while PyCharm and Pyright (for VS Code) understand this pattern, |
|
MyPy does not. For that reason TinyDB has a MyPy plugin in |
|
``mypy_plugin.py`` that adds support for this pattern. |
|
""" |
|
if TYPE_CHECKING: |
|
# In the case of type checking: pretend that the target class inherits |
|
# from the specified base class |
|
return baseclass |
|
|
|
# Otherwise: just inherit from `object` like a regular Python class |
|
return object |
|
|
|
|
|
class LRUCache(abc.MutableMapping, Generic[K, V]): |
|
""" |
|
A least-recently used (LRU) cache with a fixed cache size. |
|
|
|
This class acts as a dictionary but has a limited size. If the number of |
|
entries in the cache exceeds the cache size, the least-recently accessed |
|
entry will be discarded. |
|
|
|
This is implemented using an ``OrderedDict``. On every access the accessed |
|
entry is moved to the front by re-inserting it into the ``OrderedDict``. |
|
When adding an entry and the cache size is exceeded, the last entry will |
|
be discarded. |
|
""" |
|
|
|
def __init__(self, capacity=None) -> None: |
|
self.capacity = capacity |
|
self.cache: OrderedDict[K, V] = OrderedDict() |
|
|
|
@property |
|
def lru(self) -> List[K]: |
|
return list(self.cache.keys()) |
|
|
|
@property |
|
def length(self) -> int: |
|
return len(self.cache) |
|
|
|
def clear(self) -> None: |
|
self.cache.clear() |
|
|
|
def __len__(self) -> int: |
|
return self.length |
|
|
|
def __contains__(self, key: object) -> bool: |
|
return key in self.cache |
|
|
|
def __setitem__(self, key: K, value: V) -> None: |
|
self.set(key, value) |
|
|
|
def __delitem__(self, key: K) -> None: |
|
del self.cache[key] |
|
|
|
def __getitem__(self, key) -> V: |
|
value = self.get(key) |
|
if value is None: |
|
raise KeyError(key) |
|
|
|
return value |
|
|
|
def __iter__(self) -> Iterator[K]: |
|
return iter(self.cache) |
|
|
|
def get(self, key: K, default: Optional[D] = None) -> Optional[Union[V, D]]: |
|
value = self.cache.get(key) |
|
|
|
if value is not None: |
|
self.cache.move_to_end(key, last=True) |
|
|
|
return value |
|
|
|
return default |
|
|
|
def set(self, key: K, value: V): |
|
if self.cache.get(key): |
|
self.cache.move_to_end(key, last=True) |
|
|
|
else: |
|
self.cache[key] = value |
|
|
|
# Check, if the cache is full and we have to remove old items |
|
# If the queue is of unlimited size, self.capacity is NaN and |
|
# x > NaN is always False in Python and the cache won't be cleared. |
|
if self.capacity is not None and self.length > self.capacity: |
|
self.cache.popitem(last=False) |
|
|
|
|
|
class FrozenDict(dict): |
|
""" |
|
An immutable dictionary. |
|
|
|
This is used to generate stable hashes for queries that contain dicts. |
|
Usually, Python dicts are not hashable because they are mutable. This |
|
class removes the mutability and implements the ``__hash__`` method. |
|
""" |
|
|
|
def __hash__(self): |
|
# Calculate the has by hashing a tuple of all dict items |
|
return hash(tuple(sorted(self.items()))) |
|
|
|
def _immutable(self, *args, **kws): |
|
raise TypeError('object is immutable') |
|
|
|
# Disable write access to the dict |
|
__setitem__ = _immutable |
|
__delitem__ = _immutable |
|
clear = _immutable |
|
setdefault = _immutable # type: ignore |
|
popitem = _immutable |
|
|
|
def update(self, e=None, **f): |
|
raise TypeError('object is immutable') |
|
|
|
def pop(self, k, d=None): |
|
raise TypeError('object is immutable') |
|
|
|
|
|
def freeze(obj): |
|
""" |
|
Freeze an object by making it immutable and thus hashable. |
|
""" |
|
if isinstance(obj, dict): |
|
# Transform dicts into ``FrozenDict``s |
|
return FrozenDict((k, freeze(v)) for k, v in obj.items()) |
|
elif isinstance(obj, list): |
|
# Transform lists into tuples |
|
return tuple(freeze(el) for el in obj) |
|
elif isinstance(obj, set): |
|
# Transform sets into ``frozenset``s |
|
return frozenset(obj) |
|
else: |
|
# Don't handle all other objects |
|
return obj
|
|
|