notsakeeb commited on
Commit
5e86b1e
·
verified ·
1 Parent(s): 9220e8a

Update ToolSet.py

Browse files
Files changed (1) hide show
  1. ToolSet.py +368 -0
ToolSet.py CHANGED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import whisper
3
+ from pydantic import BaseModel, Field
4
+ from langchain_experimental.utilities import PythonREPL
5
+ import cv2
6
+ from yt_dlp import YoutubeDL
7
+ from ultralytics import YOLO
8
+ from typing import List, Dict
9
+ from typing import TypedDict, Annotated
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain.tools import Tool, tool
14
+
15
+ @tool
16
+ def multiply(a: float, b: float) -> float:
17
+ """Multiplies two numbers.
18
+ Args:
19
+ a (float): the first number
20
+ b (float): the second number
21
+ """
22
+ return a * b
23
+
24
+
25
+ @tool
26
+ def add(a: float, b: float) -> float:
27
+ """Adds two numbers.
28
+ Args:
29
+ a (float): the first number
30
+ b (float): the second number
31
+ """
32
+ return a + b
33
+
34
+
35
+ @tool
36
+ def subtract(a: float, b: float) -> int:
37
+ """Subtracts two numbers.
38
+ Args:
39
+ a (float): the first number
40
+ b (float): the second number
41
+ """
42
+ return a - b
43
+
44
+ @tool
45
+ def divide(a: float, b: float) -> float:
46
+ """Divides two numbers.
47
+ Args:
48
+ a (float): the first float number
49
+ b (float): the second float number
50
+ """
51
+ if b == 0:
52
+ raise ValueError("Cannot divided by zero.")
53
+ return a / b
54
+
55
+
56
+ @tool
57
+ def modulus(a: int, b: int) -> int:
58
+ """Get the modulus of two numbers.
59
+ Args:
60
+ a (int): the first number
61
+ b (int): the second number
62
+ """
63
+ return a % b
64
+
65
+
66
+ @tool
67
+ def power(a: float, b: float) -> float:
68
+ """Get the power of two numbers.
69
+ Args:
70
+ a (float): the first number
71
+ b (float): the second number
72
+ """
73
+ return a**b
74
+
75
+
76
+ @tool
77
+ def get_web_search_result(query: str) -> str:
78
+ """Fetches information from the internet (web) based on given query.
79
+
80
+ Args:
81
+ query: The search query.
82
+
83
+ Returns:
84
+ The search results.
85
+ """
86
+ print("get_web_search_result")
87
+ tavily_search = TavilySearchResults(max_results=3)
88
+ search_docs = tavily_search.invoke(query)
89
+ return{"web_search_results": search_docs}
90
+
91
+
92
+ @tool
93
+ def wiki_search(query: str) -> str:
94
+ """Search Wikipedia for a query and return maximum 5 results. Use this tool only if the query specifies Wiki or Wikipedia.
95
+ Args:
96
+ query: The search query.
97
+
98
+ Returns:
99
+ An array documents.
100
+ """
101
+ print("wiki_search")
102
+ search_docs = WikipediaLoader(query=query, load_max_docs=5).load()
103
+ formatted_search_docs = "\n\n---\n\n".join(
104
+ [
105
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
106
+ for doc in search_docs
107
+ ])
108
+ return {"wiki_results": formatted_search_docs}
109
+
110
+
111
+ @tool
112
+ def arxiv_search(query: str) -> str:
113
+ """Search Arxiv for a query and return maximum 3 result.
114
+
115
+ Args:
116
+ query: The search query.
117
+ Returns:
118
+ An array of documents
119
+ """
120
+ print("arxiv_search")
121
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
122
+ formatted_search_docs = "\n\n---\n\n".join(
123
+ [
124
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
125
+ for doc in search_docs
126
+ ])
127
+ return {"arxiv_results": formatted_search_docs}
128
+
129
+ @tool
130
+ def reverse_text(prompt: str) -> str:
131
+ """
132
+ Returns the reversed version of a given reversed text so that the text makes sense.
133
+
134
+ Args:
135
+ prompt: The prompt which contains word and sentence in a reverse order.
136
+
137
+ Returns:
138
+ A reversed version of the reversed sentence which is human readable and understandable.
139
+ """
140
+
141
+ print("restoring_text")
142
+ return prompt[::-1]
143
+
144
+
145
+ @tool
146
+ def transcribe_audio(file_path: str):
147
+ """
148
+ Transcribes an audio file to text using local Whisper model.
149
+
150
+ Args:
151
+ file_path: Path to the audio file
152
+
153
+ Returns:
154
+ A dictionary containing the transcription and metadata
155
+ """
156
+ try:
157
+ print(f"Transcribing audio file: {file_path}")
158
+
159
+ # Validate file exists
160
+ if not os.path.exists(file_path):
161
+ return {
162
+ "status": "error",
163
+ "message": f"File not found: {file_path}"
164
+ }
165
+
166
+ # Load a Whisper model - we'll use the small model for better performance
167
+ # Options include: tiny, base, small, medium, large
168
+ model = whisper.load_model("small")
169
+
170
+ # Transcribe the audio
171
+ result = model.transcribe(file_path)
172
+ print({
173
+ "status": "success",
174
+ "transcription": result["text"],
175
+ "language": result.get("language", "unknown"),
176
+ "file_path": file_path
177
+ })
178
+
179
+ # Return the transcription and metadata
180
+ return {
181
+ "status": "success",
182
+ "transcription": result["text"],
183
+ "language": result.get("language", "unknown"),
184
+ "file_path": file_path
185
+ }
186
+
187
+ except Exception as e:
188
+ print({
189
+ "status": "error",
190
+ "message": f"Error transcribing audio: {str(e)}"
191
+ })
192
+ return {
193
+ "status": "error",
194
+ "message": f"Error transcribing audio: {str(e)}"
195
+ }
196
+
197
+
198
+ class PythonREPLInput(BaseModel):
199
+ code: str = Field(description="The Python code string to execute.")
200
+
201
+ python_repl = PythonREPL()
202
+
203
+ python_repl_tool = Tool(
204
+ name="python_repl",
205
+ description="""A Python REPL shell (Read-Eval-Print Loop).
206
+ Use this to execute single or multi-line python commands.
207
+ Input should be syntactically valid Python code.
208
+ Always end your code with `print(...)` to see the output.
209
+ Do NOT execute code that could be harmful to the host system.
210
+ You are allowed to download files from URLs.
211
+ Do not use this tool as a web search.
212
+ Do NOT send commands that block indefinitely (e.g., `input()`).""",
213
+ func=python_repl.run,
214
+ args_schema=PythonREPLInput
215
+ )
216
+
217
+
218
+ class YouTubeFrameExtractor:
219
+ def __init__(self, model_path: str = 'yolov8n.pt', frame_rate: int = 1):
220
+ # Load YOLOv8 model
221
+ self.model = YOLO(model_path)
222
+ self.frame_rate = frame_rate # frames per second to sample
223
+
224
+ def download_video(self, url: str) -> str:
225
+ ydl_opts = {
226
+ 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',
227
+ 'outtmpl': '%(id)s.%(ext)s',
228
+ }
229
+ with YoutubeDL(ydl_opts) as ydl:
230
+ info = ydl.extract_info(url, download=True)
231
+ return ydl.prepare_filename(info)
232
+
233
+ def extract_counts_per_frame(self, url: str) -> List[Dict[str, int]]:
234
+ video_path = self.download_video(url)
235
+ cap = cv2.VideoCapture(video_path)
236
+ fps = cap.get(cv2.CAP_PROP_FPS)
237
+ sample_interval = max(1, int(round(fps / self.frame_rate)))
238
+
239
+ frame_counts: List[Dict[str, int]] = []
240
+ frame_idx = 0
241
+
242
+ while True:
243
+ ret, frame = cap.read()
244
+ if not ret:
245
+ break
246
+ if frame_idx % sample_interval == 0:
247
+ counts: Dict[str, int] = {}
248
+ results = self.model(frame)
249
+ for det in results:
250
+ for *box, conf, cls in det.boxes.data.tolist():
251
+ name = self.model.names[int(cls)]
252
+ counts[name] = counts.get(name, 0) + 1
253
+ frame_counts.append(counts)
254
+ frame_idx += 1
255
+
256
+ cap.release()
257
+ os.remove(video_path)
258
+ return frame_counts
259
+
260
+ def max_object_counter_tool() -> Tool:
261
+ extractor = YouTubeFrameExtractor()
262
+
263
+ def _max_object(input_str: str) -> str:
264
+ # Expect input: '<video_url> <object_name>'
265
+ parts = input_str.strip().split()
266
+ if len(parts) < 2:
267
+ return "Usage: <YouTube_URL> <object_name>"
268
+ url, obj_name = parts[0], parts[1]
269
+ frames = extractor.extract_counts_per_frame(url)
270
+ if not frames:
271
+ return "No frames processed or unable to download video."
272
+ # Compute max occurrences across frames
273
+ max_count = max(frame.get(obj_name, 0) for frame in frames)
274
+ return f"Maximum count of '{obj_name}' in any sampled frame: {max_count}"
275
+
276
+ return Tool(
277
+ name="youtube_max_object_counter",
278
+ func=_max_object,
279
+ description=(
280
+ "Downloads a YouTube video, samples frames at a given rate, runs YOLO detection, "
281
+ "and returns the maximum count of the specified object across all sampled frames."
282
+ )
283
+ )
284
+
285
+
286
+ class YouTubeTranscriber:
287
+ def __init__(self, model_size: str = "small"):
288
+ # Load Whisper model (tiny/base/small/medium/large/turbo)
289
+ self.model = whisper.load_model(model_size)
290
+
291
+ def download_audio(self, url: str) -> str:
292
+ """
293
+ Download only the audio from a YouTube URL and return the local filename.
294
+ """
295
+ ydl_opts = {
296
+ "format": "bestaudio/best", # best available audio :contentReference[oaicite:3]{index=3}
297
+ "postprocessors": [{
298
+ "key": "FFmpegExtractAudio", # extract with FFmpeg :contentReference[oaicite:4]{index=4}
299
+ "preferredcodec": "mp3",
300
+ "preferredquality": "192",
301
+ }],
302
+ "outtmpl": "%(id)s.%(ext)s", # name file as "<video_id>.mp3"
303
+ "quiet": True,
304
+ }
305
+ with YoutubeDL(ydl_opts) as ydl:
306
+ info = ydl.extract_info(url, download=True)
307
+ return f"{info['id']}.mp3"
308
+
309
+ def transcribe(self, audio_path: str, language: str = "en") -> str:
310
+ """
311
+ Run Whisper on the given audio file and return the transcript.
312
+ """
313
+ result = self.model.transcribe(
314
+ audio_path,
315
+ language=language,
316
+ without_timestamps=True
317
+ )
318
+ # os.remove(audio_path)
319
+ return result["text"]
320
+
321
+
322
+ def transcription_generation_tool() -> Tool:
323
+ """
324
+ Returns a LangChain Tool that takes a YouTube URL and optional language code,
325
+ then returns the transcription text.
326
+ """
327
+ transcriber = YouTubeTranscriber(model_size="small")
328
+
329
+ def _transcribe_tool(input_str: str) -> str:
330
+ # Expect: "<YouTube_URL> [language_code] "Question Text""
331
+ parts = input_str.strip().split()
332
+ url = parts[0]
333
+ lang = parts[1] if len(parts) > 2 and not input_str.split('"')[1] else "en"
334
+ # Extract question between quotes
335
+ question = input_str.split('"')[1]
336
+ try:
337
+ audio_file = transcriber.download_audio(url)
338
+ transcript = transcriber.transcribe(audio_file, language=lang)
339
+ os.remove(audio_file)
340
+ return transcript
341
+ except Exception as e:
342
+ return f"Error: {e}"
343
+
344
+ return Tool(
345
+ name="youtube_transcriber",
346
+ func=_transcribe_tool,
347
+ description=(
348
+ "Downloads audio from YouTube, transcribes it, and answers a question based on the transcript. "
349
+ "Usage: <YouTube_URL> [language_code] \"Question text\""
350
+ )
351
+ )
352
+
353
+ toolset = [
354
+ get_web_search_result,
355
+ wiki_search,
356
+ arxiv_search,
357
+ reverse_text,
358
+ transcribe_audio,
359
+ python_repl_tool,
360
+ multiply,
361
+ add,
362
+ subtract,
363
+ divide,
364
+ modulus,
365
+ power,
366
+ max_object_counter_tool(),
367
+ transcription_generation_tool()
368
+ ]