Trisha Tomy
fixes + cleanup flow
c73909d
import asyncio
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Literal, Self
from omegaconf import OmegaConf
from pydantic import BaseModel
from proxy_lite.environments import (
Action,
BaseEnvironment,
EnvironmentConfigTypes,
Environments,
EventType,
Observation,
)
from proxy_lite.logger import create_logger
from proxy_lite.recorder import DataRecorder, Run
from proxy_lite.solvers import (
BaseSolver,
SolverConfigTypes,
Solvers,
)
@asynccontextmanager
async def async_timeout(timeout: float, task_name: str = "timeout"):
try:
async with asyncio.TaskGroup() as tg:
async def timeout_task():
await asyncio.sleep(timeout)
raise TimeoutError(
f"Operation {task_name} timed out after {timeout} seconds",
)
# Create the timeout task
timeout_handle = tg.create_task(timeout_task())
try:
yield
finally:
timeout_handle.cancel()
except* asyncio.TimeoutError as eg:
for e in eg.exceptions:
raise e
except* Exception as eg:
for e in eg.exceptions:
raise e
class RunnerConfig(BaseModel):
environment: EnvironmentConfigTypes
solver: SolverConfigTypes
save_every_step: bool = True
max_steps: int = 50
action_timeout: float = 600.0
environment_timeout: float = 300.0
task_timeout: float = 1800.0
logger_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
detailed_logger_name: bool = False
@classmethod
def from_dict(cls, config_dict: dict) -> Self:
conf = OmegaConf.create(config_dict)
resolved_config = OmegaConf.to_container(conf, resolve=True)
# Type cast to satisfy linter - OmegaConf.to_container with resolve=True returns dict for dict inputs
return cls(**dict(resolved_config)) # type: ignore
@classmethod
def from_yaml(cls, yaml_path: str) -> Self:
conf = OmegaConf.load(yaml_path)
resolved_config = OmegaConf.to_container(conf, resolve=True)
# Type cast to satisfy linter - OmegaConf.to_container with resolve=True returns dict for dict inputs
return cls(**dict(resolved_config)) # type: ignore
class Runner(BaseModel):
config: RunnerConfig
recorder: DataRecorder | None = None
environment: type[BaseEnvironment] | None = None
solver: type[BaseSolver] | None = None
logger: logging.Logger | None = None
_run: Run | None = None
class Config:
arbitrary_types_allowed = True
def model_post_init(self, __context: Any) -> None:
super().model_post_init(__context)
self.environment = Environments.get(self.config.environment.name)
self.solver = Solvers.get(self.config.solver.name)
self.recorder = DataRecorder()
self.logger = create_logger(
name=f"([bold purple]{self.config.solver.name}[/]-[bold blue]{self.config.environment.name}[/])",
level=self.config.logger_level,
detailed_name=self.config.detailed_logger_name,
)
async def run_generator(self, task: str) -> AsyncIterator[Run]:
# Assert that attributes are initialized (they are set in model_post_init)
assert self.logger is not None, "Logger not initialized"
assert self.recorder is not None, "Recorder not initialized"
assert self.environment is not None, "Environment not initialized"
assert self.solver is not None, "Solver not initialized"
async with async_timeout(self.config.task_timeout, "Task"):
if self.config.logger_level is not None:
self.logger.setLevel(self.config.logger_level)
run = self.recorder.initialise_run(task)
run.environment = self.config.environment
run.solver = self.config.solver
self.logger.debug(f"Run intialised: {run.run_id}")
event_queue = asyncio.Queue()
async with (
self.environment(
config=self.config.environment,
logger=self.logger,
) as environment,
self.solver(config=self.config.solver, logger=self.logger) as solver,
):
run.env_info = await environment.get_info()
# Initialize environment first (includes login) before solver gets control
initial_observation: Observation = await environment.initialise()
self.logger.debug("Environment initialised.")
# Only after successful environment initialization (login) should solver get initialized
await solver.initialise(
task,
environment.tools,
environment.info_for_user,
)
self.logger.debug("Solver initialised.")
run.solver_history = solver.history # type: ignore
await event_queue.put(initial_observation)
step_count = 0
while step_count < self.config.max_steps:
event = await event_queue.get()
self.logger.debug(f"πŸ€– [bold purple]Processing event:[/] {event.type}")
match event.type:
case EventType.OBSERVATION:
current_observation: Observation = event
run.record(
observation=current_observation,
solver_history=solver.history, #type: ignore
)
async with async_timeout(
self.config.action_timeout,
"Action decision",
):
action_result: Action = await solver.act(current_observation)
await event_queue.put(action_result)
case EventType.ACTION:
current_action: Action = event
self.logger.debug(f"Tool calls: {current_action.tool_calls}")
run.record(action=current_action, solver_history=solver.history) # type: ignore
run.complete = await solver.is_complete(current_observation)
if self.config.save_every_step:
await self.recorder.save(run)
if run.complete:
run.result = current_action.text
self.logger.info(f"πŸ€– [bold purple]Task complete.[/] ✨ \n{run.result}")
break
self.logger.debug(f"DEBUG: Using environment_timeout: {self.config.environment_timeout} seconds")
async with async_timeout(
self.config.environment_timeout,
"Environment response",
):
next_observation: Observation = await environment.execute_action(current_action)
step_count += 1
await event_queue.put(next_observation)
yield run
if not run.complete:
self.logger.warning("πŸ€– [bold purple]Ran out of steps!")
await self.recorder.terminate(run, save=True)
yield run
async def run(self, task: str) -> Run:
async for run in self.run_generator(task):
self._run = run
return run
def run_concurrent(self, tasks: list[str]) -> list[Run | BaseException]:
async def gather_runs():
return await asyncio.gather(
*[self.run(task) for task in tasks],
return_exceptions=True,
)
return asyncio.run(gather_runs())
@property
def complete(self) -> bool:
if self._run is None:
raise RuntimeError("Run not initialised")
return self._run.complete
@property
def run_id(self) -> str:
if self._run is None:
raise RuntimeError("Run not initialised")
return self._run.run_id
@property
def run_result(self) -> str:
if self._run is None:
raise RuntimeError("Run not initialised")
return self._run.result or ""
if __name__ == "__main__":
from proxy_lite.logger import logger
config = RunnerConfig.from_dict(
{
"environment": {
"name": "webbrowser",
"homepage": "https://www.google.com",
"viewport_width": 1280,
"viewport_height": 1920,
"screenshot_delay": 1,
"headless": False,
},
"solver": {
"name": "simple",
"agent": {
"name": "proxy_lite",
"client": {
"name": "convergence",
"model_id": "convergence-ai/proxy-lite",
"api_base": "https://convergence-ai-demo-api.hf.space/v1",
},
},
},
"max_steps": 150,
"action_timeout": 1800,
"environment_timeout": 1800,
"task_timeout": 18000,
"logger_level": "DEBUG",
},
)
logger.info(f"πŸ€– [bold purple]Config:[/] {config}")
runner = Runner(config=config)
result = asyncio.run(runner.run("Tell me the tesla stock price."))
print(runner.run_result)
print(runner.complete)