File size: 7,245 Bytes
83eac9e
97bdee5
5579ccc
 
83eac9e
 
22c8505
83eac9e
 
22c8505
83eac9e
8dd7ab8
 
83eac9e
 
22c8505
 
4f8f8da
83eac9e
 
 
22c8505
83eac9e
22c8505
 
8569e5d
83eac9e
97bdee5
83eac9e
 
 
 
97bdee5
83eac9e
97bdee5
83eac9e
 
97bdee5
83eac9e
97bdee5
 
83eac9e
97bdee5
 
83eac9e
3f8f511
83eac9e
3f8f511
83eac9e
3f8f511
 
 
83eac9e
3f8f511
83eac9e
3f8f511
 
 
83eac9e
3f8f511
83eac9e
3f8f511
 
83eac9e
 
535f4d4
5579ccc
83eac9e
 
5579ccc
83eac9e
5579ccc
83eac9e
5579ccc
83eac9e
5579ccc
83eac9e
5579ccc
83eac9e
5579ccc
 
83eac9e
5579ccc
83eac9e
5579ccc
 
 
 
 
 
 
83eac9e
5579ccc
 
 
 
83eac9e
5579ccc
83eac9e
8569e5d
97bdee5
83eac9e
 
 
 
 
 
 
 
3f8f511
 
5579ccc
83eac9e
 
 
 
 
22c8505
 
 
 
 
 
 
 
83eac9e
22c8505
83eac9e
 
4f8f8da
 
 
 
83eac9e
 
 
 
 
 
 
 
 
 
021c8f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83eac9e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import math
import whisper
import pandas as pd
import pytesseract
from PIL import Image
from dotenv import load_dotenv
from youtube_transcript_api import YouTubeTranscriptApi
from typing import TypedDict, Dict, Any, Optional, List

from langchain.tools import Tool
from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.documents import Document
from langchain.tools.retriever import create_retriever_tool
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition

# Load environment variables
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")

## ----- TOOL DEFINITIONS ----- ##

# Math Tools
def add_numbers(a: float, b: float) -> float: return a + b
def subtract_numbers(a: float, b: float) -> float: return a - b
def multiply_numbers(a: float, b: float) -> float: return a * b
def divide_numbers(a: float, b: float) -> float:
    if b == 0: raise ValueError("Division by zero")
    return a / b
def power(a: float, b: float) -> float: return a ** b
def modulus(a: float, b: float) -> float: return a % b
def square_root(a: float) -> float:
    if a < 0: raise ValueError("Cannot compute square root of a negative number")
    return math.sqrt(a)
def logarithm(a: float, base: float = math.e) -> float:
    if a <= 0 or base <= 0: raise ValueError("Logarithm arguments must be positive")
    return math.log(a, base)

# Web Search Tools
web_search_tool = Tool.from_function(
    func=DuckDuckGoSearchRun().run,
    name="Web Search",
    description="Search the internet for general-purpose queries."
)

wikipedia_tool = Tool.from_function(
    func=WikipediaAPIWrapper().run,
    name="Wikipedia Search",
    description="Search Wikipedia for factual or encyclopedic information."
)

arxiv_tool = Tool.from_function(
    func=ArxivAPIWrapper().run,
    name="ArXiv Search",
    description="Search ArXiv for scientific papers. Input should be a research topic or query."
)

# Audio Transcription
whisper_model = whisper.load_model("base")

def transcribe_audio(file_path: str) -> str:
    """Transcribe audio files using Whisper."""
    return whisper_model.transcribe(file_path)["text"]

# YouTube Transcript
def get_youtube_transcript(video_id: str) -> str:
    """Extract transcript from YouTube video using video ID."""
    transcript = YouTubeTranscriptApi.get_transcript(video_id)
    return " ".join(entry["text"] for entry in transcript)

# OCR Tool
def extract_text_from_image(image_path: str) -> str:
    """Extract text from an image file."""
    return pytesseract.image_to_string(Image.open(image_path))

# Code Execution
def execute_python_code(code: str) -> str:
    """Execute a Python script and return the output."""
    try:
        local_vars = {}
        exec(code, {}, local_vars)
        return str(local_vars)
    except Exception as e:
        return f"Error: {e}"

# Excel Parsing
def total_sales_from_excel(file_path: str) -> str:
    """Compute total food sales from an Excel file."""
    df = pd.read_excel(file_path)
    food_df = df[df["Category"] == "Food"]
    return f"{food_df['Sales'].sum():.2f} USD"

## ----- TOOL LIST ----- ##

tools = [
    Tool.from_function(add_numbers, name="Add Numbers", description="Add two numbers."),
    Tool.from_function(subtract_numbers, name="Subtract Numbers", description="Subtract two numbers."),
    Tool.from_function(multiply_numbers, name="Multiply Numbers", description="Multiply two numbers."),
    Tool.from_function(divide_numbers, name="Divide Numbers", description="Divide two numbers."),
    Tool.from_function(power, name="Power", description="Raise one number to the power of another."),
    Tool.from_function(modulus, name="Modulus", description="Compute the modulus (remainder) of a division."),
    Tool.from_function(square_root, name="Square Root", description="Compute the square root of a number."),
    Tool.from_function(logarithm, name="Logarithm", description="Compute the logarithm of a number with a given base."),
    web_search_tool,
    wikipedia_tool,
    arxiv_tool,
    Tool.from_function(transcribe_audio, name="Transcribe Audio", description="Transcribe audio to text."),
    Tool.from_function(get_youtube_transcript, name="YouTube Transcript", description="Extract transcript from YouTube."),
    Tool.from_function(extract_text_from_image, name="Image OCR", description="Extract text from an image."),
    Tool.from_function(execute_python_code, name="Python Code Executor", description="Run Python code."),
    Tool.from_function(total_sales_from_excel, name="Excel Sales Parser", description="Parse Excel file for total food sales."),
]

## ----- SYSTEM PROMPT ----- ##

with open("system_prompt.txt", "r", encoding="utf-8") as f:
    system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)

## ----- EMBEDDINGS & VECTOR DB (FAISS) ----- ##

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

documents = [
    Document(page_content="What is the capital of France? Paris.", metadata={"source": "example"}),
    Document(page_content="How many legs does a spider have? 8.", metadata={"source": "example"}),
]
vector_store = FAISS.from_documents(documents, embeddings)

retriever_tool = create_retriever_tool(
    retriever=vector_store.as_retriever(),
    name="Question Search",
    description="Retrieve similar questions from a vector store."
)

## ----- GRAPH PIPELINE ----- ##

def graph():
    llm = ChatOpenAI(model="gpt-4o", temperature=0)
    llm_with_tools = llm.bind_tools(tools)

    def assistant(state: MessagesState):
        """Assistant node to generate answers."""
        return {"messages": [llm_with_tools.invoke(state["messages"])]}
    
    # Use a retriever node to inject a similar example
    def retriever(state: MessagesState):
        """Retriever node to provide example context."""
        similar = vector_store.similarity_search(state["messages"][0].content)
        if not similar:
            return {"messages": [sys_msg] + state["messages"]}
        example = HumanMessage(content=f"Similar Q&A for context:\n\n{similar[0].page_content}")
        return {"messages": [sys_msg] + state["messages"] + [example]}
    
    # Build graph
    builder = StateGraph(MessagesState)
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges("assistant", tools_condition)
    builder.add_edge("tools", "assistant")
    
    return builder.compile()


## ----- TESTING (Optional) ----- ##

if __name__ == "__main__":
    test_question = "How many albums did Taylor Swift release before 2020?"
    response = graph.invoke({"messages": [HumanMessage(content=test_question)]})
    for msg in response["messages"]:
        msg.pretty_print()