wt002 commited on
Commit
9363094
·
verified ·
1 Parent(s): 20c067b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +118 -191
agent.py CHANGED
@@ -1,219 +1,146 @@
1
  import os
2
- from dotenv import load_dotenv
3
- from langgraph.graph import START, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition
5
- from langgraph.prebuilt import ToolNode
6
- from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_openai import ChatOpenAI
8
- from langchain_google_genai import ChatGoogleGenerativeAI
9
- from langchain_huggingface import ChatHuggingFace
10
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
11
- from langchain_community.tools.tavily_search import TavilySearchResults
12
- from langchain_community.document_loaders import WikipediaLoader
13
- from langchain_community.document_loaders import ArxivLoader
14
- from langchain_community.vectorstores import SupabaseVectorStore
15
- from langchain_core.messages import SystemMessage, HumanMessage
16
- from langchain_core.tools import tool
17
- from langchain.tools.retriever import create_retriever_tool
18
- from supabase.client import Client, create_client
19
-
20
- load_dotenv()
21
 
22
- @tool
23
- def multiply(a: int, b: int) -> int:
24
- """Multiply two numbers.
25
- Args:
26
- a: first int
27
- b: second int
28
- """
29
- return a * b
30
 
31
- @tool
32
- def add(a: int, b: int) -> int:
33
- """Add two numbers.
34
-
35
- Args:
36
- a: first int
37
- b: second int
38
- """
39
- return a + b
40
 
41
  @tool
42
- def subtract(a: int, b: int) -> int:
43
- """Subtract two numbers.
44
-
45
  Args:
46
- a: first int
47
- b: second int
 
 
48
  """
49
- return a - b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  @tool
52
- def divide(a: int, b: int) -> int:
53
- """Divide two numbers.
54
-
55
  Args:
56
- a: first int
57
- b: second int
 
 
58
  """
59
- if b == 0:
60
- raise ValueError("Cannot divide by zero.")
61
- return a / b
 
 
 
 
 
 
 
 
 
 
62
 
63
  @tool
64
- def modulus(a: int, b: int) -> int:
65
- """Get the modulus of two numbers.
66
-
67
  Args:
68
- a: first int
69
- b: second int
 
 
70
  """
71
- return a % b
72
 
73
- @tool
74
- def wiki_search(query: str) -> str:
75
- """Search Wikipedia for a query and return maximum 2 results.
76
-
77
- Args:
78
- query: The search query."""
79
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
80
- formatted_search_docs = "\n\n---\n\n".join(
81
- [
82
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
- for doc in search_docs
84
- ])
85
- return {"wiki_results": formatted_search_docs}
 
 
86
 
87
- @tool
88
- def web_search(query: str) -> str:
89
- """Search Tavily for a query and return maximum 3 results.
90
-
91
- Args:
92
- query: The search query."""
93
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
94
- formatted_search_docs = "\n\n---\n\n".join(
95
- [
96
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
97
- for doc in search_docs
98
- ])
99
- return {"web_results": formatted_search_docs}
100
 
101
  @tool
102
- def arvix_search(query: str) -> str:
103
- """Search Arxiv for a query and return maximum 3 result.
104
-
105
  Args:
106
- query: The search query."""
107
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
108
- formatted_search_docs = "\n\n---\n\n".join(
109
- [
110
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
111
- for doc in search_docs
112
- ])
113
- return {"arvix_results": formatted_search_docs}
114
-
115
-
116
-
117
- # load the system prompt from the file
118
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
- system_prompt = f.read()
120
-
121
- # System message
122
- sys_msg = SystemMessage(content=system_prompt)
123
-
124
- # build a retriever
125
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
- supabase: Client = create_client(
127
- os.environ.get("SUPABASE_URL"),
128
- os.environ.get("SUPABASE_SERVICE_KEY"))
129
- vector_store = SupabaseVectorStore(
130
- client=supabase,
131
- embedding= embeddings,
132
- table_name="documents",
133
- query_name="match_documents_langchain",
134
- )
135
- create_retriever_tool = create_retriever_tool(
136
- retriever=vector_store.as_retriever(),
137
- name="Question Search",
138
- description="A tool to retrieve similar questions from a vector store.",
139
- )
140
-
141
-
142
-
143
- tools = [
144
- multiply,
145
- add,
146
- subtract,
147
- divide,
148
- modulus,
149
- wiki_search,
150
- web_search,
151
- arvix_search,
152
- ]
153
-
154
- # Build graph function
155
- def build_graph(provider: str = "huggingface", huggingface_model: str = "mistral"):
156
- """Build the graph with tool binding."""
157
-
158
- if provider == "google":
159
- llm = ChatGoogleGenerativeAI(
160
- model="gemini-2.0-flash",
161
- temperature=0,
162
- google_api_key=os.getenv("GOOGLE_API_KEY")
163
- )
164
 
165
- elif provider == "huggingface":
166
- if huggingface_model == "mistral":
167
- repo_id = "mistralai/Mistral-7B-Instruct-v0.1"
168
- elif huggingface_model == "llama":
169
- repo_id = "Meta-DeepLearning/llama-2-7b-chat-hf"
170
- else:
171
- raise ValueError("Unsupported Hugging Face model")
172
-
173
- hf_token = os.getenv("HF_TOKEN")
174
- headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
175
-
176
- llm = ChatHuggingFace(
177
- llm=HuggingFaceEndpoint(
178
- repo_id=repo_id,
179
- temperature=0,
180
- )
181
- )
182
 
183
- else:
184
- raise ValueError("Invalid provider. Choose 'google' or 'huggingface'.")
 
 
 
185
 
186
- return llm
187
 
188
- # ✅ Bind tools if defined
189
- llm_with_tools = llm.bind_tools(tools) # Make sure `tools` is defined/imported
190
- return llm_with_tools
191
 
192
- # Node
193
- def assistant(state: MessagesState):
194
- """Assistant node"""
195
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
196
 
197
- def retriever(state: MessagesState):
198
- """Retriever node"""
199
- similar_question = vector_store.similarity_search(state["messages"][0].content)
200
- example_msg = HumanMessage(
201
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
202
- )
203
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
204
-
205
- builder = StateGraph(MessagesState)
206
- builder.add_node("retriever", retriever)
207
- builder.add_node("assistant", assistant)
208
- builder.add_node("tools", ToolNode(tools))
209
- builder.add_edge(START, "retriever")
210
- builder.add_edge("retriever", "assistant")
211
- builder.add_conditional_edges(
212
- "assistant",
213
- tools_condition,
214
- )
215
- builder.add_edge("tools", "assistant")
216
 
217
- # Compile graph
218
- return builder.compile()
 
219
 
 
 
 
 
 
 
1
  import os
2
+ import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ import requests
5
+ from google import genai
6
+ from google.genai import types
7
+ from smolagents import tool
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
9
 
10
  @tool
11
+ def download_file_of_task_id(task_id: str, file_name: str) -> str:
12
+ """
13
+ Download a file associated with a specific task ID and save it to a temporary location.
14
  Args:
15
+ task_id (str): The unique identifier of the task associated with the file to download.
16
+ file_name (str): The name to assign to the downloaded file.
17
+ Returns:
18
+ str: Path to the downloaded file or an error message if the download fails.
19
  """
20
+
21
+ try:
22
+ # Create temporary file
23
+ temp_dir = tempfile.gettempdir()
24
+ filepath = os.path.join(temp_dir, file_name)
25
+
26
+ # Download the file
27
+ response = requests.get(f"https://agents-course-unit4-scoring.hf.space/files/{task_id}",
28
+ stream=True)
29
+ response.raise_for_status()
30
+
31
+ # Save the file
32
+ with open(filepath, 'wb') as f:
33
+ for chunk in response.iter_content(chunk_size=8192):
34
+ f.write(chunk)
35
+
36
+ return filepath
37
+ except Exception as e:
38
+ return f"Error downloading file: {e!s}"
39
+
40
 
41
  @tool
42
+ def analyze_audio_file(path_file_audio: str, query: str) -> str:
43
+ """
44
+ Analyzes an MP3 audio file to answer a specific query.
45
  Args:
46
+ path_file_audio (str): Path to the MP3 audio file to be analyzed.
47
+ query (str): Question or query to analyze the content of the audio file.
48
+ Returns:
49
+ str: The result of the analysis of audio.
50
  """
51
+
52
+ client = genai.Client(api_key=os.getenv("API_KEY"))
53
+
54
+ myfile = client.files.upload(file=path_file_audio)
55
+
56
+ response = client.models.generate_content(
57
+ model=os.getenv("GOOGLE_MODEL_ID"),
58
+ contents=[f"Carefully analyze the audio to answer the question correctly.\n\n The question is {query}",
59
+ myfile]
60
+ )
61
+
62
+ return response.text
63
+
64
 
65
  @tool
66
+ def analyze_youtube_video(url_youtube_video: str, query: str) -> str:
67
+ """
68
+ Analyzes a YouTube video using the provided query.
69
  Args:
70
+ url_youtube_video (str): URL of the YouTube video to analyze.
71
+ query (str): Query or question to analyze the content of the video.
72
+ Returns:
73
+ str: Result of the video analysis.
74
  """
 
75
 
76
+ client = genai.Client(api_key=os.getenv("API_KEY"))
77
+
78
+ response = client.models.generate_content(
79
+ model=f"models/{os.getenv('GOOGLE_MODEL_ID')}",
80
+ contents=types.Content(
81
+ parts=[
82
+ types.Part(
83
+ file_data=types.FileData(file_uri=url_youtube_video)
84
+ ),
85
+ types.Part(text=f"Carefully analyze each frame of the video to answer the question correctly.\n\n The question is {query}")
86
+ ]
87
+ )
88
+ )
89
+
90
+ return response.text
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  @tool
94
+ def analyze_image_file(path_file_image: str, query: str) -> str:
95
+ """
96
+ Analyzes an image file to answer a specific query.
97
  Args:
98
+ path_file_image (str): Path to the image file to be analyzed.
99
+ query (str): Question or query to analyze the content of the image file.
100
+ Returns:
101
+ str: The result of the analysis of audio.
102
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ client = genai.Client(api_key=os.getenv("API_KEY"))
105
+
106
+ myfile = client.files.upload(file=path_file_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ response = client.models.generate_content(
109
+ model=os.getenv('GOOGLE_MODEL_ID'),
110
+ contents=[myfile,
111
+ f"Carefully analyze the image file and think to answer the question correctly.\n\n The question is {query}"]
112
+ )
113
 
114
+ return response.text
115
 
 
 
 
116
 
117
+ @tool
118
+ def analyze_xlsx_file(file_path: str, query: str) -> str:
119
+ """
120
+ Analyze an Excel file using pandas and answer a question about it.
121
+ Args:
122
+ file_path: Path to the Excel file
123
+ query: Question about the data
124
+ Returns:
125
+ Analysis result or error message
126
+ """
127
 
128
+ try:
129
+ import pandas as pd
130
+
131
+ # Read the Excel file
132
+ df = pd.read_excel(file_path)
133
+
134
+ # Run various analyses based on the query
135
+ result = f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
136
+ result += f"Columns: {', '.join(df.columns)}\n\n"
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # Add summary statistics
139
+ result += "Summary statistics:\n"
140
+ result += str(df.describe())
141
 
142
+ return result
143
+ except ImportError:
144
+ return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
145
+ except Exception as e:
146
+ return f"Error analyzing Excel file: {e!s}"