File size: 4,090 Bytes
fc341bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage
from langchain.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from dotenv import load_dotenv

# Load environment variables from .env
load_dotenv()


# Initialize LLM
def initialize_llm():
    """Initializes the ChatGroq LLM."""
    llm = ChatGroq(
        temperature=0,
        model_name="qwen-qwq-32b",
        groq_api_key=os.getenv("GROQ_API_KEY")
    )
    return llm

# Initialize Tavily Search Tool
def initialize_search_tool():
    """Initializes the TavilySearchResults tool."""
    search_tool = TavilySearchResults()
    return search_tool



# Define Tools
def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
    """Fetch the current weather information for a given location using Tavily search."""
    if search_tool is None:
        search_tool = initialize_search_tool()
    query = f"current weather in {location}"
    results = search_tool.run(query)
    return results


def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
    """Initializes the recommendation chain."""
    recommendation_prompt = ChatPromptTemplate.from_template("""
    You are a helpful assistant that gives weather-based advice.

    Given the current weather condition: "{weather_condition}", provide:
    1. Clothing or activity recommendations suited for this weather.
    2. At least one health tip to stay safe or comfortable in this condition.

    Be concise and clear.
    """)
    return recommendation_prompt | llm



def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str:
    """Give activity/clothing recommendations and health tips based on the weather condition using an LLM."""
    if recommendation_chain is None:
        llm = initialize_llm()
        recommendation_chain = initialize_recommendation_chain(llm)
    return recommendation_chain.invoke({"weather_condition": weather_condition})



def build_graph():
    """Build the graph using Groq and custom prompt/tools setup"""

    # Initialize the LLM
    llm = initialize_llm()

    # Initialize Tavily tool
    search_tool = initialize_search_tool()


    # Initialize the recommendation chain
    recommendation_chain = initialize_recommendation_chain(llm)

    # Define tools
    @tool
    def weather_tool(location: str) -> str:
        """Fetch the current weather information for a given location."""
        return get_weather(location, search_tool)  # Pass the search tool

    @tool
    def recommendation_tool(weather_condition: str) -> str:
        """Get recommendations based on weather."""
        return get_recommendation(weather_condition, recommendation_chain)

    tools = [weather_tool, recommendation_tool]

    # Bind tools to LLM
    llm_with_tools = llm.bind_tools(tools)

    # Define assistant node
    def assistant(state: MessagesState):
        """Assistant node"""
        print("Entering assistant node...")
        response = llm_with_tools.invoke(state["messages"])
        print(f"Assistant says: {response.content}")
        return {"messages": [response]}

    # Create graph
    builder = StateGraph(MessagesState)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.set_entry_point("assistant")
    builder.add_conditional_edges("assistant", tools_condition)
    builder.add_edge("tools", "assistant")
    graph = builder.compile()

    return graph



# Main execution
if __name__ == "__main__":
    # Build and run the graph
    graph = build_graph()
    question = "What are the Upanishads?"
    messages = [HumanMessage(content=question)]
    messages = graph.invoke({"messages": messages})
    for m in messages["messages"]:
        m.pretty_print()