File size: 3,786 Bytes
054900e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from __future__ import annotations
from functools import wraps
from typing import TYPE_CHECKING, Any, TypeVar
from bot.cache.serialization import AbstractSerializer, PickleSerializer
from bot.core.loader import redis_client
if TYPE_CHECKING:
from collections.abc import Awaitable
from datetime import timedelta
from typing import Callable
from redis.asyncio import Redis
DEFAULT_TTL = 10
_Func = TypeVar("_Func")
Args = str | int # basically only user_id is used as identifier
Kwargs = Any
def build_key(*args: Args, **kwargs: Kwargs) -> str:
"""Build a string key based on provided arguments and keyword arguments."""
args_str = ":".join(map(str, args))
kwargs_str = ":".join(f"{key}={value}" for key, value in sorted(kwargs.items()))
return f"{args_str}:{kwargs_str}"
async def set_redis_value(
key: bytes | str,
value: bytes | str,
ttl: int | timedelta | None = DEFAULT_TTL,
is_transaction: bool = False,
) -> None:
"""Set a value in Redis with an optional time-to-live (TTL)."""
async with redis_client.pipeline(transaction=is_transaction) as pipeline:
await pipeline.set(key, value)
if ttl:
await pipeline.expire(key, ttl)
await pipeline.execute()
def cached(
ttl: int | timedelta = DEFAULT_TTL,
namespace: str = "main",
cache: Redis = redis_client,
key_builder: Callable[..., str] = build_key,
serializer: AbstractSerializer | None = None,
) -> Callable[[Callable[..., Awaitable[_Func]]], Callable[..., Awaitable[_Func]]]:
"""Caches the function's return value into a key generated with module_name, function_name, and args.
Args:
ttl (int | timedelta): Time-to-live for the cached value.
namespace (str): Namespace for cache keys.
cache (Redis): Redis instance for storing cached data.
key_builder (Callable[..., str]): Function to build cache keys.
serializer (AbstractSerializer | None): Serializer for cache data.
Returns:
Callable: A decorator that wraps the original function with caching logic.
"""
if serializer is None:
serializer = PickleSerializer()
def decorator(func: Callable[..., Awaitable[_Func]]) -> Callable[..., Awaitable[_Func]]:
@wraps(func)
async def wrapper(*args: Args, **kwargs: Kwargs) -> Any:
key = key_builder(*args, **kwargs)
key = f"{namespace}:{func.__module__}:{func.__name__}:{key}"
# Check if the key is in the cache
cached_value = await cache.get(key)
if cached_value is not None:
return serializer.deserialize(cached_value)
# If not in cache, call the original function
result = await func(*args, **kwargs)
# Store the result in Redis
await set_redis_value(
key=key,
value=serializer.serialize(result),
ttl=ttl,
)
return result
return wrapper
return decorator
async def clear_cache(
func: Callable[..., Awaitable[Any]],
*args: Args,
**kwargs: Kwargs,
) -> None:
"""Clear the cache for a specific function and arguments.
Parameters
----------
- func (Callable): The target function for which the cache needs to be cleared.
- args (Args): Positional arguments passed to the function.
- kwargs (Kwargs): Keyword arguments passed to the function.
Keyword Arguments:
- namespace (str, optional): A string indicating the namespace for the cache. Defaults to "main".
"""
namespace = kwargs.get("namespace", "main")
key = build_key(*args, **kwargs)
key = f"{namespace}:{func.__module__}:{func.__name__}:{key}"
await redis_client.delete(key)
|