|
import os |
|
import shutil |
|
from langchain_groq import ChatGroq |
|
from langchain.prompts import PromptTemplate |
|
from langgraph.graph import START, StateGraph, MessagesState |
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.document_loaders import WikipediaLoader |
|
from langchain_core.messages import SystemMessage, HumanMessage |
|
from langchain.tools import tool |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.runnables import Runnable |
|
from dotenv import load_dotenv |
|
from langchain.vectorstores import Chroma |
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain.tools.retriever import create_retriever_tool |
|
from typing import TypedDict, Annotated, List |
|
from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun, ArxivQueryRun |
|
from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper |
|
from langchain.tools import Tool |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
Agent_prompt_template = '''You are a helpful assistant tasked with answering questions using a set of tools. |
|
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: |
|
FINAL ANSWER: [YOUR FINAL ANSWER]. |
|
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. |
|
Your answer should only start with "FINAL ANSWER: ", then follows with the answer. ''' |
|
|
|
sys_msg = SystemMessage(content=Agent_prompt_template) |
|
|
|
|
|
|
|
def initialize_llm(): |
|
"""Initializes the ChatGroq LLM.""" |
|
llm = ChatGroq( |
|
temperature=0, |
|
model_name="qwen-qwq-32b", |
|
groq_api_key=os.getenv("GROQ_API_KEY") |
|
) |
|
return llm |
|
|
|
|
|
def initialize_search_tool(): |
|
"""Initializes the TavilySearchResults tool.""" |
|
return TavilySearchResults() |
|
|
|
|
|
def get_weather(location: str, search_tool: TavilySearchResults = None) -> str: |
|
""" |
|
Fetches the current weather information for a given location using Tavily search. |
|
|
|
Args: |
|
location (str): The name of the location to search for. |
|
search_tool (TavilySearchResults, optional): Defaults to None. |
|
|
|
Returns: |
|
str: The weather information for the specified location. |
|
""" |
|
if search_tool is None: |
|
search_tool = initialize_search_tool() |
|
query = f"current weather in {location}" |
|
return search_tool.run(query) |
|
|
|
|
|
def initialize_recommendation_chain(llm: ChatGroq) -> Runnable: |
|
""" |
|
Initializes the recommendation chain. |
|
|
|
Args: |
|
llm(ChatGroq):The LLM to use |
|
|
|
Returns: |
|
Runnable: A runnable sequence to generate recommendations. |
|
""" |
|
recommendation_prompt = ChatPromptTemplate.from_template(""" |
|
You are a helpful assistant that gives weather-based advice. |
|
|
|
Given the current weather condition: "{weather_condition}", provide: |
|
1. Clothing or activity recommendations suited for this weather. |
|
2. At least one health tip to stay safe or comfortable in this condition. |
|
|
|
Be concise and clear. |
|
""") |
|
return recommendation_prompt | llm |
|
|
|
def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str: |
|
""" |
|
Gives activity/clothing recommendations and health tips based on the weather condition. |
|
|
|
Args: |
|
weather_condition (str): The current weather condition. |
|
recommendation_chain (Runnable, optional): The recommendation chain to use. Defaults to None. |
|
|
|
Returns: |
|
str: Recommendations and health tips for the given weather condition. |
|
""" |
|
if recommendation_chain is None: |
|
llm = initialize_llm() |
|
recommendation_chain = initialize_recommendation_chain(llm) |
|
return recommendation_chain.invoke({"weather_condition": weather_condition}) |
|
|
|
|
|
@tool |
|
def add(x: int, y: int) -> int: |
|
""" |
|
Adds two integers. |
|
|
|
Args: |
|
x (int): The first integer. |
|
y (int): The second integer. |
|
|
|
Returns: |
|
int: The sum of x and y. |
|
""" |
|
return x + y |
|
|
|
@tool |
|
def subtract(x: int, y: int) -> int: |
|
""" |
|
Subtracts two integers. |
|
|
|
Args: |
|
x (int): The first integer. |
|
y (int): The second integer. |
|
|
|
Returns: |
|
int: The difference between x and y. |
|
""" |
|
return x - y |
|
|
|
@tool |
|
def multiply(x: int, y: int) -> int: |
|
""" |
|
Multiplies two integers. |
|
|
|
Args: |
|
x (int): The first integer. |
|
y (int): The second integer. |
|
|
|
Returns: |
|
int: The product of x and y. |
|
""" |
|
return x * y |
|
|
|
@tool |
|
def divide(x: int, y: int) -> float: |
|
""" |
|
Divides two numbers. |
|
|
|
Args: |
|
x (int): The numerator. |
|
y (int): The denominator. |
|
|
|
Returns: |
|
float: The result of the division. |
|
|
|
Raises: |
|
ValueError: If y is zero. |
|
""" |
|
if y == 0: |
|
raise ValueError("Cannot divide by zero.") |
|
return x / y |
|
|
|
@tool |
|
def square(x: int) -> int: |
|
""" |
|
Calculates the square of a number. |
|
|
|
Args: |
|
x (int): The number to square. |
|
|
|
Returns: |
|
int: The square of x. |
|
""" |
|
return x * x |
|
|
|
@tool |
|
def cube(x: int) -> int: |
|
""" |
|
Calculates the cube of a number. |
|
|
|
Args: |
|
x (int): The number to cube. |
|
|
|
Returns: |
|
int: The cube of x. |
|
""" |
|
return x * x * x |
|
|
|
@tool |
|
def power(x: int, y: int) -> int: |
|
""" |
|
Raises a number to the power of another number. |
|
|
|
Args: |
|
x (int): The base number. |
|
y (int): The exponent. |
|
|
|
Returns: |
|
int: x raised to the power of y. |
|
""" |
|
return x ** y |
|
|
|
@tool |
|
def factorial(n: int) -> int: |
|
""" |
|
Calculates the factorial of a non-negative integer. |
|
|
|
Args: |
|
n (int): The non-negative integer. |
|
|
|
Returns: |
|
int: The factorial of n. |
|
|
|
Raises: |
|
ValueError: If n is negative. |
|
""" |
|
if n < 0: |
|
raise ValueError("Factorial is not defined for negative numbers.") |
|
if n == 0 or n == 1: |
|
return 1 |
|
result = 1 |
|
for i in range(2, n + 1): |
|
result *= i |
|
return result |
|
|
|
@tool |
|
def mean(numbers: list) -> float: |
|
""" |
|
Calculates the mean of a list of numbers. |
|
|
|
Args: |
|
numbers (list): A list of numbers. |
|
|
|
Returns: |
|
float: The mean of the numbers. |
|
|
|
Raises: |
|
ValueError: If the list is empty. |
|
""" |
|
if not numbers: |
|
raise ValueError("The list is empty.") |
|
return sum(numbers) / len(numbers) |
|
|
|
@tool |
|
def standard_deviation(numbers: list) -> float: |
|
""" |
|
Calculates the standard deviation of a list of numbers. |
|
|
|
Args: |
|
numbers (list): A list of numbers. |
|
|
|
Returns: |
|
float: The standard deviation of the numbers. |
|
|
|
Raises: |
|
ValueError: If the list is empty. |
|
""" |
|
if not numbers: |
|
raise ValueError("The list is empty.") |
|
mean_value = mean(numbers) |
|
variance = sum((x - mean_value) ** 2 for x in numbers) / len(numbers) |
|
return variance ** 0.5 |
|
|
|
|
|
|
|
class MessagesState(TypedDict): |
|
messages: Annotated[List[HumanMessage], "Messages in the conversation"] |
|
|
|
|
|
PERSIST_DIR = "./chroma_store" |
|
|
|
def initialize_chroma_store(): |
|
|
|
if os.path.exists(PERSIST_DIR): |
|
shutil.rmtree(PERSIST_DIR) |
|
os.makedirs(PERSIST_DIR) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") |
|
|
|
|
|
vectorstore = Chroma( |
|
embedding_function=embeddings, |
|
persist_directory=PERSIST_DIR |
|
) |
|
return vectorstore |
|
|
|
vector_store = initialize_chroma_store() |
|
|
|
|
|
retriever_tool = create_retriever_tool( |
|
retriever=vector_store.as_retriever(), |
|
name="Question Search", |
|
description="A tool to retrieve similar questions from a vector store." |
|
) |
|
|
|
|
|
|
|
@tool |
|
def weather_tool(location: str) -> str: |
|
""" |
|
Fetches the weather for a location. |
|
|
|
Args: |
|
location (str): The location to fetch weather for. |
|
|
|
Returns: |
|
str: The weather information. |
|
""" |
|
return get_weather(location, search_tool) |
|
|
|
|
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.tools.ddg_search import DuckDuckGoSearchRun |
|
from langchain_community.tools.wikipedia.tool import WikipediaQueryRun |
|
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper |
|
from langchain_community.tools.arxiv.tool import ArxivQueryRun |
|
from langchain_community.utilities.arxiv import ArxivAPIWrapper |
|
from langchain.tools import tool |
|
|
|
|
|
@tool |
|
def web_search(query: str) -> str: |
|
"""Search the web for a given query and return the summary.""" |
|
search_tool = TavilySearchResults() |
|
result = search_tool.run(query) |
|
return result[0]['content'] |
|
|
|
|
|
@tool |
|
def duckduckgo_search(query: str) -> str: |
|
"""Search the web using DuckDuckGo for a given query and return the result.""" |
|
search_tool = DuckDuckGoSearchRun(verbose=False) |
|
return search_tool.run(query) |
|
|
|
|
|
@tool |
|
def wikipedia_search(query: str) -> str: |
|
"""Search Wikipedia for a given query and return the top 3 results.""" |
|
wrapper = WikipediaAPIWrapper(top_k_results=3) |
|
wikipedia = WikipediaQueryRun(api_wrapper=wrapper, verbose=False) |
|
return wikipedia.run(query) |
|
|
|
|
|
@tool |
|
def arxiv_search(query: str) -> str: |
|
"""Search arXiv for academic papers based on a query and return the top 3 results.""" |
|
wrapper = ArxivAPIWrapper( |
|
top_k_results=3, |
|
ARXIV_MAX_QUERY_LENGTH=300, |
|
load_max_docs=3, |
|
load_all_available_meta=False, |
|
doc_content_chars_max=40000 |
|
) |
|
arxiv = ArxivQueryRun(api_wrapper=wrapper, verbose=False) |
|
return arxiv.run(query) |
|
|
|
tools = [arxiv_search, duckduckgo_search, web_search,wikipedia_search, |
|
add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation] |
|
|
|
|
|
|
|
llm = ChatGroq( |
|
temperature=0, |
|
model_name="qwen-qwq-32b", |
|
groq_api_key=os.getenv("GROQ_API_KEY") |
|
) |
|
|
|
|
|
|
|
|
|
|
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
|
|
class ToolAgentState(TypedDict): |
|
messages: Annotated[List[HumanMessage], "Messages in the conversation"] |
|
|
|
def assistant(state: ToolAgentState): |
|
return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]} |
|
|
|
|
|
|
|
def build_graph(): |
|
builder = StateGraph(ToolAgentState) |
|
builder.add_node("assistant", assistant) |
|
builder.add_node("tools", ToolNode(tools)) |
|
|
|
builder.set_entry_point("assistant") |
|
builder.add_conditional_edges("assistant", tools_condition) |
|
builder.add_edge("tools", "assistant") |
|
return builder.compile() |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
question = "When did India won a world cup in cricket before 2000?" |
|
graph = build_graph() |
|
messages = [HumanMessage(content=question)] |
|
result = graph.invoke({"messages": messages}) |
|
|
|
for msg in result["messages"]: |
|
msg.pretty_print() |