Spaces:
Sleeping
Sleeping
| # tools.py | |
| import pandas as pd | |
| # from langchain_community.tools import DuckDuckGoSearchRun | |
| from pathlib import Path | |
| # from PIL import Image | |
| # import pytesseract | |
| from old.old2state import AgentState | |
| from langchain.schema import HumanMessage | |
| import regex as re | |
| import time | |
| from duckduckgo_search import DDGS | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| def _download_file_for_task(task_id: str, ext: str) -> str: | |
| """ | |
| Helper: attempt to GET the remote file for a given task_id. | |
| Saves under ./hf_files/{task_id}.{ext}. Returns the local path if successful, | |
| or an empty string if no file / download failed. | |
| """ | |
| print("reached _download_file_for_task") | |
| os.makedirs("hf_files", exist_ok=True) | |
| local_path = os.path.join("hf_files", f"{task_id}.{ext}") | |
| url = f"{DEFAULT_API_URL}/files/{task_id}" | |
| try: | |
| resp = requests.get(url, timeout=10) | |
| if resp.status_code == 200 and resp.content: | |
| print(f"Downloaded file from {url} to {local_path}") | |
| with open(local_path, "wb") as f: | |
| f.write(resp.content) | |
| return local_path | |
| except Exception: | |
| pass | |
| # If we get here, either 404 or download error | |
| return "" | |
| def web_search_tool(state: AgentState) -> AgentState: | |
| """ | |
| Expects: state["web_search_query"] is a non‐empty string. | |
| Returns: {"web_search_query": None, "web_search_result": <string>}. | |
| Retries up to 5 times on either a DuckDuckGo “202 Ratelimit” response or any exception (e.g. timeout). | |
| """ | |
| print("reached web_search_tool") | |
| query = state.get("web_search_query", "") | |
| if not query: | |
| return {} # nothing to do | |
| ddg = DDGS() | |
| max_retries = 5 | |
| result_text = "" | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| result_text = str(ddg.text(query, max_results=5)) | |
| except Exception as e: | |
| # Network error or timeout—retry up to max_retries | |
| if attempt < max_retries: | |
| print(f"web_search_tool: exception '{e}', retrying in 4 seconds ({attempt}/{max_retries})") | |
| time.sleep(4) | |
| continue | |
| else: | |
| # Final attempt failed | |
| return { | |
| "web_search_query": None, | |
| "web_search_result": f"Error during DuckDuckGo search: {e}" | |
| } | |
| # Check for DuckDuckGo rate‐limit indicator | |
| if "202 Ratelimit" in result_text: | |
| if attempt < max_retries: | |
| print(f"web_search_tool: received '202 Ratelimit', retrying in 4 seconds ({attempt}/{max_retries})") | |
| time.sleep(4) | |
| continue | |
| else: | |
| # Final attempt still rate‐limited | |
| break | |
| # Successful response (no exception and no rate‐limit text) | |
| break | |
| return { | |
| "web_search_query": None, | |
| "web_search_result": result_text | |
| } | |
| def ocr_image_tool(state: AgentState) -> AgentState: | |
| """ | |
| Expects: state["ocr_path"] is either: | |
| • a local image path (e.g. "./hf_files/abc.png"), OR | |
| • a Task ID (e.g. "abc123"), in which case we try downloading | |
| GET {DEFAULT_API_URL}/files/{task_id} with .png/.jpg/.jpeg extensions. | |
| Returns: | |
| { | |
| "ocr_path": None, | |
| "ocr_result": "<OCR text + brief caption or an error message>" | |
| } | |
| """ | |
| print("reached ocr_image_tool") | |
| path_or_id = state.get("ocr_path", "") | |
| # if not path_or_id: | |
| # return {} | |
| # 1) Determine local_img: either existing path_or_id or download by Task ID | |
| # local_img = "" | |
| # if os.path.exists(path_or_id): | |
| # local_img = path_or_id | |
| # else: | |
| for ext in ("png", "jpg", "jpeg"): | |
| candidate = _download_file_for_task(state.get("task_id"), ext) | |
| if candidate: | |
| local_img = candidate | |
| break | |
| if not local_img or not os.path.exists(local_img): | |
| return { | |
| "ocr_path": None, | |
| "ocr_result": "Error: No image file found (local nonexistent or download failed)." | |
| } | |
| # 2) Read raw bytes | |
| try: | |
| with open(local_img, "rb") as f: | |
| image_bytes = f.read() | |
| except Exception as e: | |
| return { | |
| "ocr_path": None, | |
| "ocr_result": f"Error reading image file: {e}" | |
| } | |
| # 3) Prepare HF Inference headers | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| return { | |
| "ocr_path": None, | |
| "ocr_result": "Error: HUGGINGFACE_API_KEY not set in environment." | |
| } | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| # 4) Call HF’s vision-ocr to extract text | |
| ocr_text = "" | |
| try: | |
| ocr_resp = requests.post( | |
| "https://api-inference.huggingface.co/models/google/vit-ocr", | |
| headers=headers, | |
| files={"file": image_bytes}, | |
| timeout=30 | |
| ) | |
| ocr_resp.raise_for_status() | |
| ocr_json = ocr_resp.json() | |
| # The JSON has “pages” → list of blocks → “lines” → each line has “text” | |
| lines = [] | |
| for page in ocr_json.get("pages", []): | |
| for line in page.get("lines", []): | |
| lines.append(line.get("text", "").strip()) | |
| ocr_text = "\n".join(lines).strip() or "(no visible text)" | |
| except Exception as e: | |
| ocr_text = f"Error during HF OCR: {e}" | |
| # 5) Call HF’s image-captioning to get a brief description | |
| caption = "" | |
| try: | |
| cap_resp = requests.post( | |
| "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-base", | |
| headers=headers, | |
| files={"file": image_bytes}, | |
| timeout=30 | |
| ) | |
| cap_resp.raise_for_status() | |
| cap_json = cap_resp.json() | |
| # The response looks like: {"generated_text": "...caption..."} | |
| caption = cap_json.get("generated_text", "").strip() | |
| if not caption: | |
| caption = "(no caption returned)" | |
| except Exception as e: | |
| caption = f"Error during HF captioning: {e}" | |
| # 6) Combine OCR + caption | |
| combined = f"OCR text:\n{ocr_text}\n\nImage caption:\n{caption}" | |
| print("combined: ") | |
| return { | |
| "ocr_path": None, | |
| "ocr_result": combined | |
| } | |
| def parse_excel_tool(state: AgentState) -> AgentState: | |
| """ | |
| Expects state["excel_path"] to be either: | |
| • A real local .xlsx path, or | |
| • A Task ID string (e.g. "abc123"), in which case we GET /files/abc123.xlsx. | |
| Returns: | |
| { | |
| "excel_path": None, | |
| "excel_sheet_name": None, | |
| "excel_result": "<stringified records or Markdown table>" | |
| } | |
| Always attempts to download the file for the given path or task ID. | |
| """ | |
| print("reached parse_excel_tool") | |
| local_xlsx = _download_file_for_task(state.get("task_id"), "xlsx") | |
| path_or_id = state.get("excel_path", "") | |
| sheet = state.get("excel_sheet_name", "") | |
| if not path_or_id: | |
| return {} | |
| # Always attempt to download the file, regardless of local existence | |
| # If we finally have a real file, read it | |
| if local_xlsx and os.path.exists(local_xlsx): | |
| try: | |
| print("reached excel file found") | |
| xls = pd.ExcelFile(local_xlsx) | |
| if sheet and sheet in xls.sheet_names: | |
| df = pd.read_excel(xls, sheet_name=sheet) | |
| else: | |
| df = pd.read_excel(xls, sheet_name=xls.sheet_names[0]) | |
| records = df.to_dict(orient="records") | |
| text = str(records) | |
| print("reached excel file found: ") | |
| print(text) | |
| print() | |
| return { | |
| "excel_path": None, | |
| "excel_sheet_name": None, | |
| "excel_result": text | |
| } | |
| except Exception as e: | |
| print(f">>> parse_excel_tool: Error reading Excel file {local_xlsx}: {e}") | |
| # Fall back to scanning for Markdown below | |
| # Fallback: scan any HumanMessage for a Markdown‐style table | |
| messages = state.get("messages", []) | |
| table_lines = [] | |
| collecting = False | |
| for msg in messages: | |
| if isinstance(msg, HumanMessage): | |
| for line in msg.content.splitlines(): | |
| if re.match(r"^\s*\|\s*[-A-Za-z0-9]", line): | |
| collecting = True | |
| if collecting: | |
| if not re.match(r"^\s*\|", line): | |
| collecting = False | |
| break | |
| table_lines.append(line) | |
| if table_lines: | |
| break | |
| if not table_lines: | |
| return { | |
| "excel_path": None, | |
| "excel_sheet_name": None, | |
| "excel_result": "Error: No Excel file found and no Markdown table detected in prompt." | |
| } | |
| clean_rows = [row for row in table_lines if not re.match(r"^\s*\|\s*-+", row)] | |
| table_block = "\n".join(clean_rows).strip() | |
| print(f"Parsed excel as excel_result: {table_block}") | |
| return { | |
| "excel_path": None, | |
| "excel_sheet_name": None, | |
| "excel_result": table_block | |
| } | |
| import os | |
| import os | |
| import openai | |
| from old.old2state import AgentState | |
| def audio_transcriber_tool(state: AgentState) -> AgentState: | |
| """ | |
| LangGraph tool for transcribing audio via OpenAI's Whisper API. | |
| Expects: state["audio_path"] to be either: | |
| • A local file path (e.g. "./hf_files/abc.mp3"), OR | |
| • A Task ID (e.g. "abc123"), in which case we try downloading | |
| GET {DEFAULT_API_URL}/files/{task_id} with .mp3, .wav, .m4a extensions. | |
| Returns: | |
| { | |
| "audio_path": None, | |
| "transcript": "<text or error message>" | |
| } | |
| Always attempts to download the file for the given path or task ID. | |
| """ | |
| print("reached audio_transcriber_tool") | |
| path_or_id = state.get("audio_path", "") | |
| if not path_or_id: | |
| return {} | |
| # Always attempt to download the file, regardless of local existence | |
| local_audio = "" | |
| for ext in ("mp3", "wav", "m4a"): | |
| candidate = _download_file_for_task(state.get("task_id"), ext) | |
| if candidate: | |
| local_audio = candidate | |
| break | |
| if not local_audio or not os.path.exists(local_audio): | |
| return { | |
| "audio_path": None, | |
| "transcript": "Error: No audio file found (download failed)." | |
| } | |
| # Send to OpenAI Whisper | |
| try: | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai.api_key: | |
| raise RuntimeError("OPENAI_API_KEY is not set in environment.") | |
| with open(local_audio, "rb") as audio_file: | |
| print("reached openai.audio.transcriptions.create") | |
| response = openai.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=audio_file, | |
| ) | |
| print("reached response") | |
| text = response.text.strip() | |
| except Exception as e: | |
| text = f"Error during transcription: {e}" | |
| print(f"Transcripted as transcript: {text}") | |
| return { | |
| "audio_path": None, | |
| "transcript": text | |
| } | |
| # tools.py | |
| import re | |
| import requests | |
| from old.old2state import AgentState | |
| def wikipedia_search_tool(state: AgentState) -> AgentState: | |
| """ | |
| LangGraph wrapper for searching Wikipedia. | |
| Expects: state["wiki_query"] to be a non‐empty string. | |
| Returns: | |
| { | |
| "wiki_query": None, | |
| "wiki_result": "<text summary of first matching page or an error message>" | |
| } | |
| If no valid wiki_query is provided, returns {}. | |
| """ | |
| print("reached wikipedia search tool") | |
| query = state.get("wiki_query", "").strip() | |
| if not query: | |
| return {} | |
| try: | |
| # 1) Use the MediaWiki API to search for page titles matching the query | |
| search_params = { | |
| "action": "query", | |
| "list": "search", | |
| "srsearch": query, | |
| "format": "json", | |
| "utf8": 1 | |
| } | |
| search_resp = requests.get("https://en.wikipedia.org/w/api.php", params=search_params, timeout=10) | |
| search_resp.raise_for_status() | |
| search_data = search_resp.json() | |
| search_results = search_data.get("query", {}).get("search", []) | |
| # print("wikipedia: search_results",search_results) | |
| if not search_results: | |
| return {"wiki_query": None, "wiki_result": f"No Wikipedia page found for '{query}'."} | |
| # 2) Take the first search result's title | |
| first_title = search_results[0].get("title", "") | |
| if not first_title: | |
| return {"wiki_query": None, "wiki_result": "Unexpected format from Wikipedia search."} | |
| # 3) Fetch the page summary for that title via the REST summary endpoint | |
| title_for_url = requests.utils.requote_uri(first_title) | |
| summary_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{title_for_url}" | |
| summary_resp = requests.get(summary_url, timeout=10) | |
| summary_resp.raise_for_status() | |
| summary_data = summary_resp.json() | |
| # 4) Extract either the "extract" field or a fallback message | |
| summary_text = summary_data.get("extract") | |
| if not summary_text: | |
| summary_text = summary_data.get("description", "No summary available.") | |
| return { | |
| "wiki_query": None, | |
| "wiki_result": f"Title: {first_title}\n\n{summary_text}" | |
| } | |
| except requests.exceptions.RequestException as e: | |
| return {"wiki_query": None, "wiki_result": f"Wikipedia search error: {e}"} | |
| except Exception as e: | |
| return {"wiki_query": None, "wiki_result": f"Unexpected error in wikipedia_search_tool: {e}"} | |
| def run_tools(state: AgentState, tool_out: AgentState) -> AgentState: | |
| """ | |
| Merges whatever partial state the tool wrapper returned (tool_out) | |
| into the main state. That is, combine previous keys with new keys: | |
| new_state = { **state, **tool_out }. | |
| This node should be wired as its own graph node, not as a transition function. | |
| """ | |
| new_state = {**state, **tool_out} | |
| return new_state |