HarshitSundriyal's picture
Update agent.py
2c9e49b verified
import os
import shutil
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from dotenv import load_dotenv
from langchain.vectorstores import Chroma
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.tools.retriever import create_retriever_tool
from typing import TypedDict, Annotated, List
from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun, ArxivQueryRun
from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
from langchain.tools import Tool
# Load environment variables from .env
load_dotenv()
# Custom Agent Prompt Template
Agent_prompt_template = '''You are a helpful assistant tasked with answering questions using a set of tools.
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
Your answer should only start with "FINAL ANSWER: ", then follows with the answer. '''
sys_msg = SystemMessage(content=Agent_prompt_template)
# Initialize LLM
def initialize_llm():
"""Initializes the ChatGroq LLM."""
llm = ChatGroq(
temperature=0,
model_name="qwen-qwq-32b",
groq_api_key=os.getenv("GROQ_API_KEY")
)
return llm
# Initialize Tavily Search Tool
def initialize_search_tool():
"""Initializes the TavilySearchResults tool."""
return TavilySearchResults()
# Weather tool
def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
"""
Fetches the current weather information for a given location using Tavily search.
Args:
location (str): The name of the location to search for.
search_tool (TavilySearchResults, optional): Defaults to None.
Returns:
str: The weather information for the specified location.
"""
if search_tool is None:
search_tool = initialize_search_tool()
query = f"current weather in {location}"
return search_tool.run(query)
# Recommendation chain
def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
"""
Initializes the recommendation chain.
Args:
llm(ChatGroq):The LLM to use
Returns:
Runnable: A runnable sequence to generate recommendations.
"""
recommendation_prompt = ChatPromptTemplate.from_template("""
You are a helpful assistant that gives weather-based advice.
Given the current weather condition: "{weather_condition}", provide:
1. Clothing or activity recommendations suited for this weather.
2. At least one health tip to stay safe or comfortable in this condition.
Be concise and clear.
""")
return recommendation_prompt | llm
def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str:
"""
Gives activity/clothing recommendations and health tips based on the weather condition.
Args:
weather_condition (str): The current weather condition.
recommendation_chain (Runnable, optional): The recommendation chain to use. Defaults to None.
Returns:
str: Recommendations and health tips for the given weather condition.
"""
if recommendation_chain is None:
llm = initialize_llm()
recommendation_chain = initialize_recommendation_chain(llm)
return recommendation_chain.invoke({"weather_condition": weather_condition})
# Math tools
@tool
def add(x: int, y: int) -> int:
"""
Adds two integers.
Args:
x (int): The first integer.
y (int): The second integer.
Returns:
int: The sum of x and y.
"""
return x + y
@tool
def subtract(x: int, y: int) -> int:
"""
Subtracts two integers.
Args:
x (int): The first integer.
y (int): The second integer.
Returns:
int: The difference between x and y.
"""
return x - y
@tool
def multiply(x: int, y: int) -> int:
"""
Multiplies two integers.
Args:
x (int): The first integer.
y (int): The second integer.
Returns:
int: The product of x and y.
"""
return x * y
@tool
def divide(x: int, y: int) -> float:
"""
Divides two numbers.
Args:
x (int): The numerator.
y (int): The denominator.
Returns:
float: The result of the division.
Raises:
ValueError: If y is zero.
"""
if y == 0:
raise ValueError("Cannot divide by zero.")
return x / y
@tool
def square(x: int) -> int:
"""
Calculates the square of a number.
Args:
x (int): The number to square.
Returns:
int: The square of x.
"""
return x * x
@tool
def cube(x: int) -> int:
"""
Calculates the cube of a number.
Args:
x (int): The number to cube.
Returns:
int: The cube of x.
"""
return x * x * x
@tool
def power(x: int, y: int) -> int:
"""
Raises a number to the power of another number.
Args:
x (int): The base number.
y (int): The exponent.
Returns:
int: x raised to the power of y.
"""
return x ** y
@tool
def factorial(n: int) -> int:
"""
Calculates the factorial of a non-negative integer.
Args:
n (int): The non-negative integer.
Returns:
int: The factorial of n.
Raises:
ValueError: If n is negative.
"""
if n < 0:
raise ValueError("Factorial is not defined for negative numbers.")
if n == 0 or n == 1:
return 1
result = 1
for i in range(2, n + 1):
result *= i
return result
@tool
def mean(numbers: list) -> float:
"""
Calculates the mean of a list of numbers.
Args:
numbers (list): A list of numbers.
Returns:
float: The mean of the numbers.
Raises:
ValueError: If the list is empty.
"""
if not numbers:
raise ValueError("The list is empty.")
return sum(numbers) / len(numbers)
@tool
def standard_deviation(numbers: list) -> float:
"""
Calculates the standard deviation of a list of numbers.
Args:
numbers (list): A list of numbers.
Returns:
float: The standard deviation of the numbers.
Raises:
ValueError: If the list is empty.
"""
if not numbers:
raise ValueError("The list is empty.")
mean_value = mean(numbers)
variance = sum((x - mean_value) ** 2 for x in numbers) / len(numbers)
return variance ** 0.5
# --- Vector Store + Retriever ---
# State schema
class MessagesState(TypedDict):
messages: Annotated[List[HumanMessage], "Messages in the conversation"]
# === VECTOR STORE SETUP ===
PERSIST_DIR = "./chroma_store"
def initialize_chroma_store():
# Optional: clear existing store if desired
if os.path.exists(PERSIST_DIR):
shutil.rmtree(PERSIST_DIR)
os.makedirs(PERSIST_DIR)
# Initialize embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
# Load existing or empty vector store
vectorstore = Chroma(
embedding_function=embeddings,
persist_directory=PERSIST_DIR
)
return vectorstore
vector_store = initialize_chroma_store()
# Create retriever tool
retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="A tool to retrieve similar questions from a vector store."
)
@tool
def weather_tool(location: str) -> str:
"""
Fetches the weather for a location.
Args:
location (str): The location to fetch weather for.
Returns:
str: The weather information.
"""
return get_weather(location, search_tool)
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.tools.ddg_search import DuckDuckGoSearchRun
from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
from langchain_community.tools.arxiv.tool import ArxivQueryRun
from langchain_community.utilities.arxiv import ArxivAPIWrapper
from langchain.tools import tool
# 1. Tavily Web Search Tool (already in correct format)
@tool
def web_search(query: str) -> str:
"""Search the web for a given query and return the summary."""
search_tool = TavilySearchResults()
result = search_tool.run(query)
return result[0]['content']
# 2. DuckDuckGo Search Tool
@tool
def duckduckgo_search(query: str) -> str:
"""Search the web using DuckDuckGo for a given query and return the result."""
search_tool = DuckDuckGoSearchRun(verbose=False)
return search_tool.run(query)
# 3. Wikipedia Search Tool
@tool
def wikipedia_search(query: str) -> str:
"""Search Wikipedia for a given query and return the top 3 results."""
wrapper = WikipediaAPIWrapper(top_k_results=3)
wikipedia = WikipediaQueryRun(api_wrapper=wrapper, verbose=False)
return wikipedia.run(query)
# 4. Arxiv Search Tool
@tool
def arxiv_search(query: str) -> str:
"""Search arXiv for academic papers based on a query and return the top 3 results."""
wrapper = ArxivAPIWrapper(
top_k_results=3,
ARXIV_MAX_QUERY_LENGTH=300,
load_max_docs=3,
load_all_available_meta=False,
doc_content_chars_max=40000
)
arxiv = ArxivQueryRun(api_wrapper=wrapper, verbose=False)
return arxiv.run(query)
tools = [arxiv_search, duckduckgo_search, web_search,wikipedia_search,
add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
# === LLM with Tools ===
llm = ChatGroq(
temperature=0,
model_name="qwen-qwq-32b",
groq_api_key=os.getenv("GROQ_API_KEY")
)
# tools = [weather_tool, wiki_search, web_search,
# add, subtract, multiply, divide, square, cube,
# power, factorial, mean, standard_deviation, arxiv_tool,wikisearch_tool, search_tool ]
llm_with_tools = llm.bind_tools(tools)
# === LangGraph State ===
class ToolAgentState(TypedDict):
messages: Annotated[List[HumanMessage], "Messages in the conversation"]
def assistant(state: ToolAgentState):
return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
# === Build Graph ===
def build_graph():
builder = StateGraph(ToolAgentState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.set_entry_point("assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
# === Run ===
if __name__ == "__main__":
question = "When did India won a world cup in cricket before 2000?"
graph = build_graph()
messages = [HumanMessage(content=question)]
result = graph.invoke({"messages": messages})
for msg in result["messages"]:
msg.pretty_print()