File size: 3,320 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import asyncio
import functools
from typing import Callable, TypeVar, Union, cast

from starfish.common.logger import get_logger

logger = get_logger(__name__)

# Type variable for generic return types
T = TypeVar("T")


def retries(max_retries: Union[int, Callable[..., int]] = 3):
    """Decorator to add retry logic to async functions.

    Args:
        max_retries: Maximum number of retry attempts, either a fixed integer
                    or a callable that returns an integer when invoked with
                    the same arguments as the decorated function.

    Returns:
        Decorated function with retry logic
    """

    def decorator(func: Callable[..., T]) -> Callable[..., T]:
        @functools.wraps(func)
        async def wrapper(*args, **kwargs) -> T:
            # Determine max retries - either use the fixed value or call the function
            retries = max_retries(*args, **kwargs) if callable(max_retries) else max_retries

            if retries is None or retries < 1:
                logger.warning(f"Invalid max_retries value: {retries}, defaulting to 1")
                retries = 1

            last_exception = None

            for attempt in range(retries):
                try:
                    result = await func(*args, **kwargs)
                    return result
                except Exception as e:
                    last_exception = e
                    logger.error(f"Error on attempt {attempt+1}/{retries}: {str(e)}")

                    if attempt < retries - 1:
                        logger.info(f"Retrying... (attempt {attempt+2}/{retries})")
                    else:
                        logger.error(f"All {retries} attempts failed")
                        raise last_exception

            # This should never be reached due to the raise above
            assert last_exception is not None
            raise last_exception

        return cast(Callable[..., T], wrapper)

    return decorator


def to_sync(async_func):
    """Decorator to make async functions synchronous.

    This converts an async function into a sync function that can be called normally.
    For Jupyter notebooks, it provides a clear error message if nest_asyncio is needed.

    """

    @functools.wraps(async_func)
    def sync_wrapper(*args, **kwargs):
        try:
            return asyncio.run(async_func(*args, **kwargs))
        except RuntimeError as e:
            if "cannot be called from a running event loop" in str(e):
                raise RuntimeError(
                    "This function can't be called in Jupyter without nest_asyncio. " "Please add 'import nest_asyncio; nest_asyncio.apply()' to your notebook."
                )
            raise

    return sync_wrapper


def merge_structured_outputs(*lists):
    """Merge multiple lists of dictionaries element-wise.
    Assumes all lists have the same length.
    Raises an error if there are key conflicts.
    """
    if not lists:
        return []

    merged_list = []
    for elements in zip(*lists, strict=False):
        merged_dict = {}
        for d in elements:
            if any(key in merged_dict for key in d):
                raise ValueError(f"Key conflict detected in {elements}")
            merged_dict.update(d)
        merged_list.append(merged_dict)
    return merged_list