File size: 4,789 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import Any, Callable, Dict, List, Optional, cast
from starfish.common.logger import get_logger
from starfish.data_factory.config import NOT_COMPLETED_THRESHOLD, TASK_RUNNER_TIMEOUT
from starfish.data_factory.constants import STORAGE_TYPE_LOCAL
from starfish.data_factory.factory_ import Factory
from starfish.data_factory.factory_wrapper import FactoryWrapper, DataFactoryProtocol, P, T
from starfish.data_factory.factory_executor_manager import FactoryExecutorManager
from starfish.data_factory.utils.data_class import FactoryMasterConfig
from starfish.data_factory.utils.state import MutableSharedState

logger = get_logger(__name__)


def data_factory(
    storage: str = STORAGE_TYPE_LOCAL,
    batch_size: int = 1,
    target_count: int = 0,
    dead_queue_threshold: int = 3,
    max_concurrency: int = 10,
    initial_state_values: Optional[Dict[str, Any]] = None,
    on_record_complete: Optional[List[Callable]] = None,
    on_record_error: Optional[List[Callable]] = None,
    show_progress: bool = True,
    task_runner_timeout: int = TASK_RUNNER_TIMEOUT,
    job_run_stop_threshold: int = NOT_COMPLETED_THRESHOLD,
) -> Callable[[Callable[P, T]], DataFactoryProtocol[P, T]]:
    """Decorator for creating data processing pipelines.

    Args:
        storage: Storage backend to use ('local' or 'in_memory')
        batch_size: Number of records to process in each batch
        target_count: Target number of records to generate (0 means process all input)
        max_concurrency: Maximum number of concurrent tasks
        initial_state_values: Initial values for shared state
        on_record_complete: Callbacks to execute after successful record processing
        on_record_error: Callbacks to execute after failed record processing
        show_progress: Whether to display progress bar
        task_runner_timeout: Timeout in seconds for task execution
        job_run_stop_threshold: Threshold for stopping job if too many records fail

    Returns:
        Decorated function with additional execution methods
    """
    # Initialize default values
    on_record_error = on_record_error or []
    on_record_complete = on_record_complete or []
    initial_state_values = initial_state_values or {}

    # Create configuration
    config = FactoryMasterConfig(
        storage=storage,
        batch_size=batch_size,
        target_count=target_count,
        dead_queue_threshold=dead_queue_threshold,
        max_concurrency=max_concurrency,
        show_progress=show_progress,
        task_runner_timeout=task_runner_timeout,
        on_record_complete=on_record_complete,
        on_record_error=on_record_error,
        job_run_stop_threshold=job_run_stop_threshold,
    )

    # Initialize factory instance
    _factory = None

    def decorator(func: Callable[P, T]) -> DataFactoryProtocol[P, T]:
        """Actual decorator that wraps the function."""
        nonlocal _factory
        _factory = _initialize_or_update_factory(_factory, config, func, initial_state_values)
        wrapper = FactoryWrapper(_factory, func)
        return cast(DataFactoryProtocol[P, T], wrapper)

    # Add resume capability as a static method
    data_factory.resume_from_checkpoint = resume_from_checkpoint

    return decorator


def _initialize_or_update_factory(
    factory: Optional[Factory], config: FactoryMasterConfig, func: Callable[P, T], initial_state_values: Dict[str, Any]
) -> Factory:
    """Initialize or update a Factory instance."""
    if factory is None:
        factory = Factory(config, func)
        factory.state = MutableSharedState(initial_data=initial_state_values)
    else:
        factory.config = config
        factory.func = func
        factory.state = MutableSharedState(initial_data=initial_state_values)
    return factory


def resume_from_checkpoint(*args, **kwargs) -> List[dict[str, Any]]:
    """Decorator for creating data processing pipelines.

    Args:
        master_job_id : resume for this master job
        storage: Storage backend to use ('local' or 'in_memory')
        batch_size: Number of records to process in each batch
        target_count: Target number of records to generate (0 means process all input)
        max_concurrency: Maximum number of concurrent tasks
        initial_state_values: Initial values for shared state
        on_record_complete: Callbacks to execute after successful record processing
        on_record_error: Callbacks to execute after failed record processing
        show_progress: Whether to display progress bar
        task_runner_timeout: Timeout in seconds for task execution
        job_run_stop_threshold: Threshold for stopping job if too many records fail

    Returns:
        List[Dict(str,Any)]
    """
    return FactoryExecutorManager.resume(*args, **kwargs)