import os import math import whisper import pandas as pd import pytesseract from PIL import Image from dotenv import load_dotenv from youtube_transcript_api import YouTubeTranscriptApi from typing import TypedDict, Dict, Any, Optional, List from langchain.tools import Tool from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper from langchain_community.tools import DuckDuckGoSearchRun from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.documents import Document from langchain.tools.retriever import create_retriever_tool from langgraph.graph import StateGraph, START, END, MessagesState from langgraph.prebuilt import ToolNode, tools_condition # Load environment variables load_dotenv() openai_api_key = os.getenv("OPENAI_API_KEY") ## ----- TOOL DEFINITIONS ----- ## # Math Tools def add_numbers(a: float, b: float) -> float: return a + b def subtract_numbers(a: float, b: float) -> float: return a - b def multiply_numbers(a: float, b: float) -> float: return a * b def divide_numbers(a: float, b: float) -> float: if b == 0: raise ValueError("Division by zero") return a / b def power(a: float, b: float) -> float: return a ** b def modulus(a: float, b: float) -> float: return a % b def square_root(a: float) -> float: if a < 0: raise ValueError("Cannot compute square root of a negative number") return math.sqrt(a) def logarithm(a: float, base: float = math.e) -> float: if a <= 0 or base <= 0: raise ValueError("Logarithm arguments must be positive") return math.log(a, base) # Web Search Tools web_search_tool = Tool.from_function( func=DuckDuckGoSearchRun().run, name="Web_Search", description="Search the internet for general-purpose queries." ) wikipedia_tool = Tool.from_function( func=WikipediaAPIWrapper().run, name="Wikipedia_Search", description="Search Wikipedia for factual or encyclopedic information." ) arxiv_tool = Tool.from_function( func=ArxivAPIWrapper().run, name="ArXiv_Search", description="Search ArXiv for scientific papers. Input should be a research topic or query." ) # Audio Transcription whisper_model = whisper.load_model("base") def transcribe_audio(file_path: str) -> str: """Transcribe audio files using Whisper.""" return whisper_model.transcribe(file_path)["text"] # YouTube Transcript def get_youtube_transcript(video_id: str) -> str: """Extract transcript from YouTube video using video ID.""" transcript = YouTubeTranscriptApi.get_transcript(video_id) return " ".join(entry["text"] for entry in transcript) # OCR Tool def extract_text_from_image(image_path: str) -> str: """Extract text from an image file.""" return pytesseract.image_to_string(Image.open(image_path)) # Code Execution def execute_python_code(code: str) -> str: """Execute a Python script and return the output.""" try: local_vars = {} exec(code, {}, local_vars) return str(local_vars) except Exception as e: return f"Error: {e}" # Excel Parsing def total_sales_from_excel(file_path: str) -> str: """Compute total food sales from an Excel file.""" df = pd.read_excel(file_path) food_df = df[df["Category"] == "Food"] return f"{food_df['Sales'].sum():.2f} USD" ## ----- TOOL LIST ----- ## tools = [ Tool.from_function(add_numbers, name="Add_Numbers", description="Add two numbers."), Tool.from_function(subtract_numbers, name="Subtract_Numbers", description="Subtract two numbers."), Tool.from_function(multiply_numbers, name="Multiply_Numbers", description="Multiply two numbers."), Tool.from_function(divide_numbers, name="Divide_Numbers", description="Divide two numbers."), Tool.from_function(power, name="Power", description="Raise one number to the power of another."), Tool.from_function(modulus, name="Modulus", description="Compute the modulus (remainder) of a division."), Tool.from_function(square_root, name="Square_Root", description="Compute the square root of a number."), Tool.from_function(logarithm, name="Logarithm", description="Compute the logarithm of a number with a given base."), web_search_tool, wikipedia_tool, arxiv_tool, Tool.from_function(transcribe_audio, name="Transcribe_Audio", description="Transcribe audio to text."), Tool.from_function(get_youtube_transcript, name="YouTube_Transcript", description="Extract transcript from YouTube."), Tool.from_function(extract_text_from_image, name="Image_OCR", description="Extract text from an image."), Tool.from_function(execute_python_code, name="Python_Code_Executor", description="Run Python code."), Tool.from_function(total_sales_from_excel, name="Excel_Sales_Parser", description="Parse Excel file for total food sales."), ] ## ----- SYSTEM PROMPT ----- ## with open("system_prompt.txt", "r", encoding="utf-8") as f: system_prompt = f.read() sys_msg = SystemMessage(content=system_prompt) ## ----- EMBEDDINGS & VECTOR DB (FAISS) ----- ## embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") documents = [ Document(page_content="What is the capital of France? Paris.", metadata={"source": "example"}), Document(page_content="How many legs does a spider have? 8.", metadata={"source": "example"}), ] vector_store = FAISS.from_documents(documents, embeddings) retriever_tool = create_retriever_tool( retriever=vector_store.as_retriever(), name="Question Search", description="Retrieve similar questions from a vector store." ) ## ----- GRAPH PIPELINE ----- ## def graph(): llm = ChatOpenAI(model="gpt-4o", temperature=0) llm_with_tools = llm.bind_tools(tools) def assistant(state: MessagesState): """Assistant node to generate answers.""" return {"messages": [llm_with_tools.invoke(state["messages"])]} # Use a retriever node to inject a similar example def retriever(state: MessagesState): """Retriever node to provide example context.""" similar = vector_store.similarity_search(state["messages"][0].content) if not similar: return {"messages": [sys_msg] + state["messages"]} example = HumanMessage(content=f"Similar Q&A for context:\n\n{similar[0].page_content}") return {"messages": [sys_msg] + state["messages"] + [example]} # Build graph builder = StateGraph(MessagesState) builder.add_node("retriever", retriever) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges("assistant", tools_condition) builder.add_edge("tools", "assistant") return builder.compile() ## ----- TESTING (Optional) ----- ## if __name__ == "__main__": test_question = "How many albums did Taylor Swift release before 2020?" response = graph.invoke({"messages": [HumanMessage(content=test_question)]}) for msg in response["messages"]: msg.pretty_print()