dawid-lorek commited on
Commit
09253eb
·
verified ·
1 Parent(s): 836efcb

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +43 -120
agent.py CHANGED
@@ -5,111 +5,60 @@ from typing import Any
5
 
6
  from llama_index.llms.openai import OpenAI
7
  from llama_index.core.agent.react import ReActAgent
8
- from llama_index.core.agent.workflow import AgentWorkflow
9
- from llama_index.core.tools import FunctionTool, ToolMetadata
10
 
11
- # Tool: DuckDuckGo Web Search
12
- from llama_index.tools.duckduckgo import DuckDuckGoSearchTool
13
 
14
- # Tool: Python code eval (for simple code/number/output questions)
15
  def eval_python_code(code: str) -> str:
16
- """
17
- Evaluate simple Python code and return result as string.
18
- Use for 'What is the output of this code?' or math.
19
- """
20
  try:
21
- # Only eval expressions (NOT exec for safety!)
22
  return str(eval(code, {"__builtins__": {}}))
23
  except Exception as e:
24
  return f"ERROR: {e}"
25
 
26
- # Tool: Strict output formatting
27
  def format_gaia_answer(answer: str, question: str = "") -> str:
28
- """Postprocess: GAIA strict answer format enforcement."""
29
  if not answer:
30
  return ""
31
- # Remove quotes/brackets/periods, apologies, "Final Answer:"
32
  answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
33
  answer = re.sub(r'(?i)i(\'?m| cannot| can\'t| unable to| apologize| not available|process the file).*', '', answer).strip()
34
  if answer.startswith('"') and answer.endswith('"'): answer = answer[1:-1]
35
  if answer.startswith('[') and answer.endswith(']'): answer = answer[1:-1]
36
  if not re.match(r'^[A-Za-z]+\.$', answer): answer = re.sub(r'\.$', '', answer)
37
- # Numeric
38
- if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
39
- num = re.search(r'(\$?\d[\d,\.]*)', answer)
40
- if num: return num.group(1).replace(',', '')
41
- # Surname/first name/code/city
42
- if 'first name' in question: return answer.split()[0]
43
- if 'surname' in question: return answer.split()[-1]
44
- if 'city' in question: return answer.split()[0]
45
- if re.search(r'IOC country code|award number|NASA', question, re.I):
46
- code = re.search(r'[A-Z0-9]{3,}', answer)
47
- if code: return code.group(0)
48
- if re.search(r'list|comma.*separated|page numbers', question, re.I):
49
- items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
50
- if 'page numbers' in question:
51
- nums = [int(x) for x in re.findall(r'\d+', answer)]
52
- return ', '.join(str(n) for n in sorted(nums))
53
- if 'ingredient' in question or 'vegetable' in question:
54
- merged = []
55
- skip = False
56
- for i, item in enumerate(items):
57
- if skip: skip = False; continue
58
- if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
59
- merged.append(f"{item} {items[i+1]}")
60
- skip = True
61
- else: merged.append(item)
62
- merged = sorted(set(merged))
63
- return ', '.join(merged)
64
- return ', '.join(items)
65
  return answer.strip().rstrip('.').strip()
66
 
67
- # Tool: OCR for images (incl. chessboards/screenshots)
68
- def ocr_image(file_path: str) -> str:
69
- """Extract text from image file."""
70
- from PIL import Image
71
- import pytesseract
72
- try:
73
- img = Image.open(file_path)
74
- return pytesseract.image_to_string(img)
75
- except Exception as e:
76
- return f"ERROR: {e}"
77
-
78
- # Tool: Audio transcription (Whisper)
79
- def transcribe_audio(file_path: str) -> str:
80
- """Transcribe audio file with Whisper."""
81
- try:
82
- import whisper
83
- model = whisper.load_model("base")
84
- result = model.transcribe(file_path)
85
- return result.get("text", "")
86
- except Exception as e:
87
- return f"ERROR: {e}"
88
-
89
- # Tool: YouTube video transcription
90
- def transcribe_youtube(url: str) -> str:
91
- """Download and transcribe a YouTube video (audio only)."""
92
- import tempfile, os
93
- try:
94
- import whisper
95
- import yt_dlp
96
- with tempfile.TemporaryDirectory() as tmpdir:
97
- ydl_opts = {'format': 'bestaudio/best', 'outtmpl': os.path.join(tmpdir, 'audio.%(ext)s')}
98
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
99
- ydl.download([url])
100
- audio_path = [os.path.join(tmpdir, f) for f in os.listdir(tmpdir) if f.startswith("audio")][0]
101
- model = whisper.load_model("base")
102
- result = model.transcribe(audio_path)
103
- return result.get("text", "")
104
- except Exception as e:
105
- return f"ERROR: {e}"
106
-
107
- # ---- LlamaIndex agent and workflow setup ----
108
-
109
- # 1. Initialize LLM
110
  llm = OpenAI(model="gpt-4o", api_key=os.environ.get("OPENAI_API_KEY"))
111
 
112
- # 2. Register tools
113
  tools = [
114
  DuckDuckGoSearchTool(),
115
  FunctionTool.from_defaults(
@@ -117,21 +66,6 @@ tools = [
117
  name="python_eval",
118
  description="Evaluate simple Python code and return result as string. Use for math or code output."
119
  ),
120
- FunctionTool.from_defaults(
121
- ocr_image,
122
- name="ocr_image",
123
- description="Extract text from an image file (provide file path)."
124
- ),
125
- FunctionTool.from_defaults(
126
- transcribe_audio,
127
- name="transcribe_audio",
128
- description="Transcribe an audio file using Whisper (provide file path)."
129
- ),
130
- FunctionTool.from_defaults(
131
- transcribe_youtube,
132
- name="transcribe_youtube",
133
- description="Download a YouTube video, extract and transcribe its audio using Whisper."
134
- ),
135
  FunctionTool.from_defaults(
136
  format_gaia_answer,
137
  name="format_gaia_answer",
@@ -139,7 +73,7 @@ tools = [
139
  ),
140
  ]
141
 
142
- # 3. Agent setup (ReAct, so can reason with tools)
143
  agent = ReActAgent.from_tools(
144
  tools=tools,
145
  llm=llm,
@@ -147,27 +81,16 @@ agent = ReActAgent.from_tools(
147
  verbose=False
148
  )
149
 
150
- # 4. Async entrypoint, suitable for HuggingFace Spaces or Gradio
151
  async def answer_question(question: str, task_id: str = None, file_path: str = None) -> str:
152
- """
153
- Main async function for the agent.
154
- Passes the question and uses tools as needed.
155
- - task_id: for future use, if you want to fetch files from a remote API.
156
- - file_path: if a file (image, audio, etc) is present locally, pass it.
157
- """
158
- # Example: if you want to always try OCR/audio on a file before reasoning, you could do:
159
- # If question contains "image" or "chess" and file_path is set, run OCR first
160
- if file_path and any(word in question.lower() for word in ["image", "chess", "screenshot"]):
161
- ocr_text = ocr_image(file_path)
162
- question = f"Extracted text from image: {ocr_text}\n\n{question}"
163
- if file_path and any(word in question.lower() for word in ["audio", "mp3", "transcribe"]):
164
- audio_text = transcribe_audio(file_path)
165
- question = f"Transcribed audio: {audio_text}\n\n{question}"
166
-
167
- # Run agent
168
  result = await agent.achat(question)
169
  return result.response
170
 
171
- # Synchronous wrapper for legacy compat
172
  def answer_question_sync(question: str, task_id: str = None, file_path: str = None) -> str:
173
- return asyncio.run(answer_question(question, task_id, file_path))
 
 
 
 
 
 
5
 
6
  from llama_index.llms.openai import OpenAI
7
  from llama_index.core.agent.react import ReActAgent
8
+ from llama_index.core.tools import FunctionTool
 
9
 
10
+ # Correct import for LlamaIndex >= 0.10
11
+ from llama_index.tools.duckduckgo_search import DuckDuckGoSearchTool
12
 
13
+ # Simple tool: Evaluate Python code for math/code questions
14
  def eval_python_code(code: str) -> str:
 
 
 
 
15
  try:
 
16
  return str(eval(code, {"__builtins__": {}}))
17
  except Exception as e:
18
  return f"ERROR: {e}"
19
 
20
+ # Strict output formatting for GAIA
21
  def format_gaia_answer(answer: str, question: str = "") -> str:
 
22
  if not answer:
23
  return ""
 
24
  answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
25
  answer = re.sub(r'(?i)i(\'?m| cannot| can\'t| unable to| apologize| not available|process the file).*', '', answer).strip()
26
  if answer.startswith('"') and answer.endswith('"'): answer = answer[1:-1]
27
  if answer.startswith('[') and answer.endswith(']'): answer = answer[1:-1]
28
  if not re.match(r'^[A-Za-z]+\.$', answer): answer = re.sub(r'\.$', '', answer)
29
+ if question:
30
+ if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
31
+ num = re.search(r'(\$?\d[\d,\.]*)', answer)
32
+ if num: return num.group(1).replace(',', '')
33
+ if 'first name' in question: return answer.split()[0]
34
+ if 'surname' in question: return answer.split()[-1]
35
+ if 'city' in question: return answer.split()[0]
36
+ if re.search(r'IOC country code|award number|NASA', question, re.I):
37
+ code = re.search(r'[A-Z0-9]{3,}', answer)
38
+ if code: return code.group(0)
39
+ if re.search(r'list|comma.*separated|page numbers', question, re.I):
40
+ items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
41
+ if 'page numbers' in question:
42
+ nums = [int(x) for x in re.findall(r'\d+', answer)]
43
+ return ', '.join(str(n) for n in sorted(nums))
44
+ if 'ingredient' in question or 'vegetable' in question:
45
+ merged = []
46
+ skip = False
47
+ for i, item in enumerate(items):
48
+ if skip: skip = False; continue
49
+ if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
50
+ merged.append(f"{item} {items[i+1]}")
51
+ skip = True
52
+ else: merged.append(item)
53
+ merged = sorted(set(merged))
54
+ return ', '.join(merged)
55
+ return ', '.join(items)
 
56
  return answer.strip().rstrip('.').strip()
57
 
58
+ # LLM setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  llm = OpenAI(model="gpt-4o", api_key=os.environ.get("OPENAI_API_KEY"))
60
 
61
+ # Tool registry
62
  tools = [
63
  DuckDuckGoSearchTool(),
64
  FunctionTool.from_defaults(
 
66
  name="python_eval",
67
  description="Evaluate simple Python code and return result as string. Use for math or code output."
68
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  FunctionTool.from_defaults(
70
  format_gaia_answer,
71
  name="format_gaia_answer",
 
73
  ),
74
  ]
75
 
76
+ # Main agent
77
  agent = ReActAgent.from_tools(
78
  tools=tools,
79
  llm=llm,
 
81
  verbose=False
82
  )
83
 
84
+ # Async entrypoint
85
  async def answer_question(question: str, task_id: str = None, file_path: str = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  result = await agent.achat(question)
87
  return result.response
88
 
89
+ # Synchronous wrapper
90
  def answer_question_sync(question: str, task_id: str = None, file_path: str = None) -> str:
91
+ return asyncio.run(answer_question(question, task_id, file_path))
92
+
93
+ # For compatibility with app.py (GAIAAgent class)
94
+ class GaiaAgent:
95
+ def __call__(self, question: str, task_id: str = None, file_path: str = None) -> str:
96
+ return answer_question_sync(question, task_id, file_path)