import asyncio import logging from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any, Literal, Self from omegaconf import OmegaConf from pydantic import BaseModel from proxy_lite.environments import ( Action, BaseEnvironment, EnvironmentConfigTypes, Environments, EventType, Observation, ) from proxy_lite.logger import create_logger from proxy_lite.recorder import DataRecorder, Run from proxy_lite.solvers import ( BaseSolver, SolverConfigTypes, Solvers, ) @asynccontextmanager async def async_timeout(timeout: float, task_name: str = "timeout"): try: async with asyncio.TaskGroup() as tg: async def timeout_task(): await asyncio.sleep(timeout) raise TimeoutError( f"Operation {task_name} timed out after {timeout} seconds", ) # Create the timeout task timeout_handle = tg.create_task(timeout_task()) try: yield finally: timeout_handle.cancel() except* asyncio.TimeoutError as eg: for e in eg.exceptions: raise e except* Exception as eg: for e in eg.exceptions: raise e class RunnerConfig(BaseModel): environment: EnvironmentConfigTypes solver: SolverConfigTypes save_every_step: bool = True max_steps: int = 50 action_timeout: float = 600.0 environment_timeout: float = 300.0 task_timeout: float = 1800.0 logger_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO" detailed_logger_name: bool = False @classmethod def from_dict(cls, config_dict: dict) -> Self: conf = OmegaConf.create(config_dict) resolved_config = OmegaConf.to_container(conf, resolve=True) # Type cast to satisfy linter - OmegaConf.to_container with resolve=True returns dict for dict inputs return cls(**dict(resolved_config)) # type: ignore @classmethod def from_yaml(cls, yaml_path: str) -> Self: conf = OmegaConf.load(yaml_path) resolved_config = OmegaConf.to_container(conf, resolve=True) # Type cast to satisfy linter - OmegaConf.to_container with resolve=True returns dict for dict inputs return cls(**dict(resolved_config)) # type: ignore class Runner(BaseModel): config: RunnerConfig recorder: DataRecorder | None = None environment: type[BaseEnvironment] | None = None solver: type[BaseSolver] | None = None logger: logging.Logger | None = None _run: Run | None = None class Config: arbitrary_types_allowed = True def model_post_init(self, __context: Any) -> None: super().model_post_init(__context) self.environment = Environments.get(self.config.environment.name) self.solver = Solvers.get(self.config.solver.name) self.recorder = DataRecorder() self.logger = create_logger( name=f"([bold purple]{self.config.solver.name}[/]-[bold blue]{self.config.environment.name}[/])", level=self.config.logger_level, detailed_name=self.config.detailed_logger_name, ) async def run_generator(self, task: str) -> AsyncIterator[Run]: # Assert that attributes are initialized (they are set in model_post_init) assert self.logger is not None, "Logger not initialized" assert self.recorder is not None, "Recorder not initialized" assert self.environment is not None, "Environment not initialized" assert self.solver is not None, "Solver not initialized" async with async_timeout(self.config.task_timeout, "Task"): if self.config.logger_level is not None: self.logger.setLevel(self.config.logger_level) run = self.recorder.initialise_run(task) run.environment = self.config.environment run.solver = self.config.solver self.logger.debug(f"Run intialised: {run.run_id}") event_queue = asyncio.Queue() async with ( self.environment( config=self.config.environment, logger=self.logger, ) as environment, self.solver(config=self.config.solver, logger=self.logger) as solver, ): run.env_info = await environment.get_info() # Initialize environment first (includes login) before solver gets control initial_observation: Observation = await environment.initialise() self.logger.debug("Environment initialised.") # Only after successful environment initialization (login) should solver get initialized await solver.initialise( task, environment.tools, environment.info_for_user, ) self.logger.debug("Solver initialised.") run.solver_history = solver.history # type: ignore await event_queue.put(initial_observation) step_count = 0 while step_count < self.config.max_steps: event = await event_queue.get() self.logger.debug(f"🤖 [bold purple]Processing event:[/] {event.type}") match event.type: case EventType.OBSERVATION: current_observation: Observation = event run.record( observation=current_observation, solver_history=solver.history, #type: ignore ) async with async_timeout( self.config.action_timeout, "Action decision", ): action_result: Action = await solver.act(current_observation) await event_queue.put(action_result) case EventType.ACTION: current_action: Action = event self.logger.debug(f"Tool calls: {current_action.tool_calls}") run.record(action=current_action, solver_history=solver.history) # type: ignore run.complete = await solver.is_complete(current_observation) if self.config.save_every_step: await self.recorder.save(run) if run.complete: run.result = current_action.text self.logger.info(f"🤖 [bold purple]Task complete.[/] ✨ \n{run.result}") break self.logger.debug(f"DEBUG: Using environment_timeout: {self.config.environment_timeout} seconds") async with async_timeout( self.config.environment_timeout, "Environment response", ): next_observation: Observation = await environment.execute_action(current_action) step_count += 1 await event_queue.put(next_observation) yield run if not run.complete: self.logger.warning("🤖 [bold purple]Ran out of steps!") await self.recorder.terminate(run, save=True) yield run async def run(self, task: str) -> Run: async for run in self.run_generator(task): self._run = run return run def run_concurrent(self, tasks: list[str]) -> list[Run | BaseException]: async def gather_runs(): return await asyncio.gather( *[self.run(task) for task in tasks], return_exceptions=True, ) return asyncio.run(gather_runs()) @property def complete(self) -> bool: if self._run is None: raise RuntimeError("Run not initialised") return self._run.complete @property def run_id(self) -> str: if self._run is None: raise RuntimeError("Run not initialised") return self._run.run_id @property def run_result(self) -> str: if self._run is None: raise RuntimeError("Run not initialised") return self._run.result or "" if __name__ == "__main__": from proxy_lite.logger import logger config = RunnerConfig.from_dict( { "environment": { "name": "webbrowser", "homepage": "https://www.google.com", "viewport_width": 1280, "viewport_height": 1920, "screenshot_delay": 1, "headless": False, }, "solver": { "name": "simple", "agent": { "name": "proxy_lite", "client": { "name": "convergence", "model_id": "convergence-ai/proxy-lite", "api_base": "https://convergence-ai-demo-api.hf.space/v1", }, }, }, "max_steps": 150, "action_timeout": 1800, "environment_timeout": 1800, "task_timeout": 18000, "logger_level": "DEBUG", }, ) logger.info(f"🤖 [bold purple]Config:[/] {config}") runner = Runner(config=config) result = asyncio.run(runner.run("Tell me the tesla stock price.")) print(runner.run_result) print(runner.complete)