Michele De Stefano
commited on
Commit
·
1b8aef5
1
Parent(s):
81917a3
Adapted the code so that it can run locally
Browse files- README.md +20 -0
- agent_factory.py +221 -0
- app.py +130 -38
- data/__init__.py +0 -0
- local_setup/__init__.py +0 -0
- local_setup/nltk_config.py +12 -0
- question_retriever.py +42 -0
- requirements.txt +31 -2
- tests/__init__.py +0 -0
- tests/data/__init__.py +0 -0
- tests/download_data.py +31 -0
- tests/resources/__init__.py +0 -0
- tests/resources/african-penguins-kelp-gull.jpg +0 -0
- tests/resources/penguin.jpeg +0 -0
- tests/test_agent.py +27 -0
- tests/test_download_questions_and_files.py +6 -0
- tests/tools/__init__.py +0 -0
- tests/tools/test_audio_transcriber.py +53 -0
- tests/tools/test_bird_classifier.py +43 -0
- tests/tools/test_excel_table_content_retriever.py +29 -0
- tests/tools/test_math_reasoning.py +30 -0
- tests/tools/test_python_script_executor.py +29 -0
- tests/tools/test_string_reverser.py +27 -0
- tests/tools/test_web_page_info_retriever.py +21 -0
- tests/tools/test_web_search.py +28 -0
- tests/tools/test_whisper.py +53 -0
- tests/tools/test_youtube_transcript.py +28 -0
- tests/tools/test_youtube_video_analysis.py +28 -0
- tests/tools/test_youtube_video_frame_sampler.py +31 -0
- tools/__init__.py +8 -0
- tools/audio_transcriber.py +54 -0
- tools/data_helpers.py +16 -0
- tools/excel_table_content_retriever.py +23 -0
- tools/math_tools.py +15 -0
- tools/python_script_executor.py +27 -0
- tools/string_reverser.py +16 -0
- tools/video_sampling.py +102 -0
- tools/web_page_info_retriever.py +59 -0
- tools/youtube_helpers.py +62 -0
- tools/youtube_video_transcript_retriever.py +24 -0
README.md
CHANGED
@@ -11,5 +11,25 @@ hf_oauth: true
|
|
11 |
# optional, default duration is 8 hours/480 minutes. Max duration is 30 days/43200 minutes.
|
12 |
hf_oauth_expiration_minutes: 480
|
13 |
---
|
|
|
|
|
|
|
|
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
11 |
# optional, default duration is 8 hours/480 minutes. Max duration is 30 days/43200 minutes.
|
12 |
hf_oauth_expiration_minutes: 480
|
13 |
---
|
14 |
+
# Evaluation application for Unit 4 of the HuggingFace Agents course
|
15 |
+
This is my implementation of the evaluation application.
|
16 |
+
Differently from the original application I cloned from, this one is thought
|
17 |
+
to be run locally, because I am using Ollama.
|
18 |
|
19 |
+
When running locally, you have to create a `.env` file into the root of the
|
20 |
+
project. This file gets read from the `dotenv.load_dotenv()` instruction
|
21 |
+
and must contain the following variables:
|
22 |
+
```commandline
|
23 |
+
HF_USERNAME="<your HuggingFace user name>"
|
24 |
+
HF_ACCESS_TOKEN="<your HuggingFace access token>"
|
25 |
+
SPACE_HOST="localhost"
|
26 |
+
SPACE_ID="<your space ID>"
|
27 |
+
```
|
28 |
+
You can infer the space ID by reading the address of your space when you
|
29 |
+
access the `Files` section. For example, if you read
|
30 |
+
```commandline
|
31 |
+
https://huggingface.co/spaces/aaa/bbb/tree/main
|
32 |
+
```
|
33 |
+
then the `SPACE_ID` is `aaa/bbb` (where `aaa` should be the user name).
|
34 |
+
# Configuration
|
35 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
agent_factory.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Any, Literal
|
3 |
+
|
4 |
+
from langchain_community.tools import DuckDuckGoSearchResults
|
5 |
+
from langchain_core.messages import SystemMessage, AnyMessage
|
6 |
+
from langchain_core.runnables import Runnable
|
7 |
+
from langchain_core.tools import BaseTool
|
8 |
+
from langchain_ollama import ChatOllama
|
9 |
+
from langgraph.constants import START, END
|
10 |
+
from langgraph.graph import MessagesState, StateGraph
|
11 |
+
from langgraph.graph.graph import CompiledGraph
|
12 |
+
from langgraph.prebuilt import ToolNode
|
13 |
+
from pydantic import BaseModel
|
14 |
+
|
15 |
+
from tools import (
|
16 |
+
get_excel_table_content,
|
17 |
+
get_youtube_video_transcript,
|
18 |
+
reverse_string,
|
19 |
+
transcribe_audio_file,
|
20 |
+
web_page_info_retriever,
|
21 |
+
youtube_video_to_frame_captions, sum_list, execute_python_script,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class AgentFactory:
|
26 |
+
"""
|
27 |
+
A factory for the agent. It is assumed that an Ollama server is running
|
28 |
+
on the machine where the factory is used.
|
29 |
+
"""
|
30 |
+
|
31 |
+
__system_prompt: str = (
|
32 |
+
"You have to answer to test questions and you need to score high.\n"
|
33 |
+
"Sometimes auxiliary files may be attached to the question so the\n"
|
34 |
+
"question itself is presented as a JSON string with the following\n"
|
35 |
+
"fields:\n"
|
36 |
+
"1. task_id: unique hash identifier of the question.\n"
|
37 |
+
"2. question: the text of the question.\n"
|
38 |
+
"3. Level: a number with the question difficulty level. You can ignore "
|
39 |
+
"this field.\n"
|
40 |
+
"4. file_name: the name of the file needed to answer the question. "
|
41 |
+
"This is empty if the question does not refer to any file. "
|
42 |
+
"IMPORTANT: The text of the question may mention a file name that is "
|
43 |
+
"different from what is reported into the \"file_name\" JSON field. "
|
44 |
+
"YOU HAVE TO IGNORE THE FILE NAME MENTIONED INTO \"question\" AND "
|
45 |
+
"YOU MUST USE THE FILE NAME PROVIDED INTO THE \"file_name\" FIELD.\n"
|
46 |
+
"\n"
|
47 |
+
"Depending on the question, the\n"
|
48 |
+
"format of your answer is a number OR as few words as possible OR a\n"
|
49 |
+
"comma separated list of numbers and/or strings. If you are asked for\n"
|
50 |
+
"a number, don't use comma to write your number neither use units\n"
|
51 |
+
"such as $ or percent sign unless specified otherwise. If you are\n"
|
52 |
+
"asked for a string, don't use articles, neither abbreviations (e.g.\n"
|
53 |
+
"for cities), and write the digits in plain text unless specified\n"
|
54 |
+
"otherwise. If you are asked for a comma separated list, apply the\n"
|
55 |
+
"above rules depending of whether the element to be put in the list\n"
|
56 |
+
"is a number or a string.\n"
|
57 |
+
"When you have to perform a sum, DON'T try to do that yourself.\n"
|
58 |
+
"Exploit the tool that is able to sum list of numbers. If you have\n"
|
59 |
+
"to sum the results of previous sums, use again the same tool\n"
|
60 |
+
"recursively. NEVER do the sums yourself.\n"
|
61 |
+
"Achieve the solution by dividing your reasoning in steps, and\n"
|
62 |
+
"provide an explanation for each step.\n"
|
63 |
+
"You are advised to cycle between reasoning and tool calling also\n"
|
64 |
+
"multiple times. Provide an answer only when you are sure you don't\n"
|
65 |
+
"have to call any tool again. Provide the answer between\n"
|
66 |
+
"<ANSWER> and </ANSWER> tags. I stress that the final answer must\n"
|
67 |
+
"follow the rules explained above.\n"
|
68 |
+
)
|
69 |
+
|
70 |
+
__llm_for_decision: Runnable
|
71 |
+
__llm: Runnable
|
72 |
+
__tools: list[BaseTool]
|
73 |
+
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
model: str = "qwen2.5-coder:32b",
|
77 |
+
# model: str = "mistral-small3.1",
|
78 |
+
# model: str = "phi4-mini",
|
79 |
+
temperature: float = 0.0,
|
80 |
+
num_ctx: int = 8192
|
81 |
+
) -> None:
|
82 |
+
"""
|
83 |
+
Constructor.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
model: The name of the Ollama model to use.
|
87 |
+
temperature: Temperature parameter.
|
88 |
+
num_ctx: Size of the context window used to generate the
|
89 |
+
next token.
|
90 |
+
"""
|
91 |
+
search_tool = DuckDuckGoSearchResults(
|
92 |
+
description=(
|
93 |
+
"A wrapper around Duck Duck Go Search. Useful for when you "
|
94 |
+
"need to answer questions about information you can find on "
|
95 |
+
"the web. Input should be a search query. It is advisable to "
|
96 |
+
"use this tool to retrieve web page URLs and use another tool "
|
97 |
+
"to analyze the pages. If the web source is suggested by the "
|
98 |
+
"user query, prefer retrieving information from that source. "
|
99 |
+
"For example, the query may suggest to search on Wikipedia or "
|
100 |
+
"Medium. In those cases, prepend the query with "
|
101 |
+
"'site: <name of the source>'. For example: "
|
102 |
+
"'site: wikipedia.org'"
|
103 |
+
),
|
104 |
+
output_format="list"
|
105 |
+
)
|
106 |
+
search_tool.with_retry()
|
107 |
+
self.__tools = [
|
108 |
+
execute_python_script,
|
109 |
+
get_excel_table_content,
|
110 |
+
get_youtube_video_transcript,
|
111 |
+
reverse_string,
|
112 |
+
search_tool,
|
113 |
+
sum_list,
|
114 |
+
transcribe_audio_file,
|
115 |
+
web_page_info_retriever,
|
116 |
+
youtube_video_to_frame_captions
|
117 |
+
]
|
118 |
+
self.__llm_for_decision = ChatOllama(
|
119 |
+
model=model,
|
120 |
+
temperature=1.0,
|
121 |
+
num_ctx=num_ctx
|
122 |
+
)
|
123 |
+
self.__llm = ChatOllama(
|
124 |
+
model=model,
|
125 |
+
temperature=temperature,
|
126 |
+
num_ctx=num_ctx
|
127 |
+
).bind_tools(tools=self.__tools)
|
128 |
+
|
129 |
+
def __decide_for_code_agent(self, state: MessagesState) -> str:
|
130 |
+
decision_messages = [
|
131 |
+
SystemMessage(
|
132 |
+
content="Answer only yes or no. "
|
133 |
+
"If you think the question can be easily answered "
|
134 |
+
"by writing Python code and executing it then answer "
|
135 |
+
"yes. If you think you can answer by exploiting other "
|
136 |
+
"resources then answer no."
|
137 |
+
),
|
138 |
+
state["messages"][-1]
|
139 |
+
]
|
140 |
+
answer = self.__llm_for_decision.invoke(decision_messages)
|
141 |
+
return answer.content
|
142 |
+
|
143 |
+
def __run_llm(self, state: MessagesState) -> dict[str, Any]:
|
144 |
+
answer = self.__llm.invoke(state["messages"])
|
145 |
+
# Remove thinking pattern if present
|
146 |
+
pattern = r'\n*<think>.*?</think>\n*'
|
147 |
+
answer.content = re.sub(
|
148 |
+
pattern, "", answer.content, flags=re.DOTALL
|
149 |
+
)
|
150 |
+
return {"messages": [answer]}
|
151 |
+
|
152 |
+
@staticmethod
|
153 |
+
def __extract_last_message(
|
154 |
+
state: list[AnyMessage] | dict[str, Any] | BaseModel,
|
155 |
+
messages_key: str
|
156 |
+
) -> str:
|
157 |
+
if isinstance(state, list):
|
158 |
+
last_message = state[-1]
|
159 |
+
elif isinstance(state, dict) and (messages := state.get(messages_key, [])):
|
160 |
+
last_message = messages[-1]
|
161 |
+
elif messages := getattr(state, messages_key, []):
|
162 |
+
last_message = messages[-1]
|
163 |
+
else:
|
164 |
+
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
165 |
+
return last_message
|
166 |
+
|
167 |
+
def __route_from_llm(
|
168 |
+
self,
|
169 |
+
state: list[AnyMessage] | dict[str, Any] | BaseModel,
|
170 |
+
messages_key: str = "messages",
|
171 |
+
) -> Literal["tools", "extract_final_answer"]:
|
172 |
+
ai_message = self.__extract_last_message(state, messages_key)
|
173 |
+
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
174 |
+
return "tools"
|
175 |
+
return "extract_final_answer"
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def __extract_final_answer(state: MessagesState) -> dict[str, Any]:
|
179 |
+
last_message = state["messages"][-1].content
|
180 |
+
pattern = r"<ANSWER>(?P<answer>.*?)</ANSWER>"
|
181 |
+
m = re.search(pattern, last_message, flags=re.DOTALL)
|
182 |
+
answer = m.group("answer").strip() if m else ""
|
183 |
+
return {"messages": [answer]}
|
184 |
+
|
185 |
+
@property
|
186 |
+
def system_prompt(self) -> SystemMessage:
|
187 |
+
"""
|
188 |
+
Returns:
|
189 |
+
The system prompt to use with the agent.
|
190 |
+
"""
|
191 |
+
return SystemMessage(content=self.__system_prompt)
|
192 |
+
|
193 |
+
def get(self) -> CompiledGraph:
|
194 |
+
"""
|
195 |
+
Factory method.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
The instance of the agent.
|
199 |
+
"""
|
200 |
+
graph_builder = StateGraph(MessagesState)
|
201 |
+
|
202 |
+
graph_builder.add_node("LLM", self.__run_llm)
|
203 |
+
graph_builder.add_node("tools", ToolNode(tools=self.__tools))
|
204 |
+
graph_builder.add_node(
|
205 |
+
"extract_final_answer",
|
206 |
+
self.__extract_final_answer
|
207 |
+
)
|
208 |
+
|
209 |
+
graph_builder.add_edge(start_key=START, end_key="LLM")
|
210 |
+
graph_builder.add_conditional_edges(
|
211 |
+
source="LLM",
|
212 |
+
path=self.__route_from_llm,
|
213 |
+
path_map={
|
214 |
+
"tools": "tools",
|
215 |
+
"extract_final_answer": "extract_final_answer"
|
216 |
+
}
|
217 |
+
)
|
218 |
+
graph_builder.add_edge(start_key="tools", end_key="LLM")
|
219 |
+
graph_builder.add_edge(start_key="extract_final_answer", end_key=END)
|
220 |
+
|
221 |
+
return graph_builder.compile()
|
app.py
CHANGED
@@ -1,25 +1,125 @@
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
|
|
2 |
import gradio as gr
|
3 |
import requests
|
4 |
-
import inspect
|
5 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# (Keep Constants as is)
|
8 |
# --- Constants ---
|
9 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
# --- Basic Agent Definition ---
|
12 |
-
# ----- THIS IS
|
13 |
class BasicAgent:
|
|
|
|
|
|
|
|
|
14 |
def __init__(self):
|
|
|
|
|
15 |
print("BasicAgent initialized.")
|
|
|
16 |
def __call__(self, question: str) -> str:
|
17 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
18 |
-
fixed_answer = "This is a default answer."
|
19 |
-
print(f"Agent returning fixed answer: {fixed_answer}")
|
20 |
-
return fixed_answer
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
"""
|
24 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
25 |
and displays the results.
|
@@ -27,15 +127,10 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
27 |
# --- Determine HF Space Runtime URL and Repo URL ---
|
28 |
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
print(f"User logged in: {username}")
|
33 |
-
else:
|
34 |
-
print("User not logged in.")
|
35 |
-
return "Please Login to Hugging Face with the button.", None
|
36 |
|
37 |
api_url = DEFAULT_API_URL
|
38 |
-
questions_url = f"{api_url}/questions"
|
39 |
submit_url = f"{api_url}/submit"
|
40 |
|
41 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
@@ -44,38 +139,21 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
44 |
except Exception as e:
|
45 |
print(f"Error instantiating agent: {e}")
|
46 |
return f"Error initializing agent: {e}", None
|
47 |
-
# In the case of an app running as a hugging Face space, this link points
|
|
|
48 |
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
49 |
print(agent_code)
|
50 |
|
51 |
-
# 2. Fetch Questions
|
52 |
-
|
53 |
-
try:
|
54 |
-
response = requests.get(questions_url, timeout=15)
|
55 |
-
response.raise_for_status()
|
56 |
-
questions_data = response.json()
|
57 |
-
if not questions_data:
|
58 |
-
print("Fetched questions list is empty.")
|
59 |
-
return "Fetched questions list is empty or invalid format.", None
|
60 |
-
print(f"Fetched {len(questions_data)} questions.")
|
61 |
-
except requests.exceptions.RequestException as e:
|
62 |
-
print(f"Error fetching questions: {e}")
|
63 |
-
return f"Error fetching questions: {e}", None
|
64 |
-
except requests.exceptions.JSONDecodeError as e:
|
65 |
-
print(f"Error decoding JSON response from questions endpoint: {e}")
|
66 |
-
print(f"Response text: {response.text[:500]}")
|
67 |
-
return f"Error decoding server response for questions: {e}", None
|
68 |
-
except Exception as e:
|
69 |
-
print(f"An unexpected error occurred fetching questions: {e}")
|
70 |
-
return f"An unexpected error occurred fetching questions: {e}", None
|
71 |
|
72 |
-
# 3. Run your Agent
|
73 |
results_log = []
|
74 |
answers_payload = []
|
75 |
print(f"Running agent on {len(questions_data)} questions...")
|
76 |
for item in questions_data:
|
77 |
task_id = item.get("task_id")
|
78 |
-
question_text =
|
79 |
if not task_id or question_text is None:
|
80 |
print(f"Skipping item with missing task_id or question: {item}")
|
81 |
continue
|
@@ -91,6 +169,11 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
91 |
print("Agent did not produce any answers to submit.")
|
92 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
94 |
# 4. Prepare Submission
|
95 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
96 |
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
@@ -98,8 +181,17 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
98 |
|
99 |
# 5. Submit
|
100 |
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
|
|
|
|
|
|
|
|
101 |
try:
|
102 |
-
response = requests.post(
|
|
|
|
|
|
|
|
|
|
|
103 |
response.raise_for_status()
|
104 |
result_data = response.json()
|
105 |
final_status = (
|
@@ -193,4 +285,4 @@ if __name__ == "__main__":
|
|
193 |
print("-"*(60 + len(" App Starting ")) + "\n")
|
194 |
|
195 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
196 |
-
demo.launch(debug=True, share=False)
|
|
|
1 |
+
import dotenv
|
2 |
+
import importlib.resources
|
3 |
+
import json
|
4 |
import os
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
import gradio as gr
|
8 |
import requests
|
|
|
9 |
import pandas as pd
|
10 |
+
from pathlib import Path
|
11 |
+
from langchain_core.messages import HumanMessage
|
12 |
+
from langgraph.graph import MessagesState
|
13 |
+
from langgraph.graph.graph import CompiledGraph
|
14 |
+
|
15 |
+
from agent_factory import AgentFactory
|
16 |
+
|
17 |
+
dotenv.load_dotenv()
|
18 |
+
|
19 |
+
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
|
20 |
+
HF_USERNAME = os.getenv("HF_USERNAME")
|
21 |
|
22 |
# (Keep Constants as is)
|
23 |
# --- Constants ---
|
24 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
25 |
|
26 |
+
DATA_PATH = Path(str(importlib.resources.files("data")))
|
27 |
+
QUESTIONS_FILE_PATH = DATA_PATH / "questions.jsonl"
|
28 |
+
AGENT_ANSWERS_FILE_PATH = DATA_PATH / "agent-answers.jsonl"
|
29 |
+
|
30 |
+
|
31 |
# --- Basic Agent Definition ---
|
32 |
+
# ----- THIS IS WHERE YOU CAN BUILD WHAT YOU WANT ------
|
33 |
class BasicAgent:
|
34 |
+
|
35 |
+
__agent_factory: AgentFactory
|
36 |
+
__agent: CompiledGraph
|
37 |
+
|
38 |
def __init__(self):
|
39 |
+
self.__agent_factory = AgentFactory()
|
40 |
+
self.__agent = self.__agent_factory.get()
|
41 |
print("BasicAgent initialized.")
|
42 |
+
|
43 |
def __call__(self, question: str) -> str:
|
44 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
|
|
|
|
|
|
45 |
|
46 |
+
initial_state = MessagesState(
|
47 |
+
messages=[
|
48 |
+
self.__agent_factory.system_prompt,
|
49 |
+
HumanMessage(content=question)
|
50 |
+
]
|
51 |
+
)
|
52 |
+
|
53 |
+
final_state = self.__agent.invoke(input=initial_state)
|
54 |
+
|
55 |
+
answer = final_state["messages"][-1].content
|
56 |
+
|
57 |
+
print(f"Agent returning answer: {answer}")
|
58 |
+
return answer
|
59 |
+
|
60 |
+
|
61 |
+
def download_questions_and_files() -> dict[str, Any]:
|
62 |
+
api_url = DEFAULT_API_URL
|
63 |
+
questions_url = f"{api_url}/questions"
|
64 |
+
files_base_url = f"{api_url}/files"
|
65 |
+
|
66 |
+
print(f"Fetching questions from: {questions_url}")
|
67 |
+
try:
|
68 |
+
response = requests.get(questions_url, timeout=15)
|
69 |
+
response.raise_for_status()
|
70 |
+
questions_data = response.json()
|
71 |
+
if not questions_data:
|
72 |
+
print("Fetched questions list is empty.")
|
73 |
+
return {
|
74 |
+
"error": "Fetched questions list is empty or invalid format."
|
75 |
+
}
|
76 |
+
print(f"Fetched {len(questions_data)} questions.")
|
77 |
+
except requests.exceptions.RequestException as e:
|
78 |
+
print(f"Error fetching questions: {e}")
|
79 |
+
return {
|
80 |
+
"error": f"Error fetching questions: {e}"
|
81 |
+
}
|
82 |
+
except requests.exceptions.JSONDecodeError as e:
|
83 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
84 |
+
print(f"Response text: {response.text[:500]}")
|
85 |
+
return {
|
86 |
+
"error": f"Error decoding server response for questions: {e}"
|
87 |
+
}
|
88 |
+
except Exception as e:
|
89 |
+
print(f"An unexpected error occurred fetching questions: {e}")
|
90 |
+
return {
|
91 |
+
"error": f"An unexpected error occurred fetching questions: {e}"
|
92 |
+
}
|
93 |
+
|
94 |
+
# Save input questions and related files into the data subdirectory
|
95 |
+
try:
|
96 |
+
with open(QUESTIONS_FILE_PATH, mode="w") as f:
|
97 |
+
for cur_question in questions_data:
|
98 |
+
json.dump(cur_question, f)
|
99 |
+
f.write("\n")
|
100 |
+
|
101 |
+
file_name = cur_question["file_name"]
|
102 |
+
if len(file_name) > 0:
|
103 |
+
file_url = f"{files_base_url}/{cur_question["task_id"]}"
|
104 |
+
response = requests.get(file_url)
|
105 |
+
out_file_path = DATA_PATH / file_name
|
106 |
+
with open(out_file_path, 'wb') as file:
|
107 |
+
file.write(response.content)
|
108 |
+
except requests.exceptions.RequestException as e:
|
109 |
+
print(f"Error fetching question-related file: {e}")
|
110 |
+
return {
|
111 |
+
"error": f"Error fetching question-related file: {e}"
|
112 |
+
}
|
113 |
+
except Exception as e:
|
114 |
+
print(f"An unexpected error occurred fetching question-related file: {e}")
|
115 |
+
return {
|
116 |
+
"error": f"An unexpected error occurred fetching question-related file: {e}"
|
117 |
+
}
|
118 |
+
|
119 |
+
return questions_data
|
120 |
+
|
121 |
+
|
122 |
+
def run_and_submit_all() -> tuple[str, pd.DataFrame | None]:
|
123 |
"""
|
124 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
125 |
and displays the results.
|
|
|
127 |
# --- Determine HF Space Runtime URL and Repo URL ---
|
128 |
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
129 |
|
130 |
+
username= f"{HF_USERNAME}"
|
131 |
+
print(f"User: {username}")
|
|
|
|
|
|
|
|
|
132 |
|
133 |
api_url = DEFAULT_API_URL
|
|
|
134 |
submit_url = f"{api_url}/submit"
|
135 |
|
136 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
|
|
139 |
except Exception as e:
|
140 |
print(f"Error instantiating agent: {e}")
|
141 |
return f"Error initializing agent: {e}", None
|
142 |
+
# In the case of an app running as a hugging Face space, this link points
|
143 |
+
# towards your codebase ( useful for others so please keep it public)
|
144 |
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
145 |
print(agent_code)
|
146 |
|
147 |
+
# 2. Fetch Questions and related files (they get saved into the data directory)
|
148 |
+
questions_data = download_questions_and_files()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
+
# 3. Run your Agent and save agent's answers for later review
|
151 |
results_log = []
|
152 |
answers_payload = []
|
153 |
print(f"Running agent on {len(questions_data)} questions...")
|
154 |
for item in questions_data:
|
155 |
task_id = item.get("task_id")
|
156 |
+
question_text = json.dumps(item)
|
157 |
if not task_id or question_text is None:
|
158 |
print(f"Skipping item with missing task_id or question: {item}")
|
159 |
continue
|
|
|
169 |
print("Agent did not produce any answers to submit.")
|
170 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
171 |
|
172 |
+
with open(AGENT_ANSWERS_FILE_PATH, mode="w") as f:
|
173 |
+
for cur_answer in answers_payload:
|
174 |
+
json.dump(cur_answer, f)
|
175 |
+
f.write("\n")
|
176 |
+
|
177 |
# 4. Prepare Submission
|
178 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
179 |
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
|
|
181 |
|
182 |
# 5. Submit
|
183 |
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
184 |
+
headers = {
|
185 |
+
"Authorization": f"Bearer {HF_ACCESS_TOKEN}",
|
186 |
+
"Content-Type": "application/json"
|
187 |
+
}
|
188 |
try:
|
189 |
+
response = requests.post(
|
190 |
+
submit_url,
|
191 |
+
json=submission_data,
|
192 |
+
headers=headers,
|
193 |
+
timeout=60
|
194 |
+
)
|
195 |
response.raise_for_status()
|
196 |
result_data = response.json()
|
197 |
final_status = (
|
|
|
285 |
print("-"*(60 + len(" App Starting ")) + "\n")
|
286 |
|
287 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
288 |
+
demo.launch(debug=True, share=False)
|
data/__init__.py
ADDED
File without changes
|
local_setup/__init__.py
ADDED
File without changes
|
local_setup/nltk_config.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
|
3 |
+
|
4 |
+
def download_nltk_packages() -> None:
|
5 |
+
nltk.download('punkt')
|
6 |
+
nltk.download('averaged_perceptron_tagger')
|
7 |
+
nltk.download('stopwords')
|
8 |
+
nltk.download('wordnet')
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
download_nltk_packages()
|
question_retriever.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.resources
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
|
6 |
+
__questions_path = (
|
7 |
+
Path(str(importlib.resources.files("data"))) / "questions.jsonl"
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
def get_question(task_id: str) -> str | None:
|
12 |
+
"""
|
13 |
+
Given the ID of one of the available questions, reads it from
|
14 |
+
the JSONL file where questions have been previously downloaded.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
task_id: The hash code of the question.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
The JSONL string with the required question.
|
21 |
+
"""
|
22 |
+
with open(__questions_path, 'r', encoding='utf-8') as file:
|
23 |
+
for line in file:
|
24 |
+
data = json.loads(line)
|
25 |
+
if data["task_id"] == task_id:
|
26 |
+
return line
|
27 |
+
|
28 |
+
return None
|
29 |
+
|
30 |
+
|
31 |
+
def get_all_questions() -> list[str]:
|
32 |
+
"""
|
33 |
+
Retrieves the list of all questions previously downloaded.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
The list of questions previously downloaded.
|
37 |
+
"""
|
38 |
+
questions = []
|
39 |
+
with open(__questions_path, 'r', encoding='utf-8') as file:
|
40 |
+
for line in file:
|
41 |
+
questions += [json.loads(line)]
|
42 |
+
return questions
|
requirements.txt
CHANGED
@@ -1,2 +1,31 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
beautifulsoup4
|
2 |
+
python-dotenv
|
3 |
+
duckduckgo-search
|
4 |
+
faiss-cpu
|
5 |
+
gradio[oauth]
|
6 |
+
helium
|
7 |
+
langchain
|
8 |
+
langchain_community
|
9 |
+
langchain-core
|
10 |
+
langchain-ollama
|
11 |
+
langchain-unstructured[local]
|
12 |
+
langgraph
|
13 |
+
opencv-python
|
14 |
+
pandas
|
15 |
+
pdfminer
|
16 |
+
pillow
|
17 |
+
pydantic
|
18 |
+
pytest
|
19 |
+
# https://www.youtube.com/watch?v=VgxnyKnB3qc
|
20 |
+
# https://github.com/juanbindez/pytubefix
|
21 |
+
pytubefix
|
22 |
+
requests
|
23 |
+
torch
|
24 |
+
transformers
|
25 |
+
ultralytics
|
26 |
+
# NOTE: For unstructured to work locally, install also system requirements
|
27 |
+
# according to what's told here:
|
28 |
+
# https://docs.unstructured.io/open-source/installation/full-installation
|
29 |
+
unstructured[all-docs]
|
30 |
+
youtube-transcript-api
|
31 |
+
|
tests/__init__.py
ADDED
File without changes
|
tests/data/__init__.py
ADDED
File without changes
|
tests/download_data.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import requests
|
3 |
+
|
4 |
+
from app import DEFAULT_API_URL
|
5 |
+
|
6 |
+
def main() -> None:
|
7 |
+
api_url = DEFAULT_API_URL
|
8 |
+
questions_url = f"{api_url}/questions"
|
9 |
+
dest_file = "data/questions.jsonl"
|
10 |
+
try:
|
11 |
+
response = requests.get(questions_url, timeout=15)
|
12 |
+
response.raise_for_status()
|
13 |
+
questions_data = response.json()
|
14 |
+
if not questions_data:
|
15 |
+
print("Fetched questions list is empty.")
|
16 |
+
return
|
17 |
+
print(f"Fetched {len(questions_data)} questions.")
|
18 |
+
except requests.exceptions.RequestException as e:
|
19 |
+
print(f"Error fetching questions: {e}")
|
20 |
+
return
|
21 |
+
|
22 |
+
with open(dest_file, mode="w") as f:
|
23 |
+
for item in questions_data:
|
24 |
+
json.dump(item, f)
|
25 |
+
f.write("\n")
|
26 |
+
|
27 |
+
print("Done.")
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
main()
|
tests/resources/__init__.py
ADDED
File without changes
|
tests/resources/african-penguins-kelp-gull.jpg
ADDED
![]() |
tests/resources/penguin.jpeg
ADDED
![]() |
tests/test_agent.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
2 |
+
from langgraph.graph import MessagesState
|
3 |
+
|
4 |
+
from agent_factory import AgentFactory
|
5 |
+
from question_retriever import get_question
|
6 |
+
|
7 |
+
|
8 |
+
def test_agent() -> None:
|
9 |
+
# given
|
10 |
+
# grocery list
|
11 |
+
task_id = "3cef3a44-215e-4aed-8e3b-b1e3f08063b7"
|
12 |
+
question = get_question(task_id=task_id)
|
13 |
+
|
14 |
+
agent_factory = AgentFactory()
|
15 |
+
agent = agent_factory.get()
|
16 |
+
|
17 |
+
initial_state = MessagesState(
|
18 |
+
messages=[
|
19 |
+
agent_factory.system_prompt,
|
20 |
+
HumanMessage(content=question)
|
21 |
+
]
|
22 |
+
)
|
23 |
+
|
24 |
+
final_state = agent.invoke(input=initial_state)
|
25 |
+
answer = final_state["messages"][-1].content
|
26 |
+
|
27 |
+
print(answer)
|
tests/test_download_questions_and_files.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app import download_questions_and_files
|
2 |
+
|
3 |
+
|
4 |
+
def test_download_questions_and_files() -> None:
|
5 |
+
download_questions_and_files()
|
6 |
+
print("Download success.")
|
tests/tools/__init__.py
ADDED
File without changes
|
tests/tools/test_audio_transcriber.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
2 |
+
from langgraph.graph import MessagesState
|
3 |
+
|
4 |
+
from agent_factory import AgentFactory
|
5 |
+
from question_retriever import get_question
|
6 |
+
|
7 |
+
|
8 |
+
def test_audio_transcriber() -> None:
|
9 |
+
# given
|
10 |
+
task_id = "1f975693-876d-457b-a649-393859e79bf3"
|
11 |
+
question = get_question(task_id=task_id)
|
12 |
+
|
13 |
+
agent_factory = AgentFactory()
|
14 |
+
agent = agent_factory.get()
|
15 |
+
|
16 |
+
initial_state = MessagesState(
|
17 |
+
messages=[
|
18 |
+
agent_factory.system_prompt,
|
19 |
+
HumanMessage(content=question)
|
20 |
+
]
|
21 |
+
)
|
22 |
+
|
23 |
+
# when
|
24 |
+
final_state = agent.invoke(input=initial_state)
|
25 |
+
|
26 |
+
# then
|
27 |
+
answer = final_state["messages"][-1].content
|
28 |
+
|
29 |
+
assert answer == "132,133,134,197,245"
|
30 |
+
|
31 |
+
|
32 |
+
def test_audio_transcriber_pie_recipe() -> None:
|
33 |
+
# given
|
34 |
+
task_id = "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3"
|
35 |
+
question = get_question(task_id=task_id)
|
36 |
+
|
37 |
+
agent_factory = AgentFactory()
|
38 |
+
agent = agent_factory.get()
|
39 |
+
|
40 |
+
initial_state = MessagesState(
|
41 |
+
messages=[
|
42 |
+
agent_factory.system_prompt,
|
43 |
+
HumanMessage(content=question)
|
44 |
+
]
|
45 |
+
)
|
46 |
+
|
47 |
+
# when
|
48 |
+
final_state = agent.invoke(input=initial_state)
|
49 |
+
|
50 |
+
# then
|
51 |
+
answer = final_state["messages"][-1].content
|
52 |
+
|
53 |
+
assert answer == "cornstarch,granulated sugar,lemon juice,ripe strawberries,vanilla extract"
|
tests/tools/test_bird_classifier.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.resources
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import pipeline
|
6 |
+
|
7 |
+
__resources_path = Path(str(importlib.resources.files("tests.resources")))
|
8 |
+
|
9 |
+
|
10 |
+
def test_bird_classifier_with_one_single_bird() -> None:
|
11 |
+
# given
|
12 |
+
img_file = __resources_path / "penguin.jpeg"
|
13 |
+
img = Image.open(img_file)
|
14 |
+
|
15 |
+
# when
|
16 |
+
pipe = pipeline(
|
17 |
+
task="image-classification",
|
18 |
+
model="dennisjooo/Birds-Classifier-EfficientNetB2"
|
19 |
+
)
|
20 |
+
|
21 |
+
result = pipe(img)
|
22 |
+
result = result[0]
|
23 |
+
|
24 |
+
# then
|
25 |
+
assert "penguin" in result["label"].lower()
|
26 |
+
|
27 |
+
|
28 |
+
def test_bird_classifier_with_multiple_birds() -> None:
|
29 |
+
# given
|
30 |
+
img_file = __resources_path / "african-penguins-kelp-gull.jpg"
|
31 |
+
img = Image.open(img_file)
|
32 |
+
|
33 |
+
# when
|
34 |
+
pipe = pipeline(
|
35 |
+
task="image-classification",
|
36 |
+
model="dennisjooo/Birds-Classifier-EfficientNetB2"
|
37 |
+
)
|
38 |
+
|
39 |
+
result = pipe(img)
|
40 |
+
result = result[0]
|
41 |
+
|
42 |
+
# then
|
43 |
+
assert "penguin" not in result["label"].lower()
|
tests/tools/test_excel_table_content_retriever.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
2 |
+
from langgraph.graph import MessagesState
|
3 |
+
|
4 |
+
from agent_factory import AgentFactory
|
5 |
+
from question_retriever import get_question
|
6 |
+
|
7 |
+
|
8 |
+
def test_excel_table_content_retriever() -> None:
|
9 |
+
# given
|
10 |
+
task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733"
|
11 |
+
question = get_question(task_id=task_id)
|
12 |
+
|
13 |
+
agent_factory = AgentFactory()
|
14 |
+
agent = agent_factory.get()
|
15 |
+
|
16 |
+
initial_state = MessagesState(
|
17 |
+
messages=[
|
18 |
+
agent_factory.system_prompt,
|
19 |
+
HumanMessage(content=question)
|
20 |
+
]
|
21 |
+
)
|
22 |
+
|
23 |
+
# when
|
24 |
+
final_state = agent.invoke(input=initial_state)
|
25 |
+
|
26 |
+
# then
|
27 |
+
answer = final_state["messages"][-1].content
|
28 |
+
|
29 |
+
assert answer.lower() == "89706"
|
tests/tools/test_math_reasoning.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
2 |
+
from langgraph.graph import MessagesState
|
3 |
+
|
4 |
+
from agent_factory import AgentFactory
|
5 |
+
from question_retriever import get_question
|
6 |
+
|
7 |
+
|
8 |
+
def test_math_reasoning() -> None:
|
9 |
+
# given
|
10 |
+
# Table describing operation
|
11 |
+
task_id = "6f37996b-2ac7-44b0-8e68-6d28256631b4"
|
12 |
+
question = get_question(task_id=task_id)
|
13 |
+
|
14 |
+
agent_factory = AgentFactory()
|
15 |
+
agent = agent_factory.get()
|
16 |
+
|
17 |
+
initial_state = MessagesState(
|
18 |
+
messages=[
|
19 |
+
agent_factory.system_prompt,
|
20 |
+
HumanMessage(content=question)
|
21 |
+
]
|
22 |
+
)
|
23 |
+
|
24 |
+
# when
|
25 |
+
final_state = agent.invoke(input=initial_state)
|
26 |
+
|
27 |
+
# then
|
28 |
+
answer = final_state["messages"][-1].content
|
29 |
+
|
30 |
+
assert answer.lower() == "b,e"
|
tests/tools/test_python_script_executor.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
2 |
+
from langgraph.graph import MessagesState
|
3 |
+
|
4 |
+
from agent_factory import AgentFactory
|
5 |
+
from question_retriever import get_question
|
6 |
+
|
7 |
+
|
8 |
+
def test_python_script_executor() -> None:
|
9 |
+
# given
|
10 |
+
task_id = "f918266a-b3e0-4914-865d-4faa564f1aef"
|
11 |
+
question = get_question(task_id=task_id)
|
12 |
+
|
13 |
+
agent_factory = AgentFactory()
|
14 |
+
agent = agent_factory.get()
|
15 |
+
|
16 |
+
initial_state = MessagesState(
|
17 |
+
messages=[
|
18 |
+
agent_factory.system_prompt,
|
19 |
+
HumanMessage(content=question)
|
20 |
+
]
|
21 |
+
)
|
22 |
+
|
23 |
+
# when
|
24 |
+
final_state = agent.invoke(input=initial_state)
|
25 |
+
|
26 |
+
# then
|
27 |
+
answer = final_state["messages"][-1].content
|
28 |
+
|
29 |
+
assert answer.lower() == "0"
|
tests/tools/test_string_reverser.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
2 |
+
from langgraph.graph import MessagesState
|
3 |
+
|
4 |
+
from agent_factory import AgentFactory
|
5 |
+
from question_retriever import get_question
|
6 |
+
|
7 |
+
|
8 |
+
def test_string_reverser() -> None:
|
9 |
+
# given
|
10 |
+
task_id = "2d83110e-a098-4ebb-9987-066c06fa42d0"
|
11 |
+
question = get_question(task_id=task_id)
|
12 |
+
|
13 |
+
agent_factory = AgentFactory()
|
14 |
+
agent = agent_factory.get()
|
15 |
+
|
16 |
+
initial_state = MessagesState(
|
17 |
+
messages=[
|
18 |
+
agent_factory.system_prompt,
|
19 |
+
HumanMessage(content=question)
|
20 |
+
]
|
21 |
+
)
|
22 |
+
|
23 |
+
final_state = agent.invoke(input=initial_state)
|
24 |
+
|
25 |
+
answer = final_state["messages"][-1].content
|
26 |
+
|
27 |
+
assert answer.lower() == "right"
|
tests/tools/test_web_page_info_retriever.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from tools import web_page_info_retriever
|
3 |
+
|
4 |
+
|
5 |
+
def test_web_document_info_retriever() -> None:
|
6 |
+
# given
|
7 |
+
web_url = "https://en.wikipedia.org/wiki/Albert_Einstein"
|
8 |
+
query = "Albert Einstein nobel prize"
|
9 |
+
|
10 |
+
# when
|
11 |
+
results = web_page_info_retriever.invoke({
|
12 |
+
"web_url": web_url,
|
13 |
+
"query": query
|
14 |
+
})
|
15 |
+
|
16 |
+
# then
|
17 |
+
all_text = " ".join(results)
|
18 |
+
assert "1922" in all_text
|
19 |
+
assert "1921" in all_text
|
20 |
+
assert "photoelectric" in all_text
|
21 |
+
assert "Nobel" in all_text
|
tests/tools/test_web_search.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
2 |
+
from langgraph.graph import MessagesState
|
3 |
+
|
4 |
+
from agent_factory import AgentFactory
|
5 |
+
from question_retriever import get_question
|
6 |
+
|
7 |
+
|
8 |
+
def test_web_search() -> None:
|
9 |
+
# given
|
10 |
+
task_id = "8e867cd7-cff9-4e6c-867a-ff5ddc2550be"
|
11 |
+
question = get_question(task_id=task_id)
|
12 |
+
|
13 |
+
agent_factory = AgentFactory()
|
14 |
+
agent = agent_factory.get()
|
15 |
+
|
16 |
+
initial_state = MessagesState(
|
17 |
+
messages=[
|
18 |
+
agent_factory.system_prompt,
|
19 |
+
HumanMessage(content=question)
|
20 |
+
]
|
21 |
+
)
|
22 |
+
|
23 |
+
# when
|
24 |
+
final_state = agent.invoke(input=initial_state)
|
25 |
+
answer = final_state["messages"][-1].content
|
26 |
+
|
27 |
+
# then
|
28 |
+
assert answer == "2"
|
tests/tools/test_whisper.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.resources
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
8 |
+
|
9 |
+
from question_retriever import get_question
|
10 |
+
from tools.data_helpers import get_file_path
|
11 |
+
|
12 |
+
__resources_path = Path(str(importlib.resources.files("data")))
|
13 |
+
|
14 |
+
|
15 |
+
def test_whisper() -> None:
|
16 |
+
|
17 |
+
task_id = "1f975693-876d-457b-a649-393859e79bf3"
|
18 |
+
question = json.loads(get_question(task_id=task_id))
|
19 |
+
|
20 |
+
audio_file = get_file_path(file_name=question["file_name"])
|
21 |
+
|
22 |
+
# cuda_available = torch.cuda.is_available()
|
23 |
+
cuda_available = False
|
24 |
+
device = "cuda:0" if cuda_available else "cpu"
|
25 |
+
torch_dtype = torch.float16 if cuda_available else torch.float32
|
26 |
+
|
27 |
+
model_id = "openai/whisper-large-v3-turbo"
|
28 |
+
|
29 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
30 |
+
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
31 |
+
)
|
32 |
+
model.to(device)
|
33 |
+
|
34 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
35 |
+
|
36 |
+
pipe = pipeline(
|
37 |
+
"automatic-speech-recognition",
|
38 |
+
model=model,
|
39 |
+
tokenizer=processor.tokenizer,
|
40 |
+
feature_extractor=processor.feature_extractor,
|
41 |
+
torch_dtype=torch_dtype,
|
42 |
+
device=device,
|
43 |
+
)
|
44 |
+
|
45 |
+
sample = audio_file
|
46 |
+
|
47 |
+
generate_kwargs = {
|
48 |
+
"return_timestamps": True,
|
49 |
+
}
|
50 |
+
|
51 |
+
result = pipe(sample, generate_kwargs=generate_kwargs)
|
52 |
+
|
53 |
+
print(result["text"])
|
tests/tools/test_youtube_transcript.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
2 |
+
from langgraph.graph import MessagesState
|
3 |
+
|
4 |
+
from agent_factory import AgentFactory
|
5 |
+
from question_retriever import get_question
|
6 |
+
|
7 |
+
|
8 |
+
def test_get_youtube_video_transcript() -> None:
|
9 |
+
# given
|
10 |
+
task_id = "9d191bce-651d-4746-be2d-7ef8ecadb9c2"
|
11 |
+
question = get_question(task_id=task_id)
|
12 |
+
|
13 |
+
agent_factory = AgentFactory()
|
14 |
+
agent = agent_factory.get()
|
15 |
+
|
16 |
+
initial_state = MessagesState(
|
17 |
+
messages=[
|
18 |
+
agent_factory.system_prompt,
|
19 |
+
HumanMessage(content=question)
|
20 |
+
]
|
21 |
+
)
|
22 |
+
|
23 |
+
# when
|
24 |
+
final_state = agent.invoke(input=initial_state)
|
25 |
+
answer = final_state["messages"][-1].content
|
26 |
+
|
27 |
+
# then
|
28 |
+
assert answer.lower() == "extremely"
|
tests/tools/test_youtube_video_analysis.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
2 |
+
from langgraph.graph import MessagesState
|
3 |
+
|
4 |
+
from agent_factory import AgentFactory
|
5 |
+
from question_retriever import get_question
|
6 |
+
|
7 |
+
|
8 |
+
def test_youtube_video_analysis() -> None:
|
9 |
+
# given
|
10 |
+
task_id = "a1e91b78-d3d8-4675-bb8d-62741b4b68a6"
|
11 |
+
question = get_question(task_id=task_id)
|
12 |
+
|
13 |
+
agent_factory = AgentFactory()
|
14 |
+
agent = agent_factory.get()
|
15 |
+
|
16 |
+
initial_state = MessagesState(
|
17 |
+
messages=[
|
18 |
+
agent_factory.system_prompt,
|
19 |
+
HumanMessage(content=question)
|
20 |
+
]
|
21 |
+
)
|
22 |
+
|
23 |
+
# when
|
24 |
+
final_state = agent.invoke(input=initial_state)
|
25 |
+
answer = final_state["messages"][-1].content
|
26 |
+
|
27 |
+
# then
|
28 |
+
assert answer.lower() == "2"
|
tests/tools/test_youtube_video_frame_sampler.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
from tools.youtube_helpers import youtube_video_frame_sampler, youtube_video_to_frame_captions
|
6 |
+
|
7 |
+
|
8 |
+
def test_youtube_video_frame_sampler() -> None:
|
9 |
+
# given
|
10 |
+
temp_frames_dir = tempfile.TemporaryDirectory()
|
11 |
+
dest_dir = temp_frames_dir.name
|
12 |
+
|
13 |
+
# when
|
14 |
+
youtube_video_frame_sampler(
|
15 |
+
addr="https://www.youtube.com/watch?v=L1vXCYZAYYM",
|
16 |
+
dest_dir=dest_dir
|
17 |
+
)
|
18 |
+
|
19 |
+
# then
|
20 |
+
dest_path = Path(dest_dir)
|
21 |
+
assert len(list(dest_path.glob('*.jpg'))) == 61
|
22 |
+
|
23 |
+
|
24 |
+
def test_youtube_video_captions_generator() -> None:
|
25 |
+
# given, when
|
26 |
+
captions_str = youtube_video_to_frame_captions(
|
27 |
+
addr="https://www.youtube.com/watch?v=L1vXCYZAYYM",
|
28 |
+
)
|
29 |
+
|
30 |
+
# then
|
31 |
+
captions = json.loads(captions_str)
|
tools/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .audio_transcriber import transcribe_audio_file
|
2 |
+
from .excel_table_content_retriever import get_excel_table_content
|
3 |
+
from .math_tools import sum_list
|
4 |
+
from .python_script_executor import execute_python_script
|
5 |
+
from .string_reverser import reverse_string
|
6 |
+
from .web_page_info_retriever import web_page_info_retriever
|
7 |
+
from .youtube_video_transcript_retriever import get_youtube_video_transcript
|
8 |
+
from .youtube_helpers import youtube_video_to_frame_captions
|
tools/audio_transcriber.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from langchain_core.tools import tool
|
4 |
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
5 |
+
|
6 |
+
from .data_helpers import get_file_path
|
7 |
+
|
8 |
+
|
9 |
+
@tool(parse_docstring=True)
|
10 |
+
def transcribe_audio_file(file_name: str) -> str:
|
11 |
+
"""
|
12 |
+
Transcribes an audio file to text.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
file_name: The name of the audio file. This is simply the file name,
|
16 |
+
not the full path.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
The transcribed text.
|
20 |
+
"""
|
21 |
+
# Specific setting for local run with GPU busy for the LLM (ollama)
|
22 |
+
cuda_available = False
|
23 |
+
device = "cuda:0" if cuda_available else "cpu"
|
24 |
+
torch_dtype = torch.float16 if cuda_available else torch.float32
|
25 |
+
|
26 |
+
model_id = "openai/whisper-large-v3-turbo"
|
27 |
+
|
28 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
29 |
+
model_id,
|
30 |
+
torch_dtype=torch_dtype,
|
31 |
+
low_cpu_mem_usage=True,
|
32 |
+
use_safetensors=True
|
33 |
+
)
|
34 |
+
model.to(device)
|
35 |
+
|
36 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
37 |
+
|
38 |
+
pipe = pipeline(
|
39 |
+
"automatic-speech-recognition",
|
40 |
+
model=model,
|
41 |
+
tokenizer=processor.tokenizer,
|
42 |
+
feature_extractor=processor.feature_extractor,
|
43 |
+
torch_dtype=torch_dtype,
|
44 |
+
device=device,
|
45 |
+
)
|
46 |
+
|
47 |
+
generate_kwargs = {
|
48 |
+
"return_timestamps": True,
|
49 |
+
}
|
50 |
+
|
51 |
+
file_path = get_file_path(file_name)
|
52 |
+
result = pipe(file_path, generate_kwargs=generate_kwargs)
|
53 |
+
|
54 |
+
return result["text"]
|
tools/data_helpers.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.resources
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
|
5 |
+
def get_file_path(file_name: str) -> str:
|
6 |
+
"""
|
7 |
+
Returns the full path of a question file.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
file_name: The file name specified into the question.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
The full path of the file that was previously downloaded.
|
14 |
+
"""
|
15 |
+
data_path = Path(str(importlib.resources.files("data")))
|
16 |
+
return str(data_path / file_name)
|
tools/excel_table_content_retriever.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from langchain_core.tools import tool
|
3 |
+
|
4 |
+
from .data_helpers import get_file_path
|
5 |
+
|
6 |
+
@tool(parse_docstring=True)
|
7 |
+
def get_excel_table_content(file_name: str) -> str:
|
8 |
+
"""
|
9 |
+
Given an Excel file name, it returns its content into a string.
|
10 |
+
Explicitly returns also the list of column names.
|
11 |
+
It assumes the file contains a table. Reads the first sheet with
|
12 |
+
pandas and returns its string representation.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
file_name: Name of the Excel file.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
String representation of the content of the first sheet of the
|
19 |
+
file.
|
20 |
+
"""
|
21 |
+
file_path = get_file_path(file_name)
|
22 |
+
df = pd.read_excel(io=file_path)
|
23 |
+
return str(df) + f"\nColumn names: {df.columns.tolist()}\n"
|
tools/math_tools.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.tools import tool
|
2 |
+
|
3 |
+
|
4 |
+
@tool(parse_docstring=True, return_direct=True)
|
5 |
+
def sum_list(numbers: list[float]) -> float:
|
6 |
+
"""
|
7 |
+
Sums the provided input numbers.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
numbers: The sequence of numbers to sum.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
The sum of the input numbers.
|
14 |
+
"""
|
15 |
+
return sum(numbers)
|
tools/python_script_executor.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess as sub
|
2 |
+
|
3 |
+
from langchain_core.tools import tool
|
4 |
+
|
5 |
+
from tools.data_helpers import get_file_path
|
6 |
+
|
7 |
+
|
8 |
+
@tool(parse_docstring=True)
|
9 |
+
def execute_python_script(file_name: str) -> str:
|
10 |
+
"""
|
11 |
+
Given the Python source file name, executes it and returns the output
|
12 |
+
captured from stdout.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
file_name: The name of the Python source to execute. This is only the
|
16 |
+
file name, not the full path.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
The execution output, captured from stdout.
|
20 |
+
"""
|
21 |
+
source_file = get_file_path(file_name)
|
22 |
+
result = sub.run(
|
23 |
+
args=["python", source_file],
|
24 |
+
capture_output=True,
|
25 |
+
encoding="utf-8"
|
26 |
+
)
|
27 |
+
return result.stdout
|
tools/string_reverser.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.tools import tool
|
2 |
+
|
3 |
+
|
4 |
+
@tool(parse_docstring=True)
|
5 |
+
def reverse_string(s: str) -> str:
|
6 |
+
"""
|
7 |
+
Returns the reverse of a string. Use this tool when you suspect that
|
8 |
+
the provided prompt is written with characters in reverse order.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
s: The input string.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
The output string. It is the input with character in reverse order.
|
15 |
+
"""
|
16 |
+
return s[::-1]
|
tools/video_sampling.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
6 |
+
|
7 |
+
# model_id = "Salesforce/blip-image-captioning-base"
|
8 |
+
model_id = "Salesforce/blip-image-captioning-large"
|
9 |
+
captioning_processor = BlipProcessor.from_pretrained(model_id)
|
10 |
+
captioning_model = BlipForConditionalGeneration.from_pretrained(model_id)
|
11 |
+
|
12 |
+
|
13 |
+
def extract_frames(video_path, output_folder, interval_ms=2000) -> None:
|
14 |
+
"""
|
15 |
+
Extracts frames from a video into an output folder at a specified time
|
16 |
+
interval. Frames are saved as *.jpg images.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
video_path: The file name of the video to sample.
|
20 |
+
output_folder: The output directory for the extracted frames.
|
21 |
+
interval_ms: The sampling interval in milliseconds.
|
22 |
+
NOTE: No anti-aliasing filter is applied.
|
23 |
+
"""
|
24 |
+
if not os.path.exists(output_folder):
|
25 |
+
os.makedirs(output_folder)
|
26 |
+
|
27 |
+
cap = cv2.VideoCapture(video_path)
|
28 |
+
fps = cap.get(cv2.CAP_PROP_FPS) # Get fps
|
29 |
+
# Compute sampling interval in number of frames to skip
|
30 |
+
interval_frames = int(fps * interval_ms * 0.001)
|
31 |
+
|
32 |
+
frame_count = 0
|
33 |
+
saved_frame_count = 0
|
34 |
+
|
35 |
+
while True:
|
36 |
+
ret, frame = cap.read()
|
37 |
+
if not ret:
|
38 |
+
break
|
39 |
+
|
40 |
+
# Keep only selected frames
|
41 |
+
if frame_count % interval_frames == 0:
|
42 |
+
frame_filename = os.path.join(
|
43 |
+
output_folder,
|
44 |
+
f"frame_{saved_frame_count:04d}.jpg"
|
45 |
+
)
|
46 |
+
cv2.imwrite(frame_filename, frame)
|
47 |
+
saved_frame_count += 1
|
48 |
+
|
49 |
+
frame_count += 1
|
50 |
+
|
51 |
+
cap.release()
|
52 |
+
|
53 |
+
|
54 |
+
def extract_frame_captions(
|
55 |
+
video_path,
|
56 |
+
interval_ms=2000
|
57 |
+
) -> str:
|
58 |
+
"""
|
59 |
+
Extracts frame captions from a video at a specified time
|
60 |
+
interval.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
video_path: The file name of the video to sample.
|
64 |
+
interval_ms: The sampling interval in milliseconds.
|
65 |
+
NOTE: No anti-aliasing filter is applied.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Frame descriptions as a list of strings.
|
69 |
+
"""
|
70 |
+
cap = cv2.VideoCapture(video_path)
|
71 |
+
fps = cap.get(cv2.CAP_PROP_FPS) # Get fps
|
72 |
+
# Compute sampling interval in number of frames to skip
|
73 |
+
interval_frames = int(fps * interval_ms * 0.001)
|
74 |
+
|
75 |
+
frame_count = 0
|
76 |
+
saved_frame_count = 0
|
77 |
+
|
78 |
+
captions = []
|
79 |
+
while True:
|
80 |
+
ret, frame = cap.read()
|
81 |
+
if not ret:
|
82 |
+
break
|
83 |
+
|
84 |
+
# Keep only selected frames
|
85 |
+
if frame_count % interval_frames == 0:
|
86 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
87 |
+
inputs = captioning_processor(
|
88 |
+
frame,
|
89 |
+
text="Detailed image description:",
|
90 |
+
return_tensors="pt"
|
91 |
+
)
|
92 |
+
out = captioning_model.generate(**inputs)
|
93 |
+
cur_caption = (
|
94 |
+
captioning_processor.decode(out[0], skip_special_tokens=True)
|
95 |
+
)
|
96 |
+
captions += [cur_caption]
|
97 |
+
saved_frame_count += 1
|
98 |
+
|
99 |
+
frame_count += 1
|
100 |
+
|
101 |
+
cap.release()
|
102 |
+
return json.dumps(captions)
|
tools/web_page_info_retriever.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
3 |
+
from langchain_community.vectorstores import FAISS
|
4 |
+
from langchain_community.vectorstores.utils import DistanceStrategy
|
5 |
+
from langchain_core.tools import tool
|
6 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
7 |
+
from langchain_unstructured import UnstructuredLoader
|
8 |
+
|
9 |
+
|
10 |
+
@tool(parse_docstring=True)
|
11 |
+
def web_page_info_retriever(
|
12 |
+
web_url: str,
|
13 |
+
query: str,
|
14 |
+
k: int = 10
|
15 |
+
) -> list[str]:
|
16 |
+
"""
|
17 |
+
Retrieves information on the fly from a web page.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
web_url: The url of the web page.
|
21 |
+
query: The user query.
|
22 |
+
k: The maximum number of documents to retrieve. Use a reasonable
|
23 |
+
number depending on the amount of context you want to retrieve.
|
24 |
+
Usually a number between 10 and 20 should suffice (but there is
|
25 |
+
no upper bound to this parameter).
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
A list of strings containing the most relevant documents retrieved.
|
29 |
+
"""
|
30 |
+
loader = UnstructuredLoader(web_url=web_url)
|
31 |
+
docs = loader.load()
|
32 |
+
|
33 |
+
embeddings = HuggingFaceEmbeddings(
|
34 |
+
model_name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
35 |
+
model_kwargs={"device": "cpu"}
|
36 |
+
)
|
37 |
+
|
38 |
+
index = faiss.IndexFlatIP(len(embeddings.embed_query("hello world")))
|
39 |
+
|
40 |
+
vector_store = FAISS(
|
41 |
+
embedding_function=embeddings,
|
42 |
+
index=index,
|
43 |
+
docstore= InMemoryDocstore(),
|
44 |
+
index_to_docstore_id={},
|
45 |
+
distance_strategy=DistanceStrategy.COSINE,
|
46 |
+
)
|
47 |
+
vector_store.add_documents(documents=docs)
|
48 |
+
|
49 |
+
retrieved_docs = vector_store.similarity_search_with_relevance_scores(
|
50 |
+
query=query,
|
51 |
+
k=k
|
52 |
+
)
|
53 |
+
|
54 |
+
sorted_contents = [
|
55 |
+
t[0].page_content
|
56 |
+
for t in sorted(retrieved_docs, key=lambda x: x[1], reverse=True)
|
57 |
+
]
|
58 |
+
|
59 |
+
return sorted_contents
|
tools/youtube_helpers.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
from langchain_core.tools import tool
|
5 |
+
from pytubefix import YouTube
|
6 |
+
|
7 |
+
from .video_sampling import extract_frames, extract_frame_captions
|
8 |
+
|
9 |
+
|
10 |
+
def download_video(url, output_path):
|
11 |
+
"""
|
12 |
+
Downloads the video into an output path.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
url: The URL of the YouTube video.
|
16 |
+
output_path: The output folder where to download the video.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
The file name of the downloaded video.
|
20 |
+
"""
|
21 |
+
yt = YouTube(url)
|
22 |
+
stream = yt.streams.get_lowest_resolution()
|
23 |
+
stream.download(output_path)
|
24 |
+
return os.path.join(output_path, stream.default_filename)
|
25 |
+
|
26 |
+
|
27 |
+
def youtube_video_frame_sampler(addr: str, dest_dir: str) -> None:
|
28 |
+
"""
|
29 |
+
Downsamples a YouTube video into frames. Saves the frames into a destination
|
30 |
+
directory. Returns the path to the destination directory.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
addr: The URL of the YouTube video.
|
34 |
+
dest_dir: The destination directory.
|
35 |
+
"""
|
36 |
+
|
37 |
+
temp_dir = tempfile.TemporaryDirectory()
|
38 |
+
download_path = temp_dir.name
|
39 |
+
|
40 |
+
video_path = download_video(addr, download_path)
|
41 |
+
extract_frames(video_path, dest_dir)
|
42 |
+
|
43 |
+
|
44 |
+
@tool(parse_docstring=True)
|
45 |
+
def youtube_video_to_frame_captions(addr: str) -> str:
|
46 |
+
"""
|
47 |
+
Analyzes video frames from a YouTube video and obtains
|
48 |
+
captions for each frame. This is useful when we need to
|
49 |
+
answer questions on the images shown in the video. It adds
|
50 |
+
computer vision capabilities to the LLM.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
addr: The URL of the YouTube video.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Frame descriptions as a list of strings.
|
57 |
+
"""
|
58 |
+
temp_dir = tempfile.TemporaryDirectory()
|
59 |
+
download_path = temp_dir.name
|
60 |
+
|
61 |
+
video_path = download_video(addr, download_path)
|
62 |
+
return extract_frame_captions(video_path)
|
tools/youtube_video_transcript_retriever.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import youtube_transcript_api as yt_api
|
2 |
+
from langchain_core.tools import tool
|
3 |
+
|
4 |
+
|
5 |
+
@tool(parse_docstring=True)
|
6 |
+
def get_youtube_video_transcript(addr: str) -> str:
|
7 |
+
"""
|
8 |
+
Given the address of a YouTube video, returns the transcript of the audio.
|
9 |
+
This is useful when we need to answer questions on what it is said into
|
10 |
+
the video.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
addr: The URL of the YouTube video.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
The transcript of the audio extracted from the video. Different items
|
17 |
+
of the transcript are separated by ";". These different items may be
|
18 |
+
sentences pronounced by different characters.
|
19 |
+
"""
|
20 |
+
video_id = addr.split(sep="=")[1]
|
21 |
+
ytt_api = yt_api.YouTubeTranscriptApi()
|
22 |
+
fetched_data = ytt_api.fetch(video_id)
|
23 |
+
result_transcript = ";".join([t.text for t in fetched_data])
|
24 |
+
return result_transcript
|