Spaces:
Configuration error
Configuration error
import itertools | |
from typing import Optional | |
class TaggedCache: | |
def __init__(self, tag_settings: Optional[dict]=None): | |
self._tag_settings = tag_settings or {} # tag cache size | |
self._data = {} | |
def __getitem__(self, key): | |
for tag_data in self._data.values(): | |
if key in tag_data: | |
return tag_data[key] | |
raise KeyError(f'Key `{key}` does not exist') | |
def __setitem__(self, key, value: tuple): | |
# value: (tag: str, (islist: bool, data: *)) | |
# if key already exists, pop old value | |
for tag_data in self._data.values(): | |
if key in tag_data: | |
tag_data.pop(key, None) | |
break | |
tag = value[0] | |
if tag not in self._data: | |
try: | |
from cachetools import LRUCache | |
default_size = 20 | |
if 'ckpt' in tag: | |
default_size = 5 | |
elif tag in ['latent', 'image']: | |
default_size = 100 | |
self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size)) | |
except (ImportError, ModuleNotFoundError): | |
# TODO: implement a simple lru dict | |
self._data[tag] = {} | |
self._data[tag][key] = value | |
def __delitem__(self, key): | |
for tag_data in self._data.values(): | |
if key in tag_data: | |
del tag_data[key] | |
return | |
raise KeyError(f'Key `{key}` does not exist') | |
def __contains__(self, key): | |
return any(key in tag_data for tag_data in self._data.values()) | |
def items(self): | |
yield from itertools.chain(*map(lambda x :x.items(), self._data.values())) | |
def get(self, key, default=None): | |
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" | |
for tag_data in self._data.values(): | |
if key in tag_data: | |
return tag_data[key] | |
return default | |
def clear(self): | |
# clear all cache | |
self._data = {} | |
cache_settings = {} | |
cache = TaggedCache(cache_settings) | |
cache_count = {} | |
def update_cache(k, tag, v): | |
cache[k] = (tag, v) | |
cnt = cache_count.get(k) | |
if cnt is None: | |
cnt = 0 | |
cache_count[k] = cnt | |
else: | |
cache_count[k] += 1 | |
def remove_cache(key): | |
global cache | |
if key == '*': | |
cache = TaggedCache(cache_settings) | |
elif key in cache: | |
del cache[key] | |
else: | |
print(f"invalid {key}") |