Spaces:
Running
Running
import asyncio | |
import functools | |
from typing import Callable, TypeVar, Union, cast | |
from starfish.common.logger import get_logger | |
logger = get_logger(__name__) | |
# Type variable for generic return types | |
T = TypeVar("T") | |
def retries(max_retries: Union[int, Callable[..., int]] = 3): | |
"""Decorator to add retry logic to async functions. | |
Args: | |
max_retries: Maximum number of retry attempts, either a fixed integer | |
or a callable that returns an integer when invoked with | |
the same arguments as the decorated function. | |
Returns: | |
Decorated function with retry logic | |
""" | |
def decorator(func: Callable[..., T]) -> Callable[..., T]: | |
async def wrapper(*args, **kwargs) -> T: | |
# Determine max retries - either use the fixed value or call the function | |
retries = max_retries(*args, **kwargs) if callable(max_retries) else max_retries | |
if retries is None or retries < 1: | |
logger.warning(f"Invalid max_retries value: {retries}, defaulting to 1") | |
retries = 1 | |
last_exception = None | |
for attempt in range(retries): | |
try: | |
result = await func(*args, **kwargs) | |
return result | |
except Exception as e: | |
last_exception = e | |
logger.error(f"Error on attempt {attempt+1}/{retries}: {str(e)}") | |
if attempt < retries - 1: | |
logger.info(f"Retrying... (attempt {attempt+2}/{retries})") | |
else: | |
logger.error(f"All {retries} attempts failed") | |
raise last_exception | |
# This should never be reached due to the raise above | |
assert last_exception is not None | |
raise last_exception | |
return cast(Callable[..., T], wrapper) | |
return decorator | |
def to_sync(async_func): | |
"""Decorator to make async functions synchronous. | |
This converts an async function into a sync function that can be called normally. | |
For Jupyter notebooks, it provides a clear error message if nest_asyncio is needed. | |
""" | |
def sync_wrapper(*args, **kwargs): | |
try: | |
return asyncio.run(async_func(*args, **kwargs)) | |
except RuntimeError as e: | |
if "cannot be called from a running event loop" in str(e): | |
raise RuntimeError( | |
"This function can't be called in Jupyter without nest_asyncio. " "Please add 'import nest_asyncio; nest_asyncio.apply()' to your notebook." | |
) | |
raise | |
return sync_wrapper | |
def merge_structured_outputs(*lists): | |
"""Merge multiple lists of dictionaries element-wise. | |
Assumes all lists have the same length. | |
Raises an error if there are key conflicts. | |
""" | |
if not lists: | |
return [] | |
merged_list = [] | |
for elements in zip(*lists, strict=False): | |
merged_dict = {} | |
for d in elements: | |
if any(key in merged_dict for key in d): | |
raise ValueError(f"Key conflict detected in {elements}") | |
merged_dict.update(d) | |
merged_list.append(merged_dict) | |
return merged_list | |