|
import os |
|
import math |
|
import whisper |
|
import pandas as pd |
|
import pytesseract |
|
from PIL import Image |
|
from dotenv import load_dotenv |
|
from youtube_transcript_api import YouTubeTranscriptApi |
|
from typing import TypedDict, Dict, Any, Optional, List |
|
|
|
from langchain.tools import Tool |
|
from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper |
|
from langchain_community.tools import DuckDuckGoSearchRun |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
from langchain_core.documents import Document |
|
from langchain.tools.retriever import create_retriever_tool |
|
from langgraph.graph import StateGraph, START, END, MessagesState |
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
|
|
|
load_dotenv() |
|
openai_api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
|
|
|
|
def add_numbers(a: float, b: float) -> float: return a + b |
|
def subtract_numbers(a: float, b: float) -> float: return a - b |
|
def multiply_numbers(a: float, b: float) -> float: return a * b |
|
def divide_numbers(a: float, b: float) -> float: |
|
if b == 0: raise ValueError("Division by zero") |
|
return a / b |
|
def power(a: float, b: float) -> float: return a ** b |
|
def modulus(a: float, b: float) -> float: return a % b |
|
def square_root(a: float) -> float: |
|
if a < 0: raise ValueError("Cannot compute square root of a negative number") |
|
return math.sqrt(a) |
|
def logarithm(a: float, base: float = math.e) -> float: |
|
if a <= 0 or base <= 0: raise ValueError("Logarithm arguments must be positive") |
|
return math.log(a, base) |
|
|
|
|
|
web_search_tool = Tool.from_function( |
|
func=DuckDuckGoSearchRun().run, |
|
name="Web_Search", |
|
description="Search the internet for general-purpose queries." |
|
) |
|
|
|
wikipedia_tool = Tool.from_function( |
|
func=WikipediaAPIWrapper().run, |
|
name="Wikipedia_Search", |
|
description="Search Wikipedia for factual or encyclopedic information." |
|
) |
|
|
|
arxiv_tool = Tool.from_function( |
|
func=ArxivAPIWrapper().run, |
|
name="ArXiv_Search", |
|
description="Search ArXiv for scientific papers. Input should be a research topic or query." |
|
) |
|
|
|
|
|
whisper_model = whisper.load_model("base") |
|
|
|
def transcribe_audio(file_path: str) -> str: |
|
"""Transcribe audio files using Whisper.""" |
|
return whisper_model.transcribe(file_path)["text"] |
|
|
|
|
|
def get_youtube_transcript(video_id: str) -> str: |
|
"""Extract transcript from YouTube video using video ID.""" |
|
transcript = YouTubeTranscriptApi.get_transcript(video_id) |
|
return " ".join(entry["text"] for entry in transcript) |
|
|
|
|
|
def extract_text_from_image(image_path: str) -> str: |
|
"""Extract text from an image file.""" |
|
return pytesseract.image_to_string(Image.open(image_path)) |
|
|
|
|
|
def execute_python_code(code: str) -> str: |
|
"""Execute a Python script and return the output.""" |
|
try: |
|
local_vars = {} |
|
exec(code, {}, local_vars) |
|
return str(local_vars) |
|
except Exception as e: |
|
return f"Error: {e}" |
|
|
|
|
|
def total_sales_from_excel(file_path: str) -> str: |
|
"""Compute total food sales from an Excel file.""" |
|
df = pd.read_excel(file_path) |
|
food_df = df[df["Category"] == "Food"] |
|
return f"{food_df['Sales'].sum():.2f} USD" |
|
|
|
|
|
|
|
tools = [ |
|
Tool.from_function(add_numbers, name="Add_Numbers", description="Add two numbers."), |
|
Tool.from_function(subtract_numbers, name="Subtract_Numbers", description="Subtract two numbers."), |
|
Tool.from_function(multiply_numbers, name="Multiply_Numbers", description="Multiply two numbers."), |
|
Tool.from_function(divide_numbers, name="Divide_Numbers", description="Divide two numbers."), |
|
Tool.from_function(power, name="Power", description="Raise one number to the power of another."), |
|
Tool.from_function(modulus, name="Modulus", description="Compute the modulus (remainder) of a division."), |
|
Tool.from_function(square_root, name="Square_Root", description="Compute the square root of a number."), |
|
Tool.from_function(logarithm, name="Logarithm", description="Compute the logarithm of a number with a given base."), |
|
web_search_tool, |
|
wikipedia_tool, |
|
arxiv_tool, |
|
Tool.from_function(transcribe_audio, name="Transcribe_Audio", description="Transcribe audio to text."), |
|
Tool.from_function(get_youtube_transcript, name="YouTube_Transcript", description="Extract transcript from YouTube."), |
|
Tool.from_function(extract_text_from_image, name="Image_OCR", description="Extract text from an image."), |
|
Tool.from_function(execute_python_code, name="Python_Code_Executor", description="Run Python code."), |
|
Tool.from_function(total_sales_from_excel, name="Excel_Sales_Parser", description="Parse Excel file for total food sales."), |
|
] |
|
|
|
|
|
|
|
with open("system_prompt.txt", "r", encoding="utf-8") as f: |
|
system_prompt = f.read() |
|
sys_msg = SystemMessage(content=system_prompt) |
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") |
|
|
|
documents = [ |
|
Document(page_content="What is the capital of France? Paris.", metadata={"source": "example"}), |
|
Document(page_content="How many legs does a spider have? 8.", metadata={"source": "example"}), |
|
] |
|
vector_store = FAISS.from_documents(documents, embeddings) |
|
|
|
retriever_tool = create_retriever_tool( |
|
retriever=vector_store.as_retriever(), |
|
name="Question Search", |
|
description="Retrieve similar questions from a vector store." |
|
) |
|
|
|
|
|
|
|
def graph(): |
|
llm = ChatOpenAI(model="gpt-4o", temperature=0) |
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
def assistant(state: MessagesState): |
|
"""Assistant node to generate answers.""" |
|
return {"messages": [llm_with_tools.invoke(state["messages"])]} |
|
|
|
|
|
def retriever(state: MessagesState): |
|
"""Retriever node to provide example context.""" |
|
similar = vector_store.similarity_search(state["messages"][0].content) |
|
if not similar: |
|
return {"messages": [sys_msg] + state["messages"]} |
|
example = HumanMessage(content=f"Similar Q&A for context:\n\n{similar[0].page_content}") |
|
return {"messages": [sys_msg] + state["messages"] + [example]} |
|
|
|
|
|
builder = StateGraph(MessagesState) |
|
builder.add_node("retriever", retriever) |
|
builder.add_node("assistant", assistant) |
|
builder.add_node("tools", ToolNode(tools)) |
|
|
|
builder.add_edge(START, "retriever") |
|
builder.add_edge("retriever", "assistant") |
|
builder.add_conditional_edges("assistant", tools_condition) |
|
builder.add_edge("tools", "assistant") |
|
|
|
return builder.compile() |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
test_question = "How many albums did Taylor Swift release before 2020?" |
|
response = graph.invoke({"messages": [HumanMessage(content=test_question)]}) |
|
for msg in response["messages"]: |
|
msg.pretty_print() |