import asyncio import time from typing import Any, Callable, Dict, List from copy import deepcopy from starfish.common.logger import get_logger from starfish.data_factory.config import TASK_RUNNER_TIMEOUT from starfish.data_factory.constants import IDX from starfish.data_factory.utils.errors import TimeoutErrorAsyncio logger = get_logger(__name__) # from starfish.common.logger_new import logger class TaskRunner: """A task runner that executes asynchronous tasks with retry logic and timeout handling. Attributes: max_retries: Maximum number of retry attempts for failed tasks timeout: Maximum execution time allowed for each task master_job_id: Optional identifier for the parent job """ def __init__(self, max_retries: int = 1, timeout: int = TASK_RUNNER_TIMEOUT, master_job_id: str = None): """Initializes the TaskRunner with configuration parameters. Args: max_retries: Maximum number of retry attempts (default: 1) timeout: Timeout in seconds for task execution (default: TASK_RUNNER_TIMEOUT) master_job_id: Optional identifier for the parent job (default: None) """ self.max_retries = max_retries self.timeout = timeout self.master_job_id = master_job_id async def run_task(self, func: Callable, input_data: Dict, input_data_idx: str) -> List[Any]: """Process a single task with asyncio.""" retries = 0 start_time = time.time() result = None # Create a copy of input_data without 'IDX' tp prevent insertion of IDX due to race condition copy_input = deepcopy({k: v for k, v in input_data.items() if k != IDX}) while retries <= self.max_retries: try: result = await asyncio.wait_for(func(**copy_input), timeout=self.timeout) logger.debug(f"Task execution completed in {time.time() - start_time:.2f} seconds") break except asyncio.TimeoutError as timeout_error: logger.error( f"Task execution timed out after {self.timeout} seconds, " "please set the timeout in data_factory decorator like this: " "task_runner_timeout=60" ) raise TimeoutErrorAsyncio(f"Task execution timed out after {self.timeout} seconds") from timeout_error except Exception as e: retries += 1 if retries > self.max_retries: # logger.error(f"Task execution failed after {self.max_retries} retries") raise e logger.debug(f"Retry attempt {retries}/{self.max_retries} for input data index {input_data_idx}") await asyncio.sleep(1**retries) # exponential backoff return result