Spaces:
Paused
Paused
import os | |
import json | |
import yaml | |
import logging | |
from dotenv import load_dotenv | |
from huggingface_hub import login | |
from selenium import webdriver | |
from selenium.webdriver.common.by import By | |
from selenium.webdriver.common.keys import Keys | |
from io import BytesIO | |
from PIL import Image | |
from datetime import datetime | |
import tempfile | |
import helium | |
from smolagents import CodeAgent, LiteLLMModel | |
from smolagents.agents import ActionStep | |
from tools.search_item_ctrl_f import SearchItemCtrlFTool | |
from tools.go_back import GoBackTool | |
from tools.close_popups import ClosePopupsTool | |
from tools.final_answer import FinalAnswerTool | |
from GRADIO_UI import GradioUI | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
hf_token = os.getenv("HF_TOKEN") | |
gemini_api_key = os.getenv("GOOGLE_API_KEY") | |
if not hf_token: | |
raise ValueError("HF_TOKEN environment variable not set.") | |
if not gemini_api_key: | |
raise ValueError("GEMINI_API_KEY environment variable not set.") | |
login(hf_token, add_to_git_credential=False) | |
# Initialize Chrome driver | |
try: | |
chrome_options = webdriver.ChromeOptions() | |
chrome_options.add_argument("--force-device-scale-factor=1") | |
chrome_options.add_argument("--window-size=1000,1350") | |
chrome_options.add_argument("--disable-pdf-viewer") | |
chrome_options.add_argument("--no-sandbox") | |
chrome_options.add_argument("--disable-dev-shm-usage") | |
chrome_options.add_argument("--window-position=0,0") | |
chrome_options.add_argument("--headless=new") | |
driver = webdriver.Chrome(options=chrome_options) | |
helium.set_driver(driver) | |
logger.info("Chrome driver initialized successfully.") | |
except Exception as e: | |
logger.error(f"Failed to initialize Chrome driver: {str(e)}") | |
raise | |
# Screenshot callback | |
def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> str: | |
from time import sleep | |
sleep(1.0) | |
driver = helium.get_driver() | |
current_step = memory_step.step_number | |
if driver is not None: | |
# Clear old screenshots from earlier steps | |
for previous_memory_step in agent.memory.steps: | |
if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number < current_step: | |
previous_memory_step.observations_images = None | |
# Save new screenshot | |
png_bytes = driver.get_screenshot_as_png() | |
image = Image.open(BytesIO(png_bytes)) | |
screenshot_dir = os.path.join(tempfile.gettempdir(), "web_agent_screenshots") | |
os.makedirs(screenshot_dir, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
screenshot_filename = f"screenshot_step_{current_step}_{timestamp}.png" | |
screenshot_path = os.path.join(screenshot_dir, screenshot_filename) | |
image.save(screenshot_path) | |
logger.info(f"Saved screenshot to: {screenshot_path}") | |
# Update observations | |
url_info = f"Current url: {driver.current_url}\nScreenshot saved at: {screenshot_path}" | |
memory_step.observations = ( | |
url_info if memory_step.observations is None else memory_step.observations + "\n" + url_info | |
) | |
return screenshot_path | |
# Load prompt templates | |
try: | |
with open("prompts.yaml", 'r') as stream: | |
prompt_templates = yaml.safe_load(stream) | |
except FileNotFoundError: | |
prompt_templates = {} | |
# Initialize tools | |
tools = [ | |
SearchItemCtrlFTool(driver=driver), | |
GoBackTool(driver=driver), | |
ClosePopupsTool(driver=driver), | |
FinalAnswerTool() | |
] | |
# Initialize model | |
model = LiteLLMModel(model_name="gemini/gemini-2.0-flash", api_key=gemini_api_key, max_tokens=2096, temperature=0.5) | |
# Initialize agent | |
agent = CodeAgent( | |
model=model, | |
tools=tools, | |
max_steps=20, | |
verbosity_level=2, | |
prompt_templates=prompt_templates, | |
step_callbacks=[save_screenshot], | |
additional_authorized_imports=[ | |
"helium", | |
"unicodedata", | |
"stat", | |
"datetime", | |
"random", | |
"pandas", | |
"itertools", | |
"math", | |
"statistics", | |
"queue", | |
"time", | |
"collections", | |
"re" | |
] | |
) | |
agent.python_executor("from helium import *") | |
# Launch Gradio UI | |
try: | |
GradioUI(agent, file_upload_folder=os.path.join(tempfile.gettempdir(), "uploads")).launch() | |
except KeyboardInterrupt: | |
driver.quit() | |
logger.info("Chrome driver closed on exit.") |