metric / utils.py
Elron's picture
Upload folder using huggingface_hub
fe70438 verified
raw
history blame
10.4 kB
import copy
import importlib.util
import json
import os
import re
import threading
from collections import OrderedDict
from functools import lru_cache
from typing import Any, Dict
from .text_utils import is_made_of_sub_strings
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
class LRUCache:
"""An LRU (Least Recently Used) cache that stores a limited number of items.
This cache automatically removes the least recently used item when it
exceeds its max size. It behaves similarly to a dictionary, allowing
items to be added and accessed using `[]` syntax.
This implementation is thread-safe, using a lock to ensure that only one
thread can modify or access the cache at any time.
Attributes:
max_size (int): The maximum number of items to store in the cache.
Items exceeding this limit are automatically removed based on least
recent usage.
"""
def __init__(self, max_size=10):
self._max_size = max_size
self._cache = OrderedDict()
self._lock = threading.Lock() # Lock to ensure thread safety
@property
def max_size(self):
with self._lock:
return self._max_size
@max_size.setter
def max_size(self, size):
with self._lock:
self._max_size = size
# Adjust the cache if the new size is smaller than the current number of items
while len(self._cache) > self._max_size:
self._cache.popitem(last=False)
def __setitem__(self, key, value):
with self._lock:
# If the key already exists, remove it first to refresh its order
if key in self._cache:
self._cache.pop(key)
# Add the new item to the cache (most recently used)
self._cache[key] = value
# If the cache exceeds the specified size, remove the least recently used item
while len(self._cache) > self._max_size:
self._cache.popitem(last=False)
def __getitem__(self, key):
with self._lock:
if key in self._cache:
# Move the accessed item to the end (mark as most recently used)
value = self._cache.pop(key)
self._cache[key] = value
return value
raise KeyError(f"{key} not found in cache")
def set(self, key, value):
"""Sets a key-value pair in the cache."""
with self._lock:
if key in self._cache:
self._cache.pop(key)
self._cache[key] = value
while len(self._cache) > self._max_size:
self._cache.popitem(last=False)
def get(self, key, default=None):
"""Gets a value from the cache by key, returning `default` if the key is not found."""
with self._lock:
if key in self._cache:
value = self._cache.pop(key)
self._cache[key] = value # Move item to end to mark as recently used
return value
return default
def __contains__(self, key):
with self._lock:
return key in self._cache
def __len__(self):
with self._lock:
return len(self._cache)
def __repr__(self):
with self._lock:
return f"LRUCache(max_size={self._max_size}, items={list(self._cache.items())})"
def flatten_dict(
d: Dict[str, Any], parent_key: str = "", sep: str = "_"
) -> Dict[str, Any]:
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
@lru_cache(maxsize=None)
def artifacts_json_cache(artifact_path):
return load_json(artifact_path)
def load_json(path):
with open(path) as f:
try:
return json.load(f)
except json.decoder.JSONDecodeError as e:
with open(path) as f:
file_content = "\n".join(f.readlines())
raise RuntimeError(
f"Failed to decode json file at '{path}' with file content:\n{file_content}"
) from e
def save_to_file(path, data):
with open(path, "w") as f:
f.write(data)
f.write("\n")
def json_dump(data):
return json.dumps(data, indent=4, ensure_ascii=False)
def is_package_installed(package_name):
"""Check if a package is installed.
Parameters:
- package_name (str): The name of the package to check.
Returns:
- bool: True if the package is installed, False otherwise.
"""
unitxt_pkg = importlib.util.find_spec(package_name)
return unitxt_pkg is not None
def is_module_available(module_name):
"""Check if a module is available in the current Python environment.
Parameters:
- module_name (str): The name of the module to check.
Returns:
- bool: True if the module is available, False otherwise.
"""
try:
__import__(module_name)
return True
except ImportError:
return False
def remove_numerics_and_quoted_texts(input_str):
# Remove floats first to avoid leaving stray periods
input_str = re.sub(r"\d+\.\d+", "", input_str)
# Remove integers
input_str = re.sub(r"\d+", "", input_str)
# Remove strings in single quotes
input_str = re.sub(r"'.*?'", "", input_str)
# Remove strings in double quotes
input_str = re.sub(r'".*?"', "", input_str)
# Remove strings in triple quotes
return re.sub(r'""".*?"""', "", input_str, flags=re.DOTALL)
def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
"""Evaluates a given expression in a restricted environment, allowing only specified tokens and context variables.
Args:
expression (str): The expression to evaluate.
context (dict): A dictionary mapping variable names to their values, which
can be used in the expression.
allowed_tokens (list): A list of strings representing allowed tokens (such as
operators, function names, etc.) that can be used in the expression.
Returns:
any: The result of evaluating the expression.
Raises:
ValueError: If the expression contains tokens not in the allowed list or context keys.
Note:
This function should be used carefully, as it employs `eval`, which can
execute arbitrary code. The function attempts to mitigate security risks
by restricting the available tokens and not exposing built-in functions.
"""
allowed_sub_strings = list(context.keys()) + allowed_tokens
if is_made_of_sub_strings(
remove_numerics_and_quoted_texts(expression), allowed_sub_strings
):
return eval(expression, {"__builtins__": {}}, context)
raise ValueError(
f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
)
def import_module_from_file(file_path):
# Get the module name (file name without extension)
module_name = os.path.splitext(os.path.basename(file_path))[0]
# Create a module specification
spec = importlib.util.spec_from_file_location(module_name, file_path)
# Create a new module based on the specification
module = importlib.util.module_from_spec(spec)
# Load the module
spec.loader.exec_module(module)
return module
def deep_copy(obj):
"""Creates a deep copy of the given object.
Args:
obj: The object to be deep copied.
Returns:
A deep copy of the original object.
"""
return copy.deepcopy(obj)
def shallow_copy(obj):
"""Creates a shallow copy of the given object.
Args:
obj: The object to be shallow copied.
Returns:
A shallow copy of the original object.
"""
return copy.copy(obj)
def recursive_copy(obj, internal_copy=None):
"""Recursively copies an object with a selective copy method.
For `list`, `dict`, and `tuple` types, it recursively copies their contents.
For other types, it uses the provided `internal_copy` function if available.
Objects without a `copy` method are returned as is.
Args:
obj: The object to be copied.
internal_copy (callable, optional): The copy function to use for non-container objects.
If `None`, objects without a `copy` method are returned as is.
Returns:
The recursively copied object.
"""
# Handle dictionaries
if isinstance(obj, dict):
return type(obj)(
{key: recursive_copy(value, internal_copy) for key, value in obj.items()}
)
# Handle named tuples
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
return type(obj)(*(recursive_copy(item, internal_copy) for item in obj))
# Handle tuples and lists
if isinstance(obj, (tuple, list)):
return type(obj)(recursive_copy(item, internal_copy) for item in obj)
if internal_copy is None:
return obj
return internal_copy(obj)
def recursive_deep_copy(obj):
"""Performs a recursive deep copy of the given object.
This function uses `deep_copy` as the internal copy method for non-container objects.
Args:
obj: The object to be deep copied.
Returns:
A recursively deep-copied version of the original object.
"""
return recursive_copy(obj, deep_copy)
def recursive_shallow_copy(obj):
"""Performs a recursive shallow copy of the given object.
This function uses `shallow_copy` as the internal copy method for non-container objects.
Args:
obj: The object to be shallow copied.
Returns:
A recursively shallow-copied version of the original object.
"""
return recursive_copy(obj, shallow_copy)
class LongString(str):
def __new__(cls, value, *, repr_str=None):
obj = super().__new__(cls, value)
obj._repr_str = repr_str
return obj
def __repr__(self):
if self._repr_str is not None:
return self._repr_str
return super().__repr__()