dlaima commited on
Commit
dc77905
·
verified ·
1 Parent(s): bc758d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -8
app.py CHANGED
@@ -4,14 +4,14 @@ import gradio as gr
4
  import requests
5
  import pandas as pd
6
 
7
- from smolagents import InferenceClientModel, ToolCallingAgent
 
8
  from audio_transcriber import AudioTranscriptionTool
9
  from image_analyzer import ImageAnalysisTool
10
  from wikipedia_searcher import WikipediaSearcher
11
 
12
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
13
 
14
- # Zephyr-compatible system prompt to prepend manually
15
  SYSTEM_PROMPT = (
16
  "You are an agent solving the GAIA benchmark and must provide exact answers.\n"
17
  "Rules:\n"
@@ -30,14 +30,30 @@ SYSTEM_PROMPT = (
30
  "Never say 'the answer is...'. Only return the answer.\n"
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class GaiaAgent:
34
  def __init__(self):
35
  print("Gaia Agent Initialized")
36
-
37
- self.model = InferenceClientModel(
38
- model_id="HuggingFaceH4/zephyr-7b-beta",
39
- token=os.getenv("HF_API_TOKEN", "").strip()
40
- )
41
 
42
  self.tools = [
43
  AudioTranscriptionTool(),
@@ -164,7 +180,6 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
164
  except Exception as e:
165
  return f"An unexpected error occurred during submission: {e}", pd.DataFrame(results_log)
166
 
167
- # Gradio UI
168
  with gr.Blocks() as demo:
169
  gr.Markdown("# Basic Agent Evaluation Runner")
170
  gr.Markdown("""
 
4
  import requests
5
  import pandas as pd
6
 
7
+ from transformers import BartForConditionalGeneration, BartTokenizer
8
+ from smolagents import ToolCallingAgent
9
  from audio_transcriber import AudioTranscriptionTool
10
  from image_analyzer import ImageAnalysisTool
11
  from wikipedia_searcher import WikipediaSearcher
12
 
13
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
14
 
 
15
  SYSTEM_PROMPT = (
16
  "You are an agent solving the GAIA benchmark and must provide exact answers.\n"
17
  "Rules:\n"
 
30
  "Never say 'the answer is...'. Only return the answer.\n"
31
  )
32
 
33
+ class LocalBartModel:
34
+ def __init__(self, model_name="facebook/bart-base", device=None):
35
+ import torch
36
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
37
+ self.tokenizer = BartTokenizer.from_pretrained(model_name)
38
+ self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.device)
39
+
40
+ def __call__(self, prompt: str) -> str:
41
+ import torch
42
+
43
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
44
+ outputs = self.model.generate(
45
+ **inputs,
46
+ max_length=128,
47
+ num_beams=5,
48
+ early_stopping=True
49
+ )
50
+ decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ return decoded.strip()
52
+
53
  class GaiaAgent:
54
  def __init__(self):
55
  print("Gaia Agent Initialized")
56
+ self.model = LocalBartModel()
 
 
 
 
57
 
58
  self.tools = [
59
  AudioTranscriptionTool(),
 
180
  except Exception as e:
181
  return f"An unexpected error occurred during submission: {e}", pd.DataFrame(results_log)
182
 
 
183
  with gr.Blocks() as demo:
184
  gr.Markdown("# Basic Agent Evaluation Runner")
185
  gr.Markdown("""