adriansanz commited on
Commit
40e41ce
·
verified ·
1 Parent(s): 92e40d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -365
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import dataclasses
5
 
6
  from langchain_core.language_models import LLM
7
- from typing import Optional, List
8
  import requests
9
  from typing import Dict
10
  import cv2
@@ -18,6 +18,18 @@ import re
18
  import json
19
  import hashlib
20
  from typing import Callable
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  class GeminiLLM(LLM):
23
  """Wrapper para usar Google Gemini como un LLM de LangChain."""
@@ -87,390 +99,66 @@ class GeminiLLM(LLM):
87
  return f"Error {response.status_code}: {response.text}"
88
 
89
 
90
- from langchain_core.prompts import PromptTemplate
91
- from langchain.chains import LLMChain
92
- import os
93
 
94
- gemini_llm = GeminiLLM()
95
 
96
- import os
97
- from math import sqrt
98
- from typing import Dict, List
99
- from langchain_community.tools.tavily_search import TavilySearchResults
100
- from langchain_community.document_loaders import WikipediaLoader
101
- from langchain_community.document_loaders import ArxivLoader
102
- import gradio as gr
103
- import requests
104
- import inspect
105
- import pandas as pd
106
- from langchain_core.documents import Document
107
- from smolagents import CodeAgent, tool, InferenceClientModel
108
 
 
109
  # (Keep Constants as is)
110
  # --- Constants ---
111
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
112
 
113
-
114
  @dataclasses.dataclass
115
  class WikiSourceDocument:
116
  source: str
117
  page: str
118
  page_content: str
119
 
 
120
  @tool
121
- def wiki_search(query: str, load_max_docs: int=3) -> List[Document]:
122
- """Search Wikipedia for a query and return maximum 2 results.
123
- Args:
124
- query: The search query.
125
- load_max_docs: The maximum number of documents to load."""
126
  search_docs = WikipediaLoader(query=query, load_max_docs=load_max_docs).load()
127
  return search_docs
128
 
129
  @tool
130
- def load_file(file_id: str) -> str:
131
- """Load a file from the Hugging Face Hub. It returns the content in bytes.
132
- Args:
133
- file_id: The file ID to load."""
134
- return requests.get(f"https://agents-course-unit4-scoring.hf.space/files/{file_id}").content
135
-
136
- @tool
137
- def web_search(query: str, max_results: int) -> Dict[str, str]:
138
- """Search Tavily for a query and return maximum 3 results.
139
- Args:
140
- query: The search query.
141
- max_results: The maximum number of results to return."""
142
  search_docs = TavilySearchResults(max_results=max_results).invoke(input=query)
143
  return {"web_results": search_docs}
144
 
145
-
146
  @tool
147
- def arxiv_search(query: str, load_max_docs: int) -> Dict[str, str]:
148
- """Search Arxiv for a query and return maximum 3 result.
149
- Args:
150
- query: The search query.
151
- load_max_docs: The maximum number of documents to load.
152
- """
153
  search_docs = ArxivLoader(query=query, load_max_docs=load_max_docs).load()
154
  formatted_search_docs = "\n\n---\n\n".join(
155
  [
156
- f'<Document Title="{doc.metadata["Title"]}" Published="{doc.metadata["Published"]}" Authors="{doc.metadata["Authors"]} Summary={doc.metadata["Summary"]}"/>\n{doc.page_content}\n</Document>'
 
 
157
  for doc in search_docs
158
  ]
159
  )
160
  return {"arxiv_results": formatted_search_docs}
161
 
162
-
163
- @tool
164
- def multiply(a: float, b: float) -> float:
165
- """
166
- Multiply two numbers and return the result.
167
- This function takes two floating-point numbers as arguments and
168
- returns their product. It performs basic multiplication.
169
- Args:
170
- a: The first number to be multiplied.
171
- b: The second number to be multiplied.
172
- """
173
- return a * b
174
-
175
-
176
- @tool
177
- def add(a: float, b: float) -> float:
178
- """
179
- Add two numbers and return the result.
180
- This function takes two floating-point numbers as arguments and
181
- returns their sum. It performs basic addition.
182
- Args:
183
- a: The first number to be added.
184
- b: The second number to be added.
185
- """
186
- return a + b
187
-
188
-
189
- @tool
190
- def subtract(a: float, b: float) -> float:
191
- """
192
- Subtracts two numbers.
193
- Args:
194
- a (float): the first number
195
- b (float): the second number
196
- """
197
- return a - b
198
-
199
-
200
- @tool
201
- def divide(a: float, b: float) -> float:
202
- """
203
- Divides two numbers.
204
- Args:
205
- a (float): the first float number
206
- b (float): the second float number
207
- """
208
- if b == 0:
209
- raise ValueError("Cannot divided by zero.")
210
- return a / b
211
-
212
-
213
- @tool
214
- def modulus(a: int, b: int) -> int:
215
- """
216
- Get the modulus of two numbers.
217
- Args:
218
- a (int): the first number
219
- b (int): the second number
220
- """
221
- return a % b
222
-
223
-
224
- @tool
225
- def power(a: float, b: float) -> float:
226
- """
227
- Get the power of two numbers.
228
- Args:
229
- a (float): the first number
230
- b (float): the second number
231
- """
232
- return a ** b
233
-
234
-
235
- @tool
236
- def square_root(a: float) -> float:
237
- """
238
- Get the square root of a number.
239
- Args:
240
- a (float): the number to get the square root of
241
- """
242
- if a >= 0:
243
- return a ** 0.5
244
- return sqrt(a)
245
-
246
-
247
- @tool
248
- def extract_numbers(text: str) -> List[float]:
249
- """
250
- Extract all numeric values from a given text.
251
- Args:
252
- text (str): Input text that may contain numbers.
253
- Returns:
254
- List[float]: A list of numbers found in the text.
255
- """
256
- import re
257
- return [float(num) for num in re.findall(r'\d+(?:\.\d+)?', text)]
258
-
259
- @tool
260
- def extract_keywords(text: str, top_n: int = 5) -> List[str]:
261
- """
262
- Extracts the most frequent keywords from a text (ignores very common words).
263
- Args:
264
- text (str): The input text.
265
- top_n (int): Number of keywords to return.
266
- Returns:
267
- List[str]: List of top keywords.
268
- """
269
- import re
270
- from collections import Counter
271
- stop_words = {"the", "a", "an", "and", "of", "in", "on", "for", "is", "at", "to", "by"}
272
- words = re.findall(r'\b[a-zA-Z]+\b', text.lower())
273
- filtered = [w for w in words if w not in stop_words]
274
- return [word for word, _ in Counter(filtered).most_common(top_n)]
275
- @tool
276
- def extract_names(text: str) -> List[str]:
277
- """
278
- Extracts words that start with a capital letter (possible names or surnames).
279
- Args:
280
- text (str): The input text.
281
- Returns:
282
- List[str]: List of unique candidate names.
283
- """
284
- import re
285
- names = re.findall(r'\b[A-Z][a-z]+\b', text)
286
- return list(dict.fromkeys(names))
287
-
288
- @tool
289
- def find_non_commutative_pairs(table: Dict[str, Dict[str, str]]) -> List[tuple]:
290
- """
291
- Finds pairs (a,b) where the operation * is not commutative.
292
- Args:
293
- table (dict): A nested dictionary representing the operation table.
294
- Returns:
295
- List[tuple]: List of pairs where a*b != b*a.
296
- """
297
- non_commutative = []
298
- elements = table.keys()
299
- for a in elements:
300
- for b in elements:
301
- if table[a][b] != table[b][a]:
302
- non_commutative.append((a, b))
303
- return non_commutative
304
-
305
- @tool
306
- def extract_dates(text: str) -> List[str]:
307
- """
308
- Extract dates from text and return them in ISO 8601 format (YYYY-MM-DD).
309
- Args:
310
- text (str): Input text.
311
- Returns:
312
- List[str]: List of dates as strings in ISO format.
313
- """
314
- import dateparser
315
- import re
316
-
317
- # Find all potential date substrings (simple heuristic)
318
- possible_dates = re.findall(r'\b(?:\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\w+ \d{1,2},? \d{4}|\d{4}-\d{2}-\d{2})\b', text)
319
- dates = []
320
- for d in possible_dates:
321
- parsed = dateparser.parse(d)
322
- if parsed:
323
- dates.append(parsed.strftime('%Y-%m-%d'))
324
- return dates
325
- @tool
326
- def normalize_text(text: str) -> str:
327
- """
328
- Normalize text: lowercase and remove punctuation.
329
- Args:
330
- text (str): Input text.
331
- Returns:
332
- str: Normalized text.
333
- """
334
- import string
335
- return text.lower().translate(str.maketrans('', '', string.punctuation))
336
-
337
- @tool
338
- def is_palindrome(text: str) -> bool:
339
- """
340
- Check if the given text is a palindrome (ignoring spaces and punctuation).
341
- Args:
342
- text (str): Input text.
343
- Returns:
344
- bool: True if palindrome, else False.
345
- """
346
- import re
347
- cleaned = re.sub(r'[\W_]+', '', text.lower())
348
- return cleaned == cleaned[::-1]
349
-
350
- @tool
351
- def filter_by_numeric_range(items: list, key: str, start: float, end: float) -> list:
352
- """
353
- Filter a list of dict-like objects by a numeric attribute in a given inclusive range.
354
-
355
- Args:
356
- items: List of dicts or objects with attribute `key`.
357
- key: Attribute/key to filter on.
358
- start: Start of range (inclusive).
359
- end: End of range (inclusive).
360
- Returns:
361
- Filtered list of items.
362
- """
363
- filtered = []
364
- for item in items:
365
- value = item.get(key) if isinstance(item, dict) else getattr(item, key, None)
366
- if value is not None and start <= value <= end:
367
- filtered.append(item)
368
- return filtered
369
-
370
-
371
- @tool
372
- def classify_items_by_list(items: list, category_a: list, category_b: list) -> dict:
373
- """
374
- Classify items into two categories based on membership.
375
-
376
- Args:
377
- items: List of items (strings).
378
- category_a: List of items for category A.
379
- category_b: List of items for category B.
380
- Returns:
381
- Dict with keys 'category_a' and 'category_b' listing matched items.
382
- """
383
- set_a = set(map(str.lower, category_a))
384
- set_b = set(map(str.lower, category_b))
385
- classified = {'category_a': [], 'category_b': []}
386
- for item in items:
387
- lower_item = item.lower()
388
- if lower_item in set_a:
389
- classified['category_a'].append(item)
390
- elif lower_item in set_b:
391
- classified['category_b'].append(item)
392
- return classified
393
-
394
- from typing import Dict
395
-
396
- @tool
397
- def web_search(query: str, max_results: int = 3) -> Dict[str, str]:
398
- """
399
- Perform a web search for a query and return up to max_results results as a dictionary.
400
-
401
- Args:
402
- query (str): The search query.
403
- max_results (int): Maximum number of results to return. Default is 3.
404
-
405
- Returns:
406
- Dict[str, str]: Dictionary with search results under the key "web_results".
407
- """
408
- search_docs = TavilySearchResults(max_results=max_results).invoke(input=query)
409
- return {"web_results": search_docs}
410
-
411
-
412
- from typing import List
413
-
414
- @tool
415
- def find_non_commutative_pairs(table: Dict[str, Dict[str, str]]) -> List[tuple]:
416
- """
417
- Finds pairs (a,b) where the operation * is not commutative.
418
- Args:
419
- table (dict): Nested dict representing operation table, e.g. table[a][b].
420
- Returns:
421
- List of pairs (a,b) where a*b != b*a.
422
- """
423
- non_commutative = []
424
- elements = list(table.keys())
425
- for a in elements:
426
- for b in elements:
427
- if table[a][b] != table[b][a]:
428
- non_commutative.append((a, b))
429
- return non_commutative
430
-
431
-
432
- # --- Basic Agent Definition ---
433
- # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
434
-
435
-
436
- # --- Helper para describir las tools ---
437
- def describe_tool(func: Callable) -> str:
438
- name = func.__name__
439
- sig = str(inspect.signature(func))
440
- doc = func.__doc__.strip().split('\n')[0] if func.__doc__ else "No description"
441
- return f"- {name}{sig}: {doc}"
442
-
443
  class BasicAgent:
444
- def __init__(self, llm=None, max_iterations=5):
445
  self.llm = llm or GeminiLLM()
 
446
  self.tools = {
447
  "wiki_search": wiki_search,
448
- "load_file": load_file,
449
  "web_search": web_search,
450
  "arxiv_search": arxiv_search,
451
- "multiply": multiply,
452
- "add": add,
453
- "subtract": subtract,
454
- "divide": divide,
455
- "modulus": modulus,
456
- "power": power,
457
- "square_root": square_root,
458
- "extract_numbers": extract_numbers,
459
  "extract_keywords": extract_keywords,
460
- "extract_names": extract_names,
461
- "find_non_commutative_pairs": find_non_commutative_pairs,
462
  "extract_dates": extract_dates,
 
463
  "normalize_text": normalize_text,
464
- "is_palindrome": is_palindrome,
465
- "filter_by_numeric_range": filter_by_numeric_range,
466
- "classify_items_by_list": classify_items_by_list,
467
  }
468
- # Cache para llamadas a tools
469
  self._cache = {}
470
  self.max_iterations = max_iterations
471
 
472
- # Construir prompt dinámico con info de tools
473
- tools_desc = "\n".join(describe_tool(f) for f in self.tools.values())
474
  prompt_str = (
475
  "You can use the following tools by calling them with syntax:\n"
476
  "tool:<tool_name>(arg1,arg2,...)\n\n"
@@ -481,10 +169,6 @@ class BasicAgent:
481
  self.prompt_template = PromptTemplate.from_template(prompt_str)
482
  self.chain = LLMChain(prompt=self.prompt_template, llm=self.llm)
483
 
484
- def register_tool(self, name: str, func: Callable):
485
- self.tools[name] = func
486
- print(f"[LOG] Registered new tool: {name}")
487
-
488
  def _cache_key(self, tool_name, args, kwargs):
489
  key_data = {"tool": tool_name, "args": args, "kwargs": kwargs}
490
  key_json = json.dumps(key_data, sort_keys=True, default=str)
@@ -492,33 +176,24 @@ class BasicAgent:
492
 
493
  def call_tool(self, tool_name: str, *args, **kwargs):
494
  func = self.tools.get(tool_name)
495
- if func is None:
496
- msg = f"Tool '{tool_name}' not found."
497
- print(f"[LOG] {msg}")
498
- return msg
499
-
500
  key = self._cache_key(tool_name, args, kwargs)
501
  if key in self._cache:
502
- print(f"[LOG] Returning cached result for tool '{tool_name}' with args={args} kwargs={kwargs}")
503
  return self._cache[key]
504
-
505
- func_name = getattr(func, "__name__", str(type(func)))
506
- print(f"[LOG] Calling tool: '{func_name}' with args={args} kwargs={kwargs}")
507
  try:
508
  result = func(*args, **kwargs)
509
- print(f"[LOG] Tool '{func_name}' returned: {result}")
510
  self._cache[key] = result
511
  return result
512
  except Exception as e:
513
- print(f"[ERROR] Tool '{func_name}' raised exception: {e}")
514
- return f"Error executing tool '{func_name}': {e}"
515
 
516
  def _parse_arg(self, arg: str):
517
  arg = arg.strip()
518
- if arg.lower() == "true":
519
- return True
520
- if arg.lower() == "false":
521
- return False
522
  try:
523
  return int(arg)
524
  except:
@@ -529,7 +204,6 @@ class BasicAgent:
529
  pass
530
  if (arg.startswith('"') and arg.endswith('"')) or (arg.startswith("'") and arg.endswith("'")):
531
  return arg[1:-1]
532
- # Intentar JSON para listas o dicts
533
  try:
534
  return json.loads(arg)
535
  except:
@@ -537,7 +211,6 @@ class BasicAgent:
537
  return arg
538
 
539
  def _run_once(self, text: str) -> (str, bool):
540
- # Ejecuta una iteración: LLM + ejecución tools
541
  llm_out = self.chain.run({"question": text})
542
  pattern = r"tool:(\w+)\((.*?)\)"
543
  tools_called = False
@@ -556,13 +229,18 @@ class BasicAgent:
556
 
557
  def __call__(self, question: str) -> str:
558
  text = question
559
- for i in range(self.max_iterations):
560
  text, used_tools = self._run_once(text)
561
  if not used_tools:
562
  break
563
  return text
564
 
565
 
 
 
 
 
 
566
  # --- Build Gradio Interface using Blocks ---
567
  def run_and_submit_all(profile: gr.OAuthProfile | None):
568
  """
 
4
  import dataclasses
5
 
6
  from langchain_core.language_models import LLM
7
+ from typing import Optional, List. Dict
8
  import requests
9
  from typing import Dict
10
  import cv2
 
18
  import json
19
  import hashlib
20
  from typing import Callable
21
+ from math import sqrt
22
+ from langchain_community.tools.tavily_search import TavilySearchResults
23
+ from langchain_community.document_loaders import WikipediaLoader
24
+ from langchain_community.document_loaders import ArxivLoader
25
+ import gradio as gr
26
+ import requests
27
+ import inspect
28
+ import pandas as pd
29
+ from langchain_core.documents import Document
30
+ from smolagents import CodeAgent, tool, InferenceClientModel
31
+ import dateparser
32
+ from collections import Counter
33
 
34
  class GeminiLLM(LLM):
35
  """Wrapper para usar Google Gemini como un LLM de LangChain."""
 
99
  return f"Error {response.status_code}: {response.text}"
100
 
101
 
 
 
 
102
 
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ gemini_llm = GeminiLLM()
106
  # (Keep Constants as is)
107
  # --- Constants ---
108
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
109
 
 
110
  @dataclasses.dataclass
111
  class WikiSourceDocument:
112
  source: str
113
  page: str
114
  page_content: str
115
 
116
+ # --- Herramientas de búsqueda ---
117
  @tool
118
+ def wiki_search(query: str, load_max_docs: int = 3) -> List[WikiSourceDocument]:
119
+ """Busca en Wikipedia y devuelve hasta load_max_docs resultados."""
 
 
 
120
  search_docs = WikipediaLoader(query=query, load_max_docs=load_max_docs).load()
121
  return search_docs
122
 
123
  @tool
124
+ def web_search(query: str, max_results: int = 3) -> Dict[str, str]:
125
+ """Busca en la web y devuelve hasta max_results resultados."""
 
 
 
 
 
 
 
 
 
 
126
  search_docs = TavilySearchResults(max_results=max_results).invoke(input=query)
127
  return {"web_results": search_docs}
128
 
 
129
  @tool
130
+ def arxiv_search(query: str, load_max_docs: int = 3) -> Dict[str, str]:
131
+ """Busca en Arxiv y devuelve hasta load_max_docs resultados formateados."""
 
 
 
 
132
  search_docs = ArxivLoader(query=query, load_max_docs=load_max_docs).load()
133
  formatted_search_docs = "\n\n---\n\n".join(
134
  [
135
+ f'<Document Title="{doc.metadata["Title"]}" Published="{doc.metadata["Published"]}" '
136
+ f'Authors="{doc.metadata["Authors"]}" Summary="{doc.metadata["Summary"]}"/>\n'
137
+ f'{doc.page_content}\n</Document>'
138
  for doc in search_docs
139
  ]
140
  )
141
  return {"arxiv_results": formatted_search_docs}
142
 
143
+ # --- Agente básico optimizado para preguntas ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  class BasicAgent:
145
+ def __init__(self, llm=None, max_iterations=3):
146
  self.llm = llm or GeminiLLM()
147
+ # Sólo herramientas de búsqueda y extracción textual clave
148
  self.tools = {
149
  "wiki_search": wiki_search,
 
150
  "web_search": web_search,
151
  "arxiv_search": arxiv_search,
 
 
 
 
 
 
 
 
152
  "extract_keywords": extract_keywords,
 
 
153
  "extract_dates": extract_dates,
154
+ "extract_names": extract_names,
155
  "normalize_text": normalize_text,
 
 
 
156
  }
 
157
  self._cache = {}
158
  self.max_iterations = max_iterations
159
 
160
+ # Descripción simplificada de herramientas para el prompt
161
+ tools_desc = "\n".join(f"- {name}: {func.__doc__.strip().splitlines()[0]}" for name, func in self.tools.items())
162
  prompt_str = (
163
  "You can use the following tools by calling them with syntax:\n"
164
  "tool:<tool_name>(arg1,arg2,...)\n\n"
 
169
  self.prompt_template = PromptTemplate.from_template(prompt_str)
170
  self.chain = LLMChain(prompt=self.prompt_template, llm=self.llm)
171
 
 
 
 
 
172
  def _cache_key(self, tool_name, args, kwargs):
173
  key_data = {"tool": tool_name, "args": args, "kwargs": kwargs}
174
  key_json = json.dumps(key_data, sort_keys=True, default=str)
 
176
 
177
  def call_tool(self, tool_name: str, *args, **kwargs):
178
  func = self.tools.get(tool_name)
179
+ if not func:
180
+ return f"Tool '{tool_name}' not found."
181
+
 
 
182
  key = self._cache_key(tool_name, args, kwargs)
183
  if key in self._cache:
 
184
  return self._cache[key]
185
+
 
 
186
  try:
187
  result = func(*args, **kwargs)
 
188
  self._cache[key] = result
189
  return result
190
  except Exception as e:
191
+ return f"Error executing tool '{tool_name}': {e}"
 
192
 
193
  def _parse_arg(self, arg: str):
194
  arg = arg.strip()
195
+ if arg.lower() in ("true", "false"):
196
+ return arg.lower() == "true"
 
 
197
  try:
198
  return int(arg)
199
  except:
 
204
  pass
205
  if (arg.startswith('"') and arg.endswith('"')) or (arg.startswith("'") and arg.endswith("'")):
206
  return arg[1:-1]
 
207
  try:
208
  return json.loads(arg)
209
  except:
 
211
  return arg
212
 
213
  def _run_once(self, text: str) -> (str, bool):
 
214
  llm_out = self.chain.run({"question": text})
215
  pattern = r"tool:(\w+)\((.*?)\)"
216
  tools_called = False
 
229
 
230
  def __call__(self, question: str) -> str:
231
  text = question
232
+ for _ in range(self.max_iterations):
233
  text, used_tools = self._run_once(text)
234
  if not used_tools:
235
  break
236
  return text
237
 
238
 
239
+
240
+
241
+
242
+
243
+
244
  # --- Build Gradio Interface using Blocks ---
245
  def run_and_submit_all(profile: gr.OAuthProfile | None):
246
  """