John-Jiang's picture
init commit
5301c48
import asyncio
import functools
from typing import Callable, TypeVar, Union, cast
from starfish.common.logger import get_logger
logger = get_logger(__name__)
# Type variable for generic return types
T = TypeVar("T")
def retries(max_retries: Union[int, Callable[..., int]] = 3):
"""Decorator to add retry logic to async functions.
Args:
max_retries: Maximum number of retry attempts, either a fixed integer
or a callable that returns an integer when invoked with
the same arguments as the decorated function.
Returns:
Decorated function with retry logic
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
async def wrapper(*args, **kwargs) -> T:
# Determine max retries - either use the fixed value or call the function
retries = max_retries(*args, **kwargs) if callable(max_retries) else max_retries
if retries is None or retries < 1:
logger.warning(f"Invalid max_retries value: {retries}, defaulting to 1")
retries = 1
last_exception = None
for attempt in range(retries):
try:
result = await func(*args, **kwargs)
return result
except Exception as e:
last_exception = e
logger.error(f"Error on attempt {attempt+1}/{retries}: {str(e)}")
if attempt < retries - 1:
logger.info(f"Retrying... (attempt {attempt+2}/{retries})")
else:
logger.error(f"All {retries} attempts failed")
raise last_exception
# This should never be reached due to the raise above
assert last_exception is not None
raise last_exception
return cast(Callable[..., T], wrapper)
return decorator
def to_sync(async_func):
"""Decorator to make async functions synchronous.
This converts an async function into a sync function that can be called normally.
For Jupyter notebooks, it provides a clear error message if nest_asyncio is needed.
"""
@functools.wraps(async_func)
def sync_wrapper(*args, **kwargs):
try:
return asyncio.run(async_func(*args, **kwargs))
except RuntimeError as e:
if "cannot be called from a running event loop" in str(e):
raise RuntimeError(
"This function can't be called in Jupyter without nest_asyncio. " "Please add 'import nest_asyncio; nest_asyncio.apply()' to your notebook."
)
raise
return sync_wrapper
def merge_structured_outputs(*lists):
"""Merge multiple lists of dictionaries element-wise.
Assumes all lists have the same length.
Raises an error if there are key conflicts.
"""
if not lists:
return []
merged_list = []
for elements in zip(*lists, strict=False):
merged_dict = {}
for d in elements:
if any(key in merged_dict for key in d):
raise ValueError(f"Key conflict detected in {elements}")
merged_dict.update(d)
merged_list.append(merged_dict)
return merged_list