wt002 commited on
Commit
3985578
·
verified ·
1 Parent(s): eda9ba9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -293
app.py CHANGED
@@ -1,321 +1,210 @@
1
  import os
2
- from typing import Annotated, Optional, TypedDict
3
  import gradio as gr
4
- from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
5
- from langchain_openai import ChatOpenAI
6
- from langgraph.graph.message import add_messages
7
- from langgraph.graph import StateGraph, START
8
- from langgraph.prebuilt import tools_condition, ToolNode
9
  import requests
10
  import pandas as pd
11
- from langchain.tools import Tool
12
- from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- import arxiv
15
- from chess_algebraic_notation_retriever import ChessAlgebraicNotationMoveRetriever
16
- from excel_file_reader import ExcelFileReader
17
- from image_question_answer_tool import ImageQuestionAnswerTool
18
- from python_code_question_answer_tool import PythonCodeQuestionAnswerTool
19
- from tavily_searcher import TavilySearcher
20
- from transcriber import Transcriber
21
- from wikipedia_searcher import WikipediaSearcher
22
- from youtube_video_question_answer_tool import YoutubeVideoQuestionAnswerTool
23
 
24
- load_dotenv()
25
 
26
  # (Keep Constants as is)
27
  # --- Constants ---
28
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
29
- ASSOCIATED_FILE_ENDPOINT = f"{DEFAULT_API_URL}/files/"
30
 
31
- # --- Basic Agent Definition ---
32
- # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
33
- #search_tool = DuckDuckGoSearchRun()
34
-
35
- #search_tool = DuckDuckGoSearcherTool()
36
 
37
- def retrieve_task_file(task_id: str) -> Optional[bytes]:
 
38
  """
39
- Retrieve the task file for a given task ID.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  """
41
  try:
42
- response = requests.get(ASSOCIATED_FILE_ENDPOINT + task_id, timeout=15)
43
- response.raise_for_status()
44
- if response.status_code != 200:
45
- print(f"Error fetching file: {response.status_code}")
46
- return None
47
- #print(f"Fetched file: {response.content}")
48
- return response.content
49
- except requests.exceptions.RequestException as e:
50
- print(f"Error fetching file: {e}")
51
- return None
 
 
 
 
 
 
 
 
 
 
52
  except Exception as e:
53
- print(f"An unexpected error occurred fetching file: {e}")
54
- return None
55
 
56
- def retrieve_next_chess_move_in_algebraic_notation(task_file_path: str, is_black_turn: bool) -> str:
57
- """
58
- Retrieve the next chess move in algebraic notation from an image path.
59
- """
60
- if task_file_path is None:
61
- return "Error: Task file not found."
62
- # Retrieve the next chess move in algebraic notation
63
- next_chess_move = ChessAlgebraicNotationMoveRetriever().retrieve(task_file_path, is_black_turn)
64
- return next_chess_move
65
-
66
- # Initialize the tool
67
- retrieve_next_chess_move_in_algebraic_notation_tool = Tool(
68
- name="retrieve_next_chess_move_in_algebraic_notation",
69
- func=retrieve_next_chess_move_in_algebraic_notation,
70
- description="Retrieve the next chess move in algebraic notation from an image path."
71
- )
72
-
73
- def transcribe_audio(file_path: str) -> str:
74
- if file_path is None:
75
- return "Error: Audio path not found."
76
- # Transcribe the audio
77
- return Transcriber().transcribe(file_path)
78
-
79
- # Initialize the tool
80
- transcribe_audio_tool = Tool(
81
- name="transcribe_audio",
82
- func=transcribe_audio,
83
- description="Transcribe the audio from an audio path."
84
- )
85
-
86
- # Initialize the tool
87
- answer_python_code_tool = PythonCodeQuestionAnswerTool()
88
-
89
- # Initialize the tool
90
- answer_image_question_tool = ImageQuestionAnswerTool()
91
-
92
- # Initialize the tool
93
- answer_youtube_video_question_tool = YoutubeVideoQuestionAnswerTool()
94
-
95
- '''def answer_youtube_video_question(youtube_video_url: str, question: str) -> str:
96
- """
97
- Answer the question based on the youtube video.
98
  """
99
- if youtube_video_url is None:
100
- return "Error: Video not found."
101
- # Download the video
102
- video_path = YoutubeVideoDownloader().download_video(youtube_video_url)
103
- # Answer the question
104
- return VideoQuestionAnswer().answer(video_path, question)
105
- # Initialize the tool
106
- answer_youtube_video_question_tool = Tool(
107
- name="answer_youtube_video_question",
108
- func=answer_youtube_video_question,
109
- description="Answer the question based on the youtube video."
110
- )'''
111
-
112
- def read_excel_file(file_path: str) -> str:
113
- if file_path is None:
114
- return "Error: File not found."
115
- return ExcelFileReader().read_file(file_path)
116
-
117
- # Initialize the tool
118
- read_excel_file_tool = Tool(
119
- name="read_excel_file",
120
- func=read_excel_file,
121
- description="Read the excel file."
122
- )
123
-
124
- # Initialize the tool
125
- wikipedia_search_tool = Tool(
126
- name="wikipedia_search",
127
- func=WikipediaSearcher().search,
128
- description="Search Wikipedia for a given query."
129
- )
130
-
131
- # Initialize the tool
132
- arxiv_search_tool = Tool(
133
- name="arxiv_search",
134
- func=ArxivSearcher().search,
135
- description="Search Arxiv for a given query."
136
- )
137
-
138
- tavily_search_tool = Tool(
139
- name="tavily_search",
140
- func=TavilySearcher().search,
141
- description="Search the web for a given query."
142
- )
143
-
144
- def format_gaia_answer(answer: str) -> str:
145
- llm = ChatOpenAI(model="o3-mini", openai_api_key=os.getenv("OPENAI_API_KEY"))
146
- prompt = f"""
147
- You are formatting answers for the GAIA benchmark, which requires responses to be concise and unambiguous.
148
- Given the answer: {answer}
149
- Return the answer in the correct GAIA format:
150
- - If the answer is a single word or number, return it without any additional text or formatting.
151
- - If the answer is a list, return a comma-separated list without any additional text or formatting.
152
- - If the answer is a string, return it without any additional text or formatting.
153
- Do not include any prefixes, dots, enumerations, explanations, or quotation marks.
154
- Do not include any additional text or formatting.
155
  """
156
- response = llm.invoke(prompt)
157
- # Delete double quotes
158
- return response.content.strip().replace('"', '')
159
-
160
- class AgentState(TypedDict):
161
- # The document provided
162
- messages: Annotated[list[AnyMessage], add_messages]
163
- file_path: Optional[str]
164
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  class BasicAgent:
 
166
  def __init__(self):
167
- tools = [
168
- tavily_search_tool,
169
- arxiv_search_tool,
170
- wikipedia_search_tool,
171
- transcribe_audio_tool,
172
- answer_python_code_tool,
173
- answer_image_question_tool,
174
- answer_youtube_video_question_tool,
175
- read_excel_file_tool
176
- ]
177
- '''llm = ChatGoogleGenerativeAI(
178
- model="gemini-2.0-flash",
179
- temperature=0.2,
180
- api_key=os.getenv("GEMINI_API_KEY")
181
- )'''
182
- llm = ChatOpenAI(model="o3-mini", openai_api_key=os.getenv("OPENAI_API_KEY"))
183
- self.llm_with_tools = llm.bind_tools(tools)
184
- builder = StateGraph(AgentState)
185
-
186
- # Define nodes: these do the work
187
- builder.add_node("assistant", self.assistant)
188
- builder.add_node("tools", ToolNode(tools))
189
-
190
- # Define edges: these determine how the control flow moves
191
- builder.add_edge(START, "assistant")
192
- builder.add_conditional_edges(
193
- "assistant",
194
- # If the latest message requires a tool, route to tools
195
- # Otherwise, provide a direct response
196
- tools_condition,
197
- )
198
- builder.add_edge("tools", "assistant")
199
- self.agent = builder.compile()
200
 
 
201
  print("BasicAgent initialized.")
202
 
203
- def assistant(self, state: AgentState):
204
- # System message
205
- textual_description_of_tools="""
206
- tavily_search(query: str) -> str:
207
- Search the web for a given query.
208
- Args:
209
- query: Query to search the web for (string).
210
- Returns:
211
- A single string containing the information found on the web.
212
- arxiv_search(query: str) -> str:
213
- Search Arxiv, that contains scientific papers, for a given query.
214
- Args:
215
- query: Query to search Arxiv for (string).
216
- Returns:
217
- A single string containing the answer to the question.
218
- wikipedia_search(query: str) -> str:
219
- Search Wikipedia for a given query.
220
- Args:
221
- query: Query to search Wikipedia for (string).
222
- Returns:
223
- A single string containing the answer to the question.
224
- transcribe_audio(file_path: str) -> str:
225
- Transcribe the audio from an audio path.
226
- Args:
227
- file_path: File path of the audio file (string).
228
- Returns:
229
- A single string containing the transcribed text from the audio.
230
-
231
- answer_python_code(file_path: str, question: str) -> str:
232
- Answer the question based on the python code.
233
- Args:
234
- file_path: File path of the python file (string).
235
- question: Question to answer (string).
236
- Returns:
237
- A single string containing the answer to the question.
238
-
239
- answer_image_question(file_path: str, question: str) -> str:
240
- Answer the question based on the image.
241
- Args:
242
- file_path: File path of the image (string).
243
- question: Question to answer (string).
244
- Returns:
245
- A single string containing the answer to the question.
246
-
247
- download_youtube_video(youtube_video_url: str) -> str:
248
- Download the Youtube video into a local file based on the URL
249
- Args:
250
- youtube_video_url: A youtube video url (string).
251
- Returns:
252
- A single string containing the file path of the downloaded youtube video.
253
- answer_youtube_video_question(file_path: str, question: str) -> str:
254
- Answer the question based on file path of the downloaded youtube video
255
- Args:
256
- file_path: File path of the downloaded youtube video (string).
257
- question: Question to answer (string).
258
- Returns:
259
- A single string containing the answer to the question.
260
-
261
- read_excel_file(file_path: str) -> str:
262
- Read the excel file.
263
- Args:
264
- file_path: File path of the excel file (string).
265
- Returns:
266
- A markdown formatted string containing the contents of the excel file.
267
- """
268
- file_path=state["file_path"]
269
- prompt = f"""
270
- You are a helpful assistant that can analyse images, videos, excel files and Python scripts and run computations with provided tools:
271
- {textual_description_of_tools}
272
- You have access to the file path of the attached file in case it's informed. Currently the file path is: {file_path}
273
- Be direct and specific. GAIA benchmark requires exact matching answers.
274
- For example, if asked "What is the capital of France?", respond simply with "Paris".
275
- Do not include any prefixes, dots, enumerations, explanations, or quotation marks.
276
- Do not include any additional text or formatting.
277
- If you are required a number, return a number, not the items.
278
- """
279
- sys_msg = SystemMessage(content=prompt)
280
 
281
- return {
282
- "messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"], config={"configurable": {"file_path": state["file_path"]}})],
283
- "file_path": state["file_path"]
284
- }
285
- '''return {
286
- "messages": [self.llm_with_tools.invoke(
287
- state["messages"],
288
- config={"configurable": {"file_path": state["file_path"]}} # Aquí pasas el task_id
289
- )],
290
- "file_path": state["file_path"]
291
- }'''
292
-
293
- def __call__(self, question: str, task_id: str, file_name: str) -> str:
294
- print(f"######################### Agent received question (first 50 chars): {question[:50]}... with file_name: {file_name}")
295
-
296
- # Get the file path
297
- tmp_file_path = None
298
- if file_name is not None and file_name != "":
299
- file_content = retrieve_task_file(task_id)
300
- if file_content is not None:
301
- print(f"Saving file {file_name} to tmp folder")
302
- tmp_file_path = f"tmp/{file_name}"
303
- with open(tmp_file_path, "wb") as f:
304
- f.write(file_content)
305
- # Show the file path
306
- print(f"File path: {tmp_file_path}")
307
-
308
- messages = self.agent.invoke({"messages": [HumanMessage(question)], "file_path": tmp_file_path})
309
- # Show the messages
310
- for m in messages['messages']:
311
- m.pretty_print()
312
- answer = messages["messages"][-1].content
313
- answer = format_gaia_answer(answer)
314
- print(f"######################### Agent returning answer: {answer}\n")
315
- # Delete the file
316
- if tmp_file_path is not None:
317
- os.remove(tmp_file_path)
318
- return answer
319
 
320
  def run_and_submit_all( profile: gr.OAuthProfile | None):
321
  """
 
1
  import os
 
2
  import gradio as gr
 
 
 
 
 
3
  import requests
4
  import pandas as pd
5
+ from smolagents import CodeAgent, OpenAIServerModel, DuckDuckGoSearchTool, VisitWebpageTool, tool, \
6
+ FinalAnswerTool, PythonInterpreterTool, SpeechToTextTool, ToolCallingAgent
7
+ import yaml
8
+ import importlib
9
+ from io import BytesIO
10
+ import tempfile
11
+ import base64
12
+ from youtube_transcript_api import YouTubeTranscriptApi
13
+ from youtube_transcript_api._errors import TranscriptsDisabled, NoTranscriptFound, VideoUnavailable
14
+ from urllib.parse import urlparse, parse_qs
15
+ import json
16
+ import whisper
17
+ import re
18
 
 
 
 
 
 
 
 
 
 
19
 
 
20
 
21
  # (Keep Constants as is)
22
  # --- Constants ---
23
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
24
 
 
 
 
 
 
25
 
26
+ @tool
27
+ def transcribe_audio_file(file_path: str) -> str:
28
  """
29
+ Transcribes a local MP3 audio file using Whisper.
30
+ Args:
31
+ file_path: Full path to the .mp3 audio file.
32
+ Returns:
33
+ A JSON-formatted string containing either the transcript or an error message.
34
+ {
35
+ "success": true,
36
+ "transcript": [
37
+ {"start": 0.0, "end": 5.2, "text": "Hello and welcome"},
38
+ ...
39
+ ]
40
+ }
41
+ OR
42
+ {
43
+ "success": false,
44
+ "error": "Reason why transcription failed"
45
+ }
46
  """
47
  try:
48
+ if not os.path.exists(file_path):
49
+ return json.dumps({"success": False, "error": "File does not exist."})
50
+
51
+ if not file_path.lower().endswith(".mp3"):
52
+ return json.dumps({"success": False, "error": "Invalid file type. Only MP3 files are supported."})
53
+
54
+ model = whisper.load_model("base") # You can use 'tiny', 'base', 'small', 'medium', or 'large'
55
+ result = model.transcribe(file_path, verbose=False, word_timestamps=False)
56
+
57
+ transcript_data = [
58
+ {
59
+ "start": segment["start"],
60
+ "end": segment["end"],
61
+ "text": segment["text"].strip()
62
+ }
63
+ for segment in result["segments"]
64
+ ]
65
+
66
+ return json.dumps({"success": True, "transcript": transcript_data})
67
+
68
  except Exception as e:
69
+ return json.dumps({"success": False, "error": str(e)})
 
70
 
71
+
72
+ @tool
73
+ def get_youtube_transcript(video_url: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  """
75
+ Retrieves the transcript from a YouTube video URL, including timestamps.
76
+ This tool fetches the English transcript for a given YouTube video. Automatically generated subtitles
77
+ are also supported. The result includes each snippet's start time, duration, and text.
78
+ Args:
79
+ video_url: The full URL of the YouTube video (e.g., https://www.youtube.com/watch?v=12345)
80
+ Returns:
81
+ A JSON-formatted string containing either the transcript with timestamps or an error message.
82
+ {
83
+ "success": true,
84
+ "transcript": [
85
+ {"start": 0.0, "duration": 1.54, "text": "Hey there"},
86
+ {"start": 1.54, "duration": 4.16, "text": "how are you"},
87
+ ...
88
+ ]
89
+ }
90
+ OR
91
+ {
92
+ "success": false,
93
+ "error": "Reason why the transcript could not be retrieved"
94
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  """
96
+ try:
97
+ # Extract video ID from URL
98
+ parsed_url = urlparse(video_url)
99
+ query_params = parse_qs(parsed_url.query)
100
+ video_id = query_params.get("v", [None])[0]
101
+
102
+ if not video_id:
103
+ return json.dumps({"success": False, "error": "Invalid YouTube URL. Could not extract video ID."})
104
+
105
+ fetched_transcript = YouTubeTranscriptApi().fetch(video_id)
106
+ transcript_data = [
107
+ {
108
+ "start": snippet.start,
109
+ "duration": snippet.duration,
110
+ "text": snippet.text
111
+ }
112
+ for snippet in fetched_transcript
113
+ ]
114
+
115
+ return json.dumps({"success": True, "transcript": transcript_data})
116
+
117
+ except VideoUnavailable:
118
+ return json.dumps({"success": False, "error": "The video is unavailable."})
119
+ except TranscriptsDisabled:
120
+ return json.dumps({"success": False, "error": "Transcripts are disabled for this video."})
121
+ except NoTranscriptFound:
122
+ return json.dumps({"success": False, "error": "No transcript found for this video."})
123
+ except Exception as e:
124
+ return json.dumps({"success": False, "error": str(e)})
125
+
126
+ # --- Basic Agent Definition ---
127
+ # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
128
  class BasicAgent:
129
+
130
  def __init__(self):
131
+ model = OpenAIServerModel(api_key=os.environ.get("OPENAI_API_KEY"), model_id="gpt-4o")
132
+
133
+ self.code_agent = CodeAgent(
134
+ tools=[PythonInterpreterTool(), DuckDuckGoSearchTool(), VisitWebpageTool(), transcribe_audio_file,
135
+ get_youtube_transcript,
136
+ FinalAnswerTool()],
137
+ model=model,
138
+ max_steps=20,
139
+ name="hf_agent_course_final_assignment_solver",
140
+ prompt_templates=yaml.safe_load(
141
+ importlib.resources.files("prompts").joinpath("code_agent.yaml").read_text()
142
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ )
145
  print("BasicAgent initialized.")
146
 
147
+ def __call__(self, task_id: str, question: str, file_name: str) -> str:
148
+ if file_name:
149
+ question = self.enrich_question_with_associated_file_details(task_id, question, file_name)
150
+
151
+ final_result = self.code_agent.run(question)
152
+
153
+ # Extract text after "FINAL ANSWER:" (case-insensitive, and trims whitespace)
154
+ match = re.search(r'final answer:\s*(.*)', str(final_result), re.IGNORECASE | re.DOTALL)
155
+ if match:
156
+ return match.group(1).strip()
157
+
158
+ # Fallback in case the pattern is not found
159
+ return str(final_result).strip()
160
+
161
+ def enrich_question_with_associated_file_details(self, task_id:str, question: str, file_name: str) -> str:
162
+ api_url = DEFAULT_API_URL
163
+ get_associated_files_url = f"{api_url}/files/{task_id}"
164
+ response = requests.get(get_associated_files_url, timeout=15)
165
+ response.raise_for_status()
166
+
167
+ if file_name.endswith(".mp3"):
168
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
169
+ tmp_file.write(response.content)
170
+ file_path = tmp_file.name
171
+ return question + "\n\nMentioned .mp3 file local path is: " + file_path
172
+ elif file_name.endswith(".py"):
173
+ file_content = response.text
174
+ return question + "\n\nBelow is mentioned Python file:\n\n```python\n" + file_content + "\n```\n"
175
+ elif file_name.endswith(".xlsx"):
176
+ xlsx_io = BytesIO(response.content)
177
+ df = pd.read_excel(xlsx_io)
178
+ file_content = df.to_csv(index=False)
179
+ return question + "\n\nBelow is mentioned excel file in CSV format:\n\n```csv\n" + file_content + "\n```\n"
180
+ elif file_name.endswith(".png"):
181
+ base64_str = base64.b64encode(response.content).decode('utf-8')
182
+ return question + "\n\nBelow is the .png image in base64 format:\n\n```base64\n" + base64_str + "\n```\n"
183
+
184
+
185
+ def enrich_question_with_associated_file_details(self, task_id:str, question: str, file_name: str) -> str:
186
+ api_url = DEFAULT_API_URL
187
+ get_associated_files_url = f"{api_url}/files/{task_id}"
188
+ response = requests.get(get_associated_files_url, timeout=15)
189
+ response.raise_for_status()
190
+
191
+ if file_name.endswith(".mp3"):
192
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
193
+ tmp_file.write(response.content)
194
+ file_path = tmp_file.name
195
+ return question + "\n\nMentioned .mp3 file local path is: " + file_path
196
+ elif file_name.endswith(".py"):
197
+ file_content = response.text
198
+ return question + "\n\nBelow is mentioned Python file:\n\n```python\n" + file_content + "\n```\n"
199
+ elif file_name.endswith(".xlsx"):
200
+ xlsx_io = BytesIO(response.content)
201
+ df = pd.read_excel(xlsx_io)
202
+ file_content = df.to_csv(index=False)
203
+ return question + "\n\nBelow is mentioned excel file in CSV format:\n\n```csv\n" + file_content + "\n```\n"
204
+ elif file_name.endswith(".png"):
205
+ base64_str = base64.b64encode(response.content).decode('utf-8')
206
+ return question + "\n\nBelow is the .png image in base64 format:\n\n```base64\n" + base64_str + "\n```\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  def run_and_submit_all( profile: gr.OAuthProfile | None):
210
  """