Spaces:
Sleeping
Sleeping
image
Browse files
agent.py
CHANGED
@@ -14,7 +14,8 @@ from tools import (
|
|
14 |
arxiv_search_tool,
|
15 |
audio_transcriber_tool,
|
16 |
excel_tool,
|
17 |
-
analyze_code_tool
|
|
|
18 |
)
|
19 |
|
20 |
# βββββββββββββββββββββββββββ Configuration βββββββββββββββββββββββββββββββ
|
@@ -41,7 +42,8 @@ def build_graph():
|
|
41 |
arxiv_search_tool,
|
42 |
audio_transcriber_tool,
|
43 |
excel_tool,
|
44 |
-
analyze_code_tool
|
|
|
45 |
]
|
46 |
|
47 |
# Create the react agent - it will use the system prompt from the messages
|
|
|
14 |
arxiv_search_tool,
|
15 |
audio_transcriber_tool,
|
16 |
excel_tool,
|
17 |
+
analyze_code_tool,
|
18 |
+
image_tool
|
19 |
)
|
20 |
|
21 |
# βββββββββββββββββββββββββββ Configuration βββββββββββββββββββββββββββββββ
|
|
|
42 |
arxiv_search_tool,
|
43 |
audio_transcriber_tool,
|
44 |
excel_tool,
|
45 |
+
analyze_code_tool,
|
46 |
+
image_tool
|
47 |
]
|
48 |
|
49 |
# Create the react agent - it will use the system prompt from the messages
|
app.py
CHANGED
@@ -13,14 +13,16 @@ from state import AgentState
|
|
13 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
14 |
|
15 |
SYSTEM_PROMPT = """
|
16 |
-
You are a general AI assistant. I will ask you a question.
|
17 |
-
Report your thoughts in brief, and finish your answer with the following template:
|
18 |
-
FINAL ANSWER: [YOUR FINAL ANSWER]
|
19 |
|
20 |
IMPORTANT: When using tools that require file access (such as audio_transcriber_tool, excel_tool, analyze_code_tool, or image_tool), ALWAYS use the task_id parameter only. Do NOT use any file names mentioned by the user - ignore them completely and only pass the task_id.
|
21 |
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
"""
|
25 |
|
26 |
|
|
|
13 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
14 |
|
15 |
SYSTEM_PROMPT = """
|
16 |
+
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
|
|
|
|
17 |
|
18 |
IMPORTANT: When using tools that require file access (such as audio_transcriber_tool, excel_tool, analyze_code_tool, or image_tool), ALWAYS use the task_id parameter only. Do NOT use any file names mentioned by the user - ignore them completely and only pass the task_id.
|
19 |
|
20 |
+
SEARCH STRATEGY:
|
21 |
+
- If wikipedia_search_tool fails or returns insufficient/irrelevant results, try these fallback strategies:
|
22 |
+
1. Try wikipedia_search_tool again with a broader, more general query (remove specific terms, use synonyms)
|
23 |
+
2. If Wikipedia still doesn't help, try arxiv_search_tool for academic/research topics
|
24 |
+
3. You can use multiple search attempts with different keywords to find better information
|
25 |
+
- Always evaluate if the search results are relevant and sufficient before proceeding to your final answer
|
26 |
"""
|
27 |
|
28 |
|
tools.py
CHANGED
@@ -47,78 +47,107 @@ def image_tool(task_id: str) -> str:
|
|
47 |
Returns: "OCR text + brief caption or an error message"
|
48 |
|
49 |
"""
|
50 |
-
print("
|
51 |
-
|
|
|
|
|
|
|
52 |
for ext in ("png", "jpg", "jpeg"):
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
57 |
|
58 |
if not local_img or not os.path.exists(local_img):
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
}
|
63 |
|
64 |
# 2) Read raw bytes
|
65 |
try:
|
|
|
66 |
with open(local_img, "rb") as f:
|
67 |
image_bytes = f.read()
|
|
|
68 |
except Exception as e:
|
69 |
-
|
70 |
-
|
|
|
71 |
|
72 |
# 3) Prepare HF Inference headers
|
73 |
hf_token = os.getenv("HF_TOKEN")
|
74 |
if not hf_token:
|
75 |
-
|
76 |
-
|
|
|
77 |
|
78 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
|
|
79 |
|
80 |
-
# 4) Call HF
|
81 |
ocr_text = ""
|
82 |
try:
|
|
|
83 |
ocr_resp = requests.post(
|
84 |
-
"https://api-inference.huggingface.co/models/
|
85 |
headers=headers,
|
86 |
files={"file": image_bytes},
|
87 |
timeout=30
|
88 |
)
|
|
|
89 |
ocr_resp.raise_for_status()
|
90 |
ocr_json = ocr_resp.json()
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
98 |
except Exception as e:
|
99 |
ocr_text = f"Error during HF OCR: {e}"
|
|
|
100 |
|
101 |
-
# 5) Call HF
|
102 |
caption = ""
|
103 |
try:
|
|
|
104 |
cap_resp = requests.post(
|
105 |
"https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-base",
|
106 |
headers=headers,
|
107 |
files={"file": image_bytes},
|
108 |
timeout=30
|
109 |
)
|
|
|
110 |
cap_resp.raise_for_status()
|
111 |
cap_json = cap_resp.json()
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
if not caption:
|
115 |
-
caption = "(no caption
|
|
|
116 |
except Exception as e:
|
117 |
caption = f"Error during HF captioning: {e}"
|
|
|
118 |
|
119 |
# 6) Combine OCR + caption
|
120 |
combined = f"OCR text:\n{ocr_text}\n\nImage caption:\n{caption}"
|
121 |
-
print("
|
122 |
return combined
|
123 |
|
124 |
@tool
|
@@ -289,7 +318,7 @@ def analyze_code_tool(task_id: str) -> str:
|
|
289 |
# """
|
290 |
# Expects: state["web_search_query"] is a nonβempty string.
|
291 |
# Returns: {"web_search_query": None, "web_search_result": <string>}.
|
292 |
-
# Retries up to 5 times on either a DuckDuckGo
|
293 |
# """
|
294 |
# print("reached web_search_tool")
|
295 |
# query = state.get("web_search_query", "")
|
|
|
47 |
Returns: "OCR text + brief caption or an error message"
|
48 |
|
49 |
"""
|
50 |
+
print(f"DEBUG: image_tool called with task_id: {task_id}")
|
51 |
+
|
52 |
+
local_img = None # Initialize the variable
|
53 |
+
|
54 |
+
# Try to download image file with different extensions
|
55 |
for ext in ("png", "jpg", "jpeg"):
|
56 |
+
print(f"DEBUG: Trying to download {task_id}.{ext}")
|
57 |
+
candidate = _download_file_for_task(task_id, ext)
|
58 |
+
if candidate:
|
59 |
+
local_img = candidate
|
60 |
+
print(f"DEBUG: Successfully downloaded image: {local_img}")
|
61 |
+
break
|
62 |
+
else:
|
63 |
+
print(f"DEBUG: Failed to download {task_id}.{ext}")
|
64 |
|
65 |
if not local_img or not os.path.exists(local_img):
|
66 |
+
error_msg = f"Error: No image file found for task_id {task_id} (tried png, jpg, jpeg extensions)"
|
67 |
+
print(f"DEBUG: {error_msg}")
|
68 |
+
return error_msg
|
|
|
69 |
|
70 |
# 2) Read raw bytes
|
71 |
try:
|
72 |
+
print(f"DEBUG: Reading image file: {local_img}")
|
73 |
with open(local_img, "rb") as f:
|
74 |
image_bytes = f.read()
|
75 |
+
print(f"DEBUG: Successfully read {len(image_bytes)} bytes from image")
|
76 |
except Exception as e:
|
77 |
+
error_msg = f"Error reading image file: {e}"
|
78 |
+
print(f"DEBUG: {error_msg}")
|
79 |
+
return error_msg
|
80 |
|
81 |
# 3) Prepare HF Inference headers
|
82 |
hf_token = os.getenv("HF_TOKEN")
|
83 |
if not hf_token:
|
84 |
+
error_msg = "Error: HF_TOKEN not set in environment."
|
85 |
+
print(f"DEBUG: {error_msg}")
|
86 |
+
return error_msg
|
87 |
|
88 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
89 |
+
print("DEBUG: HF token found, proceeding with API calls")
|
90 |
|
91 |
+
# 4) Call HF's vision-ocr to extract text
|
92 |
ocr_text = ""
|
93 |
try:
|
94 |
+
print("DEBUG: Calling HF OCR API...")
|
95 |
ocr_resp = requests.post(
|
96 |
+
"https://api-inference.huggingface.co/models/microsoft/trocr-base-printed",
|
97 |
headers=headers,
|
98 |
files={"file": image_bytes},
|
99 |
timeout=30
|
100 |
)
|
101 |
+
print(f"DEBUG: OCR API response status: {ocr_resp.status_code}")
|
102 |
ocr_resp.raise_for_status()
|
103 |
ocr_json = ocr_resp.json()
|
104 |
+
print(f"DEBUG: OCR API response: {ocr_json}")
|
105 |
+
|
106 |
+
# Handle different response formats
|
107 |
+
if isinstance(ocr_json, list) and len(ocr_json) > 0:
|
108 |
+
# If it's a list, take the first result
|
109 |
+
ocr_text = ocr_json[0].get("generated_text", "").strip()
|
110 |
+
elif isinstance(ocr_json, dict):
|
111 |
+
ocr_text = ocr_json.get("generated_text", "").strip()
|
112 |
+
|
113 |
+
if not ocr_text:
|
114 |
+
ocr_text = "(no visible text detected)"
|
115 |
+
print(f"DEBUG: Extracted OCR text: {ocr_text}")
|
116 |
except Exception as e:
|
117 |
ocr_text = f"Error during HF OCR: {e}"
|
118 |
+
print(f"DEBUG: OCR failed: {e}")
|
119 |
|
120 |
+
# 5) Call HF's image-captioning to get a brief description
|
121 |
caption = ""
|
122 |
try:
|
123 |
+
print("DEBUG: Calling HF Image Captioning API...")
|
124 |
cap_resp = requests.post(
|
125 |
"https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-base",
|
126 |
headers=headers,
|
127 |
files={"file": image_bytes},
|
128 |
timeout=30
|
129 |
)
|
130 |
+
print(f"DEBUG: Captioning API response status: {cap_resp.status_code}")
|
131 |
cap_resp.raise_for_status()
|
132 |
cap_json = cap_resp.json()
|
133 |
+
print(f"DEBUG: Captioning API response: {cap_json}")
|
134 |
+
|
135 |
+
# Handle different response formats
|
136 |
+
if isinstance(cap_json, list) and len(cap_json) > 0:
|
137 |
+
caption = cap_json[0].get("generated_text", "").strip()
|
138 |
+
elif isinstance(cap_json, dict):
|
139 |
+
caption = cap_json.get("generated_text", "").strip()
|
140 |
+
|
141 |
if not caption:
|
142 |
+
caption = "(no caption generated)"
|
143 |
+
print(f"DEBUG: Generated caption: {caption}")
|
144 |
except Exception as e:
|
145 |
caption = f"Error during HF captioning: {e}"
|
146 |
+
print(f"DEBUG: Captioning failed: {e}")
|
147 |
|
148 |
# 6) Combine OCR + caption
|
149 |
combined = f"OCR text:\n{ocr_text}\n\nImage caption:\n{caption}"
|
150 |
+
print(f"DEBUG: Final result: {combined}")
|
151 |
return combined
|
152 |
|
153 |
@tool
|
|
|
318 |
# """
|
319 |
# Expects: state["web_search_query"] is a nonβempty string.
|
320 |
# Returns: {"web_search_query": None, "web_search_result": <string>}.
|
321 |
+
# Retries up to 5 times on either a DuckDuckGo "202 Ratelimit" response or any exception (e.g. timeout).
|
322 |
# """
|
323 |
# print("reached web_search_tool")
|
324 |
# query = state.get("web_search_query", "")
|