FrancescaScipioni's picture
selected the llm to use
22c8505 verified
raw
history blame
7.52 kB
from langchain.tools import Tool
from langchain.utilities import WikipediaAPIWrapper, ArxivAPIWrapper, DuckDuckGoSearchRun
import math
import whisper
from youtube_transcript_api import YouTubeTranscriptApi
from PIL import Image
import pytesseract
import pandas as pd
from dotenv import load_dotenv
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from typing import TypedDict, Dict, Any, Optional, List
load_dotenv()
## ----- API KEYS ----- ##
openai_api_key = os.getenv("OPENAI_API_KEY")
## ----- TOOLS DEFINITION ----- ##
# ** Math Tools ** #
def add_numbers(a: float, b: float) -> float:
"""
Add two floating-point numbers.
Args:
a (float): The first number.
b (float): The second number.
Returns:
float: The result of the addition.
"""
return a + b
def subtract_numbers(a: float, b: float) -> float:
"""
Subtract the second floating-point number from the first.
Args:
a (float): The first number.
b (float): The second number.
Returns:
float: The result of the subtraction.
"""
return a - b
def multiply_numbers(a: float, b: float) -> float:
"""
Multiply two floating-point numbers.
Args:
a (float): The first number.
b (float): The second number.
Returns:
float: The result of the multiplication.
"""
return a * b
def divide_numbers(a: float, b: float) -> float:
"""
Divide the first floating-point number by the second.
Args:
a (float): The numerator.
b (float): The denominator.
Returns:
float: The result of the division.
Raises:
ValueError: If division by zero is attempted.
"""
if b == 0:
raise ValueError("Division by zero")
return a / b
def power(a: float, b: float) -> float:
"""
Raise the first number to the power of the second.
Args:
a (float): The base.
b (float): The exponent.
Returns:
float: The result of the exponentiation.
"""
return a ** b
def modulus(a: float, b: float) -> float:
"""
Compute the modulus (remainder) of the division of a by b.
Args:
a (float): The dividend.
b (float): The divisor.
Returns:
float: The remainder after division.
"""
return a % b
def square_root(a: float) -> float:
"""
Compute the square root of a number.
Args:
a (float): The number.
Returns:
float: The square root.
Raises:
ValueError: If a is negative.
"""
if a < 0:
raise ValueError("Cannot compute square root of a negative number")
return math.sqrt(a)
def logarithm(a: float, base: float = math.e) -> float:
"""
Compute the logarithm of a number with a specified base.
Args:
a (float): The number.
base (float, optional): The logarithmic base (default is natural log).
Returns:
float: The logarithm.
Raises:
ValueError: If a or base is not positive.
"""
if a <= 0 or base <= 0:
raise ValueError("Logarithm arguments must be positive")
return math.log(a, base)
# ** Search Tools ** #
# DuckDuckGo Web Search
duckduckgo_search = DuckDuckGoSearchRun()
web_search_tool = Tool.from_function(
func=duckduckgo_search.run,
name="Web Search",
description="Use this tool to search the internet for general-purpose queries."
)
# Wikipedia Search
wikipedia_search = WikipediaAPIWrapper()
wikipedia_tool = Tool.from_function(
func=wikipedia_search.run,
name="Wikipedia Search",
description="Use this tool to search Wikipedia for factual or encyclopedic information."
)
# ArXiv Search
arxiv_search = ArxivAPIWrapper()
arxiv_tool = Tool.from_function(
func=arxiv_search.run,
name="ArXiv Search",
description="Use this tool to search ArXiv for scientific papers. Input should be a research topic or query."
)
# ** Audio Transcription Tool ** #
model = whisper.load_model("base")
@tool
def transcribe_audio(file_path: str) -> str:
"""Transcribe spoken words from an audio file into text."""
result = model.transcribe(file_path)
return result["text"]
# ** youtube-transcript-api Tool ** #
@tool
def get_youtube_transcript(video_id: str) -> str:
"""Get transcript of a YouTube video from its video ID."""
transcript = YouTubeTranscriptApi.get_transcript(video_id)
return " ".join([entry["text"] for entry in transcript])
# ** Image Tool ** #
@tool
def extract_text_from_image(image_path: str) -> str:
"""Extract text from an image using OCR."""
return pytesseract.image_to_string(Image.open(image_path))
# ** Code Execution Tool ** #
@tool
def execute_python_code(code: str) -> str:
"""Execute a Python code string and return the output."""
try:
local_vars = {}
exec(code, {}, local_vars)
return str(local_vars)
except Exception as e:
return f"Error: {e}"
# ** Excel Parsing Tool ** #
@tool
def total_sales_from_excel(file_path: str) -> str:
"""Compute total food sales from an Excel file."""
df = pd.read_excel(file_path)
food_df = df[df["Category"] == "Food"]
total_sales = food_df["Sales"].sum()
return f"{total_sales:.2f} USD"
## ----- TOOLS LIST ----- ##
tools = [
# Math
Tool.from_function(func=add_numbers, name="Add Numbers", description="Add two numbers."),
Tool.from_function(func=subtract_numbers, name="Subtract Numbers", description="Subtract two numbers."),
Tool.from_function(func=multiply_numbers, name="Multiply Numbers", description="Multiply two numbers."),
Tool.from_function(func=divide_numbers, name="Divide Numbers", description="Divide two numbers."),
Tool.from_function(func=power, name="Power", description="Raise one number to the power of another."),
Tool.from_function(func=modulus, name="Modulus", description="Compute the modulus (remainder) of a division."),
Tool.from_function(func=square_root, name="Square Root", description="Compute the square root of a number."),
Tool.from_function(func=logarithm, name="Logarithm", description="Compute the logarithm of a number with a given base."),
# Search
web_search_tool,
wikipedia_tool,
arxiv_tool,
# Audio
Tool.from_function(func=transcribe_audio, name="Transcribe Audio", description="Transcribe audio files to text."),
# Youtube
Tool.from_function(func=get_youtube_transcript, name="YouTube Transcript", description="Extract transcript from YouTube video."),
# Image
Tool.from_function(func=extract_text_from_image, name="Image OCR", description="Extract text from an image file."),
# Code Execution
Tool.from_function(func=execute_python_code, name="Python Code Executor", description="Run and return output from a Python script."),
# Excel parsing
Tool.from_function(func=total_sales_from_excel, name="Excel Sales Parser", description="Compute total food sales from Excel file."),
]
## ----- LLM MODEL ----- ##
llm = ChatOpenAI(model="gpt-4o", temperature=0)
llm_with_tools = llm.bind_tools(tools)
## ----- SYSTEM PROMPT ----- ##
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
print(system_prompt)
# System message
sys_msg = SystemMessage(content=system_prompt)
## ----- GRAPH AGENT PIPELINE ----- ##