File size: 2,846 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import asyncio
import time
from typing import Any, Callable, Dict, List
from copy import deepcopy
from starfish.common.logger import get_logger
from starfish.data_factory.config import TASK_RUNNER_TIMEOUT
from starfish.data_factory.constants import IDX
from starfish.data_factory.utils.errors import TimeoutErrorAsyncio

logger = get_logger(__name__)


# from starfish.common.logger_new import logger
class TaskRunner:
    """A task runner that executes asynchronous tasks with retry logic and timeout handling.

    Attributes:
        max_retries: Maximum number of retry attempts for failed tasks
        timeout: Maximum execution time allowed for each task
        master_job_id: Optional identifier for the parent job
    """

    def __init__(self, max_retries: int = 1, timeout: int = TASK_RUNNER_TIMEOUT, master_job_id: str = None):
        """Initializes the TaskRunner with configuration parameters.

        Args:
            max_retries: Maximum number of retry attempts (default: 1)
            timeout: Timeout in seconds for task execution (default: TASK_RUNNER_TIMEOUT)
            master_job_id: Optional identifier for the parent job (default: None)
        """
        self.max_retries = max_retries
        self.timeout = timeout
        self.master_job_id = master_job_id

    async def run_task(self, func: Callable, input_data: Dict, input_data_idx: str) -> List[Any]:
        """Process a single task with asyncio."""
        retries = 0
        start_time = time.time()
        result = None
        # Create a copy of input_data without 'IDX' tp prevent insertion of IDX due to race condition
        copy_input = deepcopy({k: v for k, v in input_data.items() if k != IDX})
        while retries <= self.max_retries:
            try:
                result = await asyncio.wait_for(func(**copy_input), timeout=self.timeout)
                logger.debug(f"Task execution completed in {time.time() - start_time:.2f} seconds")
                break
            except asyncio.TimeoutError as timeout_error:
                logger.error(
                    f"Task execution timed out after {self.timeout} seconds, "
                    "please set the timeout in data_factory decorator like this: "
                    "task_runner_timeout=60"
                )
                raise TimeoutErrorAsyncio(f"Task execution timed out after {self.timeout} seconds") from timeout_error
            except Exception as e:
                retries += 1
                if retries > self.max_retries:
                    # logger.error(f"Task execution failed after {self.max_retries} retries")
                    raise e
                logger.debug(f"Retry attempt {retries}/{self.max_retries} for input data index {input_data_idx}")
                await asyncio.sleep(1**retries)  # exponential backoff

        return result