FrancescaScipioni's picture
Fix: updated tool names to use valid characters
f7b61d7 verified
import os
import math
import whisper
import pandas as pd
import pytesseract
from PIL import Image
from dotenv import load_dotenv
from youtube_transcript_api import YouTubeTranscriptApi
from typing import TypedDict, Dict, Any, Optional, List
from langchain.tools import Tool
from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.documents import Document
from langchain.tools.retriever import create_retriever_tool
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
# Load environment variables
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
## ----- TOOL DEFINITIONS ----- ##
# Math Tools
def add_numbers(a: float, b: float) -> float: return a + b
def subtract_numbers(a: float, b: float) -> float: return a - b
def multiply_numbers(a: float, b: float) -> float: return a * b
def divide_numbers(a: float, b: float) -> float:
if b == 0: raise ValueError("Division by zero")
return a / b
def power(a: float, b: float) -> float: return a ** b
def modulus(a: float, b: float) -> float: return a % b
def square_root(a: float) -> float:
if a < 0: raise ValueError("Cannot compute square root of a negative number")
return math.sqrt(a)
def logarithm(a: float, base: float = math.e) -> float:
if a <= 0 or base <= 0: raise ValueError("Logarithm arguments must be positive")
return math.log(a, base)
# Web Search Tools
web_search_tool = Tool.from_function(
func=DuckDuckGoSearchRun().run,
name="Web_Search",
description="Search the internet for general-purpose queries."
)
wikipedia_tool = Tool.from_function(
func=WikipediaAPIWrapper().run,
name="Wikipedia_Search",
description="Search Wikipedia for factual or encyclopedic information."
)
arxiv_tool = Tool.from_function(
func=ArxivAPIWrapper().run,
name="ArXiv_Search",
description="Search ArXiv for scientific papers. Input should be a research topic or query."
)
# Audio Transcription
whisper_model = whisper.load_model("base")
def transcribe_audio(file_path: str) -> str:
"""Transcribe audio files using Whisper."""
return whisper_model.transcribe(file_path)["text"]
# YouTube Transcript
def get_youtube_transcript(video_id: str) -> str:
"""Extract transcript from YouTube video using video ID."""
transcript = YouTubeTranscriptApi.get_transcript(video_id)
return " ".join(entry["text"] for entry in transcript)
# OCR Tool
def extract_text_from_image(image_path: str) -> str:
"""Extract text from an image file."""
return pytesseract.image_to_string(Image.open(image_path))
# Code Execution
def execute_python_code(code: str) -> str:
"""Execute a Python script and return the output."""
try:
local_vars = {}
exec(code, {}, local_vars)
return str(local_vars)
except Exception as e:
return f"Error: {e}"
# Excel Parsing
def total_sales_from_excel(file_path: str) -> str:
"""Compute total food sales from an Excel file."""
df = pd.read_excel(file_path)
food_df = df[df["Category"] == "Food"]
return f"{food_df['Sales'].sum():.2f} USD"
## ----- TOOL LIST ----- ##
tools = [
Tool.from_function(add_numbers, name="Add_Numbers", description="Add two numbers."),
Tool.from_function(subtract_numbers, name="Subtract_Numbers", description="Subtract two numbers."),
Tool.from_function(multiply_numbers, name="Multiply_Numbers", description="Multiply two numbers."),
Tool.from_function(divide_numbers, name="Divide_Numbers", description="Divide two numbers."),
Tool.from_function(power, name="Power", description="Raise one number to the power of another."),
Tool.from_function(modulus, name="Modulus", description="Compute the modulus (remainder) of a division."),
Tool.from_function(square_root, name="Square_Root", description="Compute the square root of a number."),
Tool.from_function(logarithm, name="Logarithm", description="Compute the logarithm of a number with a given base."),
web_search_tool,
wikipedia_tool,
arxiv_tool,
Tool.from_function(transcribe_audio, name="Transcribe_Audio", description="Transcribe audio to text."),
Tool.from_function(get_youtube_transcript, name="YouTube_Transcript", description="Extract transcript from YouTube."),
Tool.from_function(extract_text_from_image, name="Image_OCR", description="Extract text from an image."),
Tool.from_function(execute_python_code, name="Python_Code_Executor", description="Run Python code."),
Tool.from_function(total_sales_from_excel, name="Excel_Sales_Parser", description="Parse Excel file for total food sales."),
]
## ----- SYSTEM PROMPT ----- ##
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)
## ----- EMBEDDINGS & VECTOR DB (FAISS) ----- ##
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
documents = [
Document(page_content="What is the capital of France? Paris.", metadata={"source": "example"}),
Document(page_content="How many legs does a spider have? 8.", metadata={"source": "example"}),
]
vector_store = FAISS.from_documents(documents, embeddings)
retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="Retrieve similar questions from a vector store."
)
## ----- GRAPH PIPELINE ----- ##
def graph():
llm = ChatOpenAI(model="gpt-4o", temperature=0)
llm_with_tools = llm.bind_tools(tools)
def assistant(state: MessagesState):
"""Assistant node to generate answers."""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
# Use a retriever node to inject a similar example
def retriever(state: MessagesState):
"""Retriever node to provide example context."""
similar = vector_store.similarity_search(state["messages"][0].content)
if not similar:
return {"messages": [sys_msg] + state["messages"]}
example = HumanMessage(content=f"Similar Q&A for context:\n\n{similar[0].page_content}")
return {"messages": [sys_msg] + state["messages"] + [example]}
# Build graph
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
## ----- TESTING (Optional) ----- ##
if __name__ == "__main__":
test_question = "How many albums did Taylor Swift release before 2020?"
response = graph.invoke({"messages": [HumanMessage(content=test_question)]})
for msg in response["messages"]:
msg.pretty_print()