File size: 4,090 Bytes
fc341bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage
from langchain.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from dotenv import load_dotenv
# Load environment variables from .env
load_dotenv()
# Initialize LLM
def initialize_llm():
"""Initializes the ChatGroq LLM."""
llm = ChatGroq(
temperature=0,
model_name="qwen-qwq-32b",
groq_api_key=os.getenv("GROQ_API_KEY")
)
return llm
# Initialize Tavily Search Tool
def initialize_search_tool():
"""Initializes the TavilySearchResults tool."""
search_tool = TavilySearchResults()
return search_tool
# Define Tools
def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
"""Fetch the current weather information for a given location using Tavily search."""
if search_tool is None:
search_tool = initialize_search_tool()
query = f"current weather in {location}"
results = search_tool.run(query)
return results
def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
"""Initializes the recommendation chain."""
recommendation_prompt = ChatPromptTemplate.from_template("""
You are a helpful assistant that gives weather-based advice.
Given the current weather condition: "{weather_condition}", provide:
1. Clothing or activity recommendations suited for this weather.
2. At least one health tip to stay safe or comfortable in this condition.
Be concise and clear.
""")
return recommendation_prompt | llm
def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str:
"""Give activity/clothing recommendations and health tips based on the weather condition using an LLM."""
if recommendation_chain is None:
llm = initialize_llm()
recommendation_chain = initialize_recommendation_chain(llm)
return recommendation_chain.invoke({"weather_condition": weather_condition})
def build_graph():
"""Build the graph using Groq and custom prompt/tools setup"""
# Initialize the LLM
llm = initialize_llm()
# Initialize Tavily tool
search_tool = initialize_search_tool()
# Initialize the recommendation chain
recommendation_chain = initialize_recommendation_chain(llm)
# Define tools
@tool
def weather_tool(location: str) -> str:
"""Fetch the current weather information for a given location."""
return get_weather(location, search_tool) # Pass the search tool
@tool
def recommendation_tool(weather_condition: str) -> str:
"""Get recommendations based on weather."""
return get_recommendation(weather_condition, recommendation_chain)
tools = [weather_tool, recommendation_tool]
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# Define assistant node
def assistant(state: MessagesState):
"""Assistant node"""
print("Entering assistant node...")
response = llm_with_tools.invoke(state["messages"])
print(f"Assistant says: {response.content}")
return {"messages": [response]}
# Create graph
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.set_entry_point("assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
graph = builder.compile()
return graph
# Main execution
if __name__ == "__main__":
# Build and run the graph
graph = build_graph()
question = "What are the Upanishads?"
messages = [HumanMessage(content=question)]
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()
|