ZeroTimo's picture
Update app.py
e3d9b1d verified
raw
history blame
3.82 kB
# ------------------------------------------------------------
# fast async app.py (korrekte Zuordnung + Gemini-Throttle)
# ------------------------------------------------------------
import os, asyncio, concurrent.futures, functools, json
from pathlib import Path
import gradio as gr, requests, pandas as pd
from langchain_core.messages import HumanMessage
from agent import agent_executor
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
MAX_PAR_LLM = 3 # 3 gleichz. Requests stay < 15/min
SEMAPHORE = asyncio.Semaphore(MAX_PAR_LLM)
# ---------- synchroner Agent-Aufruf ---------------------------------
def run_agent_sync(task_id: str, question: str) -> str:
payload = {
"messages": [HumanMessage(content=question)],
"task_id": task_id,
}
try:
res = agent_executor.invoke(payload)
return res["messages"][-1].content.strip()
except Exception as e:
return f"AGENT ERROR: {e}"
# ---------- async Wrapper + Throttle --------------------------------
async def run_agent_async(executor, task_id: str, question: str) -> str:
loop = asyncio.get_event_loop()
async with SEMAPHORE: # Gemini-Quota Guard
return await loop.run_in_executor(
executor, functools.partial(run_agent_sync, task_id, question)
)
# ---------- Main Gradio Callback ------------------------------------
async def run_and_submit_all(profile: gr.OAuthProfile | None, progress=gr.Progress()):
if not profile:
return "Please login with your HF account.", None
username = profile.username
# 1) Fragen laden
try:
questions = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15).json()
except Exception as e:
return f"Error fetching questions: {e}", None
progress(0, desc=f"Fetched {len(questions)} questions – processing …")
answers, logs = [], []
work = [(q["task_id"], q["question"]) for q in questions]
# 2) Parallel-Ausführung mit korrekt gemappten Tasks
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_PAR_LLM) as ex:
# korrekte Reihenfolge: Ergebnisse von gather ↔ Reihenfolge in work
tasks = [
run_agent_async(ex, tid, qst) # liefert str-Antwort
for tid, qst in work
]
# gather wartet auf alle, behält Reihenfolge
results = await asyncio.gather(*tasks)
for idx, answer in enumerate(results):
tid, qst = work[idx]
answers.append({"task_id": tid, "submitted_answer": answer})
logs.append({"Task ID": tid, "Question": qst, "Answer": answer})
progress((idx + 1) / len(work), desc=f"{idx+1}/{len(work)} done")
# 3) Submit
submit_url = f"{DEFAULT_API_URL}/submit"
payload = {
"username": username,
"agent_code": f"https://huggingface.co/spaces/{os.getenv('SPACE_ID')}/tree/main",
"answers": answers,
}
try:
res = requests.post(submit_url, json=payload, timeout=60).json()
status = (f"Submission OK – Score {res.get('score','?')} % "
f"({res.get('correct_count','?')}/{res.get('total_attempted','?')})")
except Exception as e:
status = f"Submission failed: {e}"
return status, pd.DataFrame(logs)
# ---------- Gradio UI -----------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Fast GAIA Agent Runner (async × progress)")
gr.LoginButton()
run_btn = gr.Button("Run & Submit")
out_status = gr.Textbox(label="Status / Score", lines=3)
out_table = gr.DataFrame(label="Answers", wrap=True)
run_btn.click(run_and_submit_all, outputs=[out_status, out_table])
if __name__ == "__main__":
demo.launch(debug=True, share=False)