dlaima commited on
Commit
79fcd3e
·
verified ·
1 Parent(s): abf7526

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -18
app.py CHANGED
@@ -3,13 +3,13 @@ import os
3
  import gradio as gr
4
  import requests
5
  import pandas as pd
 
6
  import torch
7
- from transformers import BartForConditionalGeneration, BartTokenizer
8
 
 
9
  from audio_transcriber import AudioTranscriptionTool
10
  from image_analyzer import ImageAnalysisTool
11
  from wikipedia_searcher import WikipediaSearcher
12
- from smolagents import ToolCallingAgent
13
 
14
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
15
 
@@ -32,32 +32,38 @@ SYSTEM_PROMPT = (
32
  )
33
 
34
  class LocalBartModel:
35
- def __init__(self, model_name="facebook/bart-base"):
36
- self.tokenizer = BartTokenizer.from_pretrained(model_name)
37
- self.model = BartForConditionalGeneration.from_pretrained(model_name)
38
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  self.model.to(self.device)
 
40
 
41
  def generate(self, inputs, **generate_kwargs):
 
 
 
42
  input_ids = inputs.get("input_ids")
43
  attention_mask = inputs.get("attention_mask")
44
-
45
- if input_ids is None:
46
- raise ValueError("input_ids missing from tokenizer output")
47
 
48
  input_ids = input_ids.to(self.device)
49
- if attention_mask is not None:
50
- attention_mask = attention_mask.to(self.device)
51
 
52
- return self.model.generate(
53
- input_ids=input_ids,
54
- attention_mask=attention_mask,
55
- **generate_kwargs
56
- )
 
 
57
 
58
- def __call__(self, prompt: str) -> str:
59
- inputs = self.tokenizer(prompt, return_tensors="pt")
 
60
 
 
61
  output_ids = self.generate(
62
  inputs,
63
  max_length=100,
@@ -71,11 +77,13 @@ class GaiaAgent:
71
  def __init__(self):
72
  print("Gaia Agent Initialized")
73
  self.model = LocalBartModel()
 
74
  self.tools = [
75
  AudioTranscriptionTool(),
76
  ImageAnalysisTool(),
77
  WikipediaSearcher()
78
  ]
 
79
  self.agent = ToolCallingAgent(
80
  tools=self.tools,
81
  model=self.model
@@ -83,18 +91,19 @@ class GaiaAgent:
83
 
84
  def __call__(self, question: str) -> str:
85
  print(f"Agent received question (first 50 chars): {question[:50]}...")
86
-
87
  full_prompt = f"{SYSTEM_PROMPT}\nQUESTION:\n{question}"
88
 
89
  try:
90
  result = self.agent.run(full_prompt)
91
  print(f"Raw result from agent: {result}")
92
 
 
93
  if isinstance(result, dict) and "answer" in result:
94
  return str(result["answer"]).strip()
95
  elif isinstance(result, str):
96
  return result.strip()
97
  elif isinstance(result, list):
 
98
  for item in reversed(result):
99
  if isinstance(item, dict) and item.get("role") == "assistant" and "content" in item:
100
  return item["content"].strip()
 
3
  import gradio as gr
4
  import requests
5
  import pandas as pd
6
+ from transformers import BartTokenizer, BartForConditionalGeneration
7
  import torch
 
8
 
9
+ from smolagents import ToolCallingAgent
10
  from audio_transcriber import AudioTranscriptionTool
11
  from image_analyzer import ImageAnalysisTool
12
  from wikipedia_searcher import WikipediaSearcher
 
13
 
14
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
15
 
 
32
  )
33
 
34
  class LocalBartModel:
35
+ def __init__(self):
36
+ self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
37
+ self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
38
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  self.model.to(self.device)
40
+ self.model.eval()
41
 
42
  def generate(self, inputs, **generate_kwargs):
43
+ # inputs must be dict with input_ids and attention_mask
44
+ if not isinstance(inputs, dict):
45
+ raise ValueError(f"Expected dict input but got {type(inputs)}")
46
  input_ids = inputs.get("input_ids")
47
  attention_mask = inputs.get("attention_mask")
48
+ if input_ids is None or attention_mask is None:
49
+ raise ValueError("input_ids and attention_mask are required in inputs dict")
 
50
 
51
  input_ids = input_ids.to(self.device)
52
+ attention_mask = attention_mask.to(self.device)
 
53
 
54
+ with torch.no_grad():
55
+ outputs = self.model.generate(
56
+ input_ids=input_ids,
57
+ attention_mask=attention_mask,
58
+ **generate_kwargs
59
+ )
60
+ return outputs
61
 
62
+ def __call__(self, prompt):
63
+ if not isinstance(prompt, str):
64
+ raise ValueError(f"LocalBartModel expects a string prompt, got {type(prompt)}")
65
 
66
+ inputs = self.tokenizer(prompt, return_tensors="pt")
67
  output_ids = self.generate(
68
  inputs,
69
  max_length=100,
 
77
  def __init__(self):
78
  print("Gaia Agent Initialized")
79
  self.model = LocalBartModel()
80
+
81
  self.tools = [
82
  AudioTranscriptionTool(),
83
  ImageAnalysisTool(),
84
  WikipediaSearcher()
85
  ]
86
+
87
  self.agent = ToolCallingAgent(
88
  tools=self.tools,
89
  model=self.model
 
91
 
92
  def __call__(self, question: str) -> str:
93
  print(f"Agent received question (first 50 chars): {question[:50]}...")
 
94
  full_prompt = f"{SYSTEM_PROMPT}\nQUESTION:\n{question}"
95
 
96
  try:
97
  result = self.agent.run(full_prompt)
98
  print(f"Raw result from agent: {result}")
99
 
100
+ # Handle different result types robustly
101
  if isinstance(result, dict) and "answer" in result:
102
  return str(result["answer"]).strip()
103
  elif isinstance(result, str):
104
  return result.strip()
105
  elif isinstance(result, list):
106
+ # Try to extract assistant content from list
107
  for item in reversed(result):
108
  if isinstance(item, dict) and item.get("role") == "assistant" and "content" in item:
109
  return item["content"].strip()