laverdes commited on
Commit
a60e9fe
·
verified ·
1 Parent(s): a21e3ef

feat: smart_read_file, extract clean text, extra tools

Browse files
Files changed (1) hide show
  1. tools.py +178 -62
tools.py CHANGED
@@ -3,12 +3,14 @@ import base64
3
  import json
4
  import inspect
5
  import time
6
- from typing import Callable
 
 
7
 
8
  from datetime import datetime, timezone
 
9
 
10
  from langchain.tools import tool
11
-
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_core.messages import HumanMessage
14
  from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
@@ -21,6 +23,10 @@ from langchain_google_community import SpeechToTextLoader
21
  from langchain_community.tools import YouTubeSearchTool
22
  from youtube_transcript_api import YouTubeTranscriptApi
23
  from langchain_community.tools.file_management.read import ReadFileTool
 
 
 
 
24
 
25
  from basic_agent import print_conversation
26
 
@@ -115,28 +121,94 @@ def search_and_extract(query: str) -> list[dict]:
115
  return structured_results
116
 
117
 
118
- youtube_search_api = YouTubeSearchTool()
119
-
120
  @tool
121
- def youtube_search_tool(query: str, number_of_results:int=3) -> list:
122
- """Search YouTube for a query and return the top number_of_results."""
 
 
 
 
 
123
  if CUSTOM_DEBUG:
124
  print_tool_call(
125
- youtube_search_tool,
126
- tool_name='youtube_search_tool',
127
- args={'query': query, number_of_results: number_of_results},
128
  )
129
- response = youtube_search_api.run(f"{query},{number_of_results}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if CUSTOM_DEBUG:
131
- print_tool_response(response)
132
- return response
133
 
 
 
134
 
135
  def extract_video_id(url: str) -> str:
136
  parsed = urlparse(url)
137
  return parse_qs(parsed.query).get("v", [""])[0]
138
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  @tool
141
  def load_youtube_transcript(url: str) -> str:
142
  """Load a YouTube transcript using youtube_transcript_api."""
@@ -165,43 +237,21 @@ def load_youtube_transcript(url: str) -> str:
165
  return error_str
166
 
167
 
168
-
169
- gemini = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
170
 
171
  @tool
172
- def image_query_tool(image_path: str, question: str) -> str:
173
- """
174
- Uses Gemini Vision to answer a question about an image.
175
- - image_path: file path to the image to analyze (.png)
176
- - question: the query to ask about the image
177
- """
178
- try:
179
- base64_img = encode_image_to_base64(image_path)
180
- except OSError:
181
- response = f"OSError: Invalid argument (invalid image path or file format): {image_path}. Please provide a valid PNG image."
182
- print_tool_response(response)
183
- return response
184
-
185
- base64_img_str = f"data:image/png;base64,{base64_img}"
186
  if CUSTOM_DEBUG:
187
  print_tool_call(
188
- image_query_tool,
189
- tool_name='image_query_tool',
190
- args={'base64_image': base64_img_str[:100], 'question': question},
191
  )
192
- msg = HumanMessage(content=[
193
- {"type": "text", "text": question},
194
- {"type": "image_url", "image_url": base64_img_str},
195
- ])
196
- try:
197
- response = gemini.invoke([msg])
198
- except ChatGoogleGenerativeAIError:
199
- response = "ChatGoogleGenerativeAIError: Invalid argument provided to Gemini: 400 Provided image is not valid"
200
- print_tool_response(response)
201
- return response
202
  if CUSTOM_DEBUG:
203
- print_tool_response(response.content)
204
- return response.content
205
 
206
 
207
  @tool
@@ -223,43 +273,109 @@ def search_and_extract_from_wikipedia(query: str) -> list:
223
 
224
  @tool
225
  def transcribe_audio(file_path: str) -> list:
226
- """Transcribe audio from a file using Google Speech-to-Text."""
 
227
  if CUSTOM_DEBUG:
228
  print_tool_call(
229
  transcribe_audio,
230
  tool_name='transcribe_audio',
231
  args={'file_path': file_path},
232
  )
233
- project_id = os.getenv("GOOGLE_CLOUD_PROJECT_ID")
234
- loader = SpeechToTextLoader(
235
- project_id=project_id,
236
- file_path=file_path,
237
- is_long = False, # Set to True for long audio files
238
- )
 
 
 
 
 
 
 
 
 
 
239
 
240
- docs = loader.load()
241
- docs_content = [doc.page_content for doc in docs]
 
 
 
242
 
243
  if CUSTOM_DEBUG:
244
  print_tool_response(docs_content)
245
  return docs_content
246
 
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  read_tool = ReadFileTool()
249
 
250
 
251
  @tool
252
- def read_file_tool(file_path: str) -> str:
253
- """Read the content of a file. Use this tool to read .py, .csv, .md, text files, PDFs, etc."""
 
 
 
 
 
254
  if CUSTOM_DEBUG:
255
  print_tool_call(
256
- read_file_tool,
257
- tool_name='read_file_tool',
258
  args={'file_path': file_path},
259
  )
260
- response = read_tool.invoke({"file_path": file_path})
261
- if not os.path.exists(file_path):
262
- response = f"File not found: {file_path}"
263
- print_tool_response(response)
264
- print_tool_response(response)
265
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import json
4
  import inspect
5
  import time
6
+ import trafilatura
7
+
8
+ from typing import Callable, Union
9
 
10
  from datetime import datetime, timezone
11
+ from markitdown import MarkItDown
12
 
13
  from langchain.tools import tool
 
14
  from langchain_community.tools.tavily_search import TavilySearchResults
15
  from langchain_core.messages import HumanMessage
16
  from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
 
23
  from langchain_community.tools import YouTubeSearchTool
24
  from youtube_transcript_api import YouTubeTranscriptApi
25
  from langchain_community.tools.file_management.read import ReadFileTool
26
+ from langchain.chains.summarize import load_summarize_chain
27
+ from langchain.prompts import PromptTemplate
28
+ from langchain_core.documents import Document
29
+ from langchain_openai import ChatOpenAI
30
 
31
  from basic_agent import print_conversation
32
 
 
121
  return structured_results
122
 
123
 
 
 
124
  @tool
125
+ def aggregate_information(results: list[str], query: str) -> str:
126
+ """
127
+ Processes a list of unstructured text chunks (e.g., search results) and produces a concise, query-specific summary.
128
+
129
+ Each input text is filtered and summarized individually in the context of the provided query. Irrelevant results are discarded.
130
+ Relevant content is aggregated and synthesized into a final, coherent answer that directly addresses the query.
131
+ """
132
  if CUSTOM_DEBUG:
133
  print_tool_call(
134
+ aggregate_information,
135
+ tool_name='aggregate_information',
136
+ args={'results': results, 'query': query},
137
  )
138
+ if not results:
139
+ response = "No search results provided."
140
+ if CUSTOM_DEBUG:
141
+ print_tool_response(response)
142
+ return response
143
+
144
+ # Convert to LangChain Document objects
145
+ docs = [Document(page_content=chunk) for chunk in results]
146
+
147
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.2)
148
+
149
+ # Map Prompt — Summarize each document in light of the query
150
+ map_prompt = PromptTemplate.from_template(
151
+ "You are analyzing a search result in the context of the question: '{query}'.\n\n"
152
+ "Search result:\n{text}\n\n"
153
+ "Instructions:\n"
154
+ "- If the result contains information relevant to answering the query, summarize the relevant parts clearly.\n"
155
+ "- If the result is not helpful or unrelated, return 'IGNORE'.\n"
156
+ "- Do not include generic information or filler.\n"
157
+ "- Focus on extracting facts, key statements, or numbers that directly support the query.\n\n"
158
+ "Relevant Summary:"
159
+ )
160
+
161
+ # Combine Prompt — Aggregate the summaries to one final answer
162
+ combine_prompt = PromptTemplate.from_template(
163
+ "You are aggregating information to answer the following question: '{query}'.\n\n"
164
+ "Here are the summaries from filtered search results:\n{text}\n\n"
165
+ "Using the most relevant points, write a clear, concise, and complete answer to the original query.\n"
166
+ "If there's conflicting information, mention it briefly. Otherwise, focus on consensus.\n\n"
167
+ "Final Answer:"
168
+ )
169
+
170
+ chain = load_summarize_chain(
171
+ llm,
172
+ chain_type="map_reduce",
173
+ map_prompt=map_prompt.partial(query=query),
174
+ combine_prompt=combine_prompt.partial(query=query),
175
+ )
176
+
177
+ summary = chain.invoke({'input_documents': docs})
178
+ output_text = summary.get('output_text', str(summary))
179
+ output_text = json.dumps({'summary': output_text})
180
+
181
  if CUSTOM_DEBUG:
182
+ print_tool_response(output_text)
 
183
 
184
+ return output_text
185
+
186
 
187
  def extract_video_id(url: str) -> str:
188
  parsed = urlparse(url)
189
  return parse_qs(parsed.query).get("v", [""])[0]
190
 
191
 
192
+ @tool
193
+ def get_audio_from_youtube(urls: list[str], save_dir:str="./tmp/") -> list[str | PurePath | None]:
194
+ """Extracts audio from a YouTube video URL."""
195
+
196
+ if CUSTOM_DEBUG:
197
+ print_tool_call(
198
+ get_audio_from_youtube,
199
+ tool_name='get_audio_from_youtube',
200
+ args={'urls': urls, 'save_dir': save_dir},
201
+ )
202
+ loader = YoutubeAudioLoader(urls, save_dir)
203
+ audio_blobs = list(loader.yield_blobs())
204
+ paths = [str(blob.path) for blob in audio_blobs]
205
+
206
+ if CUSTOM_DEBUG:
207
+ print_tool_response(json.dumps({'paths': paths}))
208
+
209
+ return paths
210
+
211
+
212
  @tool
213
  def load_youtube_transcript(url: str) -> str:
214
  """Load a YouTube transcript using youtube_transcript_api."""
 
237
  return error_str
238
 
239
 
240
+ youtube_search_api = YouTubeSearchTool()
 
241
 
242
  @tool
243
+ def youtube_search_tool(query: str, number_of_results:int=3) -> list:
244
+ """Search YouTube for a query and return the top number_of_results."""
 
 
 
 
 
 
 
 
 
 
 
 
245
  if CUSTOM_DEBUG:
246
  print_tool_call(
247
+ youtube_search_tool,
248
+ tool_name='youtube_search_tool',
249
+ args={'query': query, number_of_results: number_of_results},
250
  )
251
+ response = youtube_search_api.run(f"{query},{number_of_results}")
 
 
 
 
 
 
 
 
 
252
  if CUSTOM_DEBUG:
253
+ print_tool_response(response)
254
+ return response
255
 
256
 
257
  @tool
 
273
 
274
  @tool
275
  def transcribe_audio(file_path: str) -> list:
276
+ """Transcribe audio from an audio file in file_path using Google Speech-to-Text."""
277
+ docs, docs_content = [], []
278
  if CUSTOM_DEBUG:
279
  print_tool_call(
280
  transcribe_audio,
281
  tool_name='transcribe_audio',
282
  args={'file_path': file_path},
283
  )
284
+ try:
285
+ loader = SpeechToTextLoader(
286
+ project_id=os.getenv("GOOGLE_CLOUD_PROJECT_ID"),
287
+ file_path=file_path,
288
+ is_long = False, # Set to True for long audio files
289
+ )
290
+
291
+ docs = loader.load()
292
+ except Exception as e:
293
+ print(f"Error loading audio file: {e}")
294
+ try:
295
+ loader = SpeechToTextLoader(
296
+ project_id=os.getenv("GOOGLE_CLOUD_PROJECT_ID"),
297
+ file_path=file_path,
298
+ is_long=True, # Set to True for long audio files
299
+ )
300
 
301
+ docs = loader.load()
302
+ except Exception as e:
303
+ docs_content = [f"Error loading audio file: {e}"]
304
+
305
+ docs_content = [doc.page_content for doc in docs] if docs else docs_content
306
 
307
  if CUSTOM_DEBUG:
308
  print_tool_response(docs_content)
309
  return docs_content
310
 
311
 
312
+ @tool
313
+ def extract_clean_text_from_url(url: str) -> str:
314
+ """Extract the main readable content from a webpage using trafilatura."""
315
+ if CUSTOM_DEBUG:
316
+ print_tool_call(
317
+ extract_clean_text_from_url,
318
+ tool_name='extract_clean_text_from_url',
319
+ args={'url': url},
320
+ )
321
+ downloaded = trafilatura.fetch_url(url)
322
+ response = ""
323
+ if not downloaded:
324
+ response = "Failed to download the page. Please check the URL."
325
+
326
+ if not "Failed" in response:
327
+ response = trafilatura.extract(downloaded)
328
+
329
+ response = response or "No meaningful content found."
330
+ if CUSTOM_DEBUG:
331
+ print_tool_response(response)
332
+ return response
333
+
334
+
335
  read_tool = ReadFileTool()
336
 
337
 
338
  @tool
339
+ def smart_read_file(file_path: str) -> str:
340
+ """
341
+ Smart tool to read a file based on its type.
342
+
343
+ - Use `read_file_tool` for simple text, CSV, code files.
344
+ - Use MarkItDown for PDFs, Word, Excel, HTML, and other complex formats.
345
+ """
346
  if CUSTOM_DEBUG:
347
  print_tool_call(
348
+ smart_read_file,
349
+ tool_name='smart_read_file',
350
  args={'file_path': file_path},
351
  )
352
+ _, ext = os.path.splitext(file_path.lower())
353
+
354
+ if ext in [".mp3", ".wav", ".m4a", ".flac"]:
355
+ # If the file is an audio file, transcribe it
356
+ return transcribe_audio.invoke({"file_path": file_path})
357
+
358
+ if ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp"]:
359
+ # If the file is an image, use image_query_tool to analyze it
360
+ q = "What can you tell me about this image?"
361
+ return image_query_tool.invoke({"image_path": file_path, "question": q})
362
+
363
+ if any(ext in url_pattern for url_pattern in ["http://", "https://", "www."]):
364
+ if "youtube.com/watch?v=" in file_path:
365
+ transcript = load_youtube_transcript.invoke({"url": file_path})
366
+ if "Error loading" in transcript:
367
+ return get_audio_from_youtube.invoke({'urls': [file_path], 'save_dir': './tmp/'})
368
+ else:
369
+ return extract_clean_text_from_url.invoke(file_path)
370
+
371
+ md = MarkItDown()
372
+ try:
373
+ result = md.convert(file_path)
374
+ result = result.text_content
375
+ except Exception as e:
376
+ # print("Error reading file with MarkItDown:", e)
377
+ result = read_tool.invoke({"file_path": file_path})
378
+
379
+ if CUSTOM_DEBUG:
380
+ print_tool_response(result)
381
+ return result