Spaces:
Paused
Paused
| import os | |
| import json | |
| import yaml | |
| import logging | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login | |
| from selenium import webdriver | |
| from selenium.webdriver.common.by import By | |
| from selenium.webdriver.common.keys import Keys | |
| from io import BytesIO | |
| from PIL import Image | |
| from datetime import datetime | |
| import tempfile | |
| import helium | |
| from smolagents import CodeAgent, LiteLLMModel | |
| from smolagents.agents import ActionStep | |
| from tools.search_item_ctrl_f import SearchItemCtrlFTool | |
| from tools.go_back import GoBackTool | |
| from tools.close_popups import ClosePopupsTool | |
| from tools.final_answer import FinalAnswerTool | |
| from GRADIO_UI import GradioUI | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| hf_token = os.getenv("HF_TOKEN") | |
| gemini_api_key = os.getenv("GOOGLE_API_KEY") | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN environment variable not set.") | |
| if not gemini_api_key: | |
| raise ValueError("GEMINI_API_KEY environment variable not set.") | |
| login(hf_token, add_to_git_credential=False) | |
| # Initialize Chrome driver | |
| try: | |
| chrome_options = webdriver.ChromeOptions() | |
| chrome_options.add_argument("--force-device-scale-factor=1") | |
| chrome_options.add_argument("--window-size=1000,1350") | |
| chrome_options.add_argument("--disable-pdf-viewer") | |
| chrome_options.add_argument("--no-sandbox") | |
| chrome_options.add_argument("--disable-dev-shm-usage") | |
| chrome_options.add_argument("--window-position=0,0") | |
| chrome_options.add_argument("--headless=new") | |
| driver = webdriver.Chrome(options=chrome_options) | |
| helium.set_driver(driver) | |
| logger.info("Chrome driver initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Chrome driver: {str(e)}") | |
| raise | |
| # Screenshot callback | |
| def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> str: | |
| from time import sleep | |
| sleep(1.0) | |
| driver = helium.get_driver() | |
| current_step = memory_step.step_number | |
| if driver is not None: | |
| # Clear old screenshots from earlier steps | |
| for previous_memory_step in agent.memory.steps: | |
| if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number < current_step: | |
| previous_memory_step.observations_images = None | |
| # Save new screenshot | |
| png_bytes = driver.get_screenshot_as_png() | |
| image = Image.open(BytesIO(png_bytes)) | |
| screenshot_dir = os.path.join(tempfile.gettempdir(), "web_agent_screenshots") | |
| os.makedirs(screenshot_dir, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| screenshot_filename = f"screenshot_step_{current_step}_{timestamp}.png" | |
| screenshot_path = os.path.join(screenshot_dir, screenshot_filename) | |
| image.save(screenshot_path) | |
| logger.info(f"Saved screenshot to: {screenshot_path}") | |
| # Update observations | |
| url_info = f"Current url: {driver.current_url}\nScreenshot saved at: {screenshot_path}" | |
| memory_step.observations = ( | |
| url_info if memory_step.observations is None else memory_step.observations + "\n" + url_info | |
| ) | |
| return screenshot_path | |
| # Load prompt templates | |
| try: | |
| with open("prompts.yaml", 'r') as stream: | |
| prompt_templates = yaml.safe_load(stream) | |
| except FileNotFoundError: | |
| prompt_templates = {} | |
| # Initialize tools | |
| tools = [ | |
| SearchItemCtrlFTool(driver=driver), | |
| GoBackTool(driver=driver), | |
| ClosePopupsTool(driver=driver), | |
| FinalAnswerTool() | |
| ] | |
| # Initialize model | |
| model = LiteLLMModel(model_name="gemini/gemini-2.0-flash", api_key=gemini_api_key, max_tokens=2096, temperature=0.5) | |
| # Initialize agent | |
| agent = CodeAgent( | |
| model=model, | |
| tools=tools, | |
| max_steps=20, | |
| verbosity_level=2, | |
| prompt_templates=prompt_templates, | |
| step_callbacks=[save_screenshot], | |
| additional_authorized_imports=[ | |
| "helium", | |
| "unicodedata", | |
| "stat", | |
| "datetime", | |
| "random", | |
| "pandas", | |
| "itertools", | |
| "math", | |
| "statistics", | |
| "queue", | |
| "time", | |
| "collections", | |
| "re" | |
| ] | |
| ) | |
| agent.python_executor("from helium import *") | |
| # Launch Gradio UI | |
| try: | |
| GradioUI(agent, file_upload_folder=os.path.join(tempfile.gettempdir(), "uploads")).launch() | |
| except KeyboardInterrupt: | |
| driver.quit() | |
| logger.info("Chrome driver closed on exit.") |