FrancescaScipioni's picture
fixed FAISS initial document
4f8f8da verified
raw
history blame
7.25 kB
import os
import math
import whisper
import pandas as pd
import pytesseract
from PIL import Image
from dotenv import load_dotenv
from youtube_transcript_api import YouTubeTranscriptApi
from typing import TypedDict, Dict, Any, Optional, List
from langchain.tools import Tool
from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.documents import Document
from langchain.tools.retriever import create_retriever_tool
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
# Load environment variables
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
## ----- TOOL DEFINITIONS ----- ##
# Math Tools
def add_numbers(a: float, b: float) -> float: return a + b
def subtract_numbers(a: float, b: float) -> float: return a - b
def multiply_numbers(a: float, b: float) -> float: return a * b
def divide_numbers(a: float, b: float) -> float:
if b == 0: raise ValueError("Division by zero")
return a / b
def power(a: float, b: float) -> float: return a ** b
def modulus(a: float, b: float) -> float: return a % b
def square_root(a: float) -> float:
if a < 0: raise ValueError("Cannot compute square root of a negative number")
return math.sqrt(a)
def logarithm(a: float, base: float = math.e) -> float:
if a <= 0 or base <= 0: raise ValueError("Logarithm arguments must be positive")
return math.log(a, base)
# Web Search Tools
web_search_tool = Tool.from_function(
func=DuckDuckGoSearchRun().run,
name="Web Search",
description="Search the internet for general-purpose queries."
)
wikipedia_tool = Tool.from_function(
func=WikipediaAPIWrapper().run,
name="Wikipedia Search",
description="Search Wikipedia for factual or encyclopedic information."
)
arxiv_tool = Tool.from_function(
func=ArxivAPIWrapper().run,
name="ArXiv Search",
description="Search ArXiv for scientific papers. Input should be a research topic or query."
)
# Audio Transcription
whisper_model = whisper.load_model("base")
def transcribe_audio(file_path: str) -> str:
"""Transcribe audio files using Whisper."""
return whisper_model.transcribe(file_path)["text"]
# YouTube Transcript
def get_youtube_transcript(video_id: str) -> str:
"""Extract transcript from YouTube video using video ID."""
transcript = YouTubeTranscriptApi.get_transcript(video_id)
return " ".join(entry["text"] for entry in transcript)
# OCR Tool
def extract_text_from_image(image_path: str) -> str:
"""Extract text from an image file."""
return pytesseract.image_to_string(Image.open(image_path))
# Code Execution
def execute_python_code(code: str) -> str:
"""Execute a Python script and return the output."""
try:
local_vars = {}
exec(code, {}, local_vars)
return str(local_vars)
except Exception as e:
return f"Error: {e}"
# Excel Parsing
def total_sales_from_excel(file_path: str) -> str:
"""Compute total food sales from an Excel file."""
df = pd.read_excel(file_path)
food_df = df[df["Category"] == "Food"]
return f"{food_df['Sales'].sum():.2f} USD"
## ----- TOOL LIST ----- ##
tools = [
Tool.from_function(add_numbers, name="Add Numbers", description="Add two numbers."),
Tool.from_function(subtract_numbers, name="Subtract Numbers", description="Subtract two numbers."),
Tool.from_function(multiply_numbers, name="Multiply Numbers", description="Multiply two numbers."),
Tool.from_function(divide_numbers, name="Divide Numbers", description="Divide two numbers."),
Tool.from_function(power, name="Power", description="Raise one number to the power of another."),
Tool.from_function(modulus, name="Modulus", description="Compute the modulus (remainder) of a division."),
Tool.from_function(square_root, name="Square Root", description="Compute the square root of a number."),
Tool.from_function(logarithm, name="Logarithm", description="Compute the logarithm of a number with a given base."),
web_search_tool,
wikipedia_tool,
arxiv_tool,
Tool.from_function(transcribe_audio, name="Transcribe Audio", description="Transcribe audio to text."),
Tool.from_function(get_youtube_transcript, name="YouTube Transcript", description="Extract transcript from YouTube."),
Tool.from_function(extract_text_from_image, name="Image OCR", description="Extract text from an image."),
Tool.from_function(execute_python_code, name="Python Code Executor", description="Run Python code."),
Tool.from_function(total_sales_from_excel, name="Excel Sales Parser", description="Parse Excel file for total food sales."),
]
## ----- SYSTEM PROMPT ----- ##
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)
## ----- EMBEDDINGS & VECTOR DB (FAISS) ----- ##
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
documents = [
Document(page_content="What is the capital of France? Paris.", metadata={"source": "example"}),
Document(page_content="How many legs does a spider have? 8.", metadata={"source": "example"}),
]
vector_store = FAISS.from_documents(documents, embeddings)
retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="Retrieve similar questions from a vector store."
)
## ----- GRAPH PIPELINE ----- ##
def graph():
llm = ChatOpenAI(model="gpt-4o", temperature=0)
llm_with_tools = llm.bind_tools(tools)
def assistant(state: MessagesState):
"""Assistant node to generate answers."""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
# Use a retriever node to inject a similar example
def retriever(state: MessagesState):
"""Retriever node to provide example context."""
similar = vector_store.similarity_search(state["messages"][0].content)
if not similar:
return {"messages": [sys_msg] + state["messages"]}
example = HumanMessage(content=f"Similar Q&A for context:\n\n{similar[0].page_content}")
return {"messages": [sys_msg] + state["messages"] + [example]}
# Build graph
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
## ----- TESTING (Optional) ----- ##
if __name__ == "__main__":
test_question = "How many albums did Taylor Swift release before 2020?"
response = graph.invoke({"messages": [HumanMessage(content=test_question)]})
for msg in response["messages"]:
msg.pretty_print()