File size: 3,266 Bytes
88f3c8e
 
 
 
 
 
f5a41c7
88f3c8e
 
 
 
85e517e
88f3c8e
 
3b9afc9
130d92d
85e517e
3b9afc9
 
88f3c8e
 
 
 
 
 
 
 
 
 
 
a66da40
130d92d
3b9afc9
88f3c8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6ac13f
88f3c8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6ac13f
88f3c8e
 
 
 
 
 
 
 
3b9afc9
88f3c8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b9afc9
88f3c8e
 
 
 
 
 
 
 
 
 
 
3b9afc9
88f3c8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition
from langchain_ollama.chat_models import ChatOllama

from tools import search_tool, weather_info_tool, hub_stats_tool


import gradio as gr


from smolagents import GradioUI, CodeAgent, HfApiModel


from retriever import load_guest_dataset

from dotenv import load_dotenv
import yaml
load_dotenv()

# Import our custom tools from their modules
from retriever import load_guest_dataset


# Import our custom tools from their modules
from retriever import load_guest_dataset

model = HfApiModel()
guest_info_tool = load_guest_dataset()

# Load prompts from YAML file
with open("prompts.yaml", 'r') as stream:
    prompt_templates = yaml.safe_load(stream)

# Get the system prompt from the YAML file
system_prompt = prompt_templates["system_prompt"]

# Initialize the chat model
chat = ChatOllama(model="qwen2:7b",
                verbose=True)

# Define available tools
tools = [
        load_guest_dataset, 
        search_tool, 
        weather_info_tool, 
        hub_stats_tool
        ]

# Bind tools to the chat model
chat_with_tools = chat.bind_tools(tools)
tool_node=ToolNode(tools)
# Generate the AgentState and Agent graph
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


def assistant(state: AgentState):
    # If this is the first message, add the system prompt
    if len(state["messages"]) == 1 and isinstance(state["messages"][0], HumanMessage):
        # Add system message with the ReAct framework prompt
        messages = [SystemMessage(content=system_prompt)] + state["messages"]
    else:
        messages = state["messages"]

    return {
        "messages": [chat_with_tools.invoke(messages)],
    }


## The graph
builder = StateGraph(AgentState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", tool_node)

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message requires a tool, route to tools
    # Otherwise, provide a direct response
    tools_condition,
)
builder.add_edge("tools", "assistant")
graph_app = builder.compile()

graph_state = {}

# Gradio expects a function with (chat_history, user_message) -> (updated_chat_history)
def chat_fn(message, history):
    session_id = "session-123"

    result = graph_app.invoke(
        {"messages": [HumanMessage(content=message)]},
        config={"configurable": {"thread_id": session_id}}
    )

    response = result["messages"][-1].content
    history.append((message, response))
    return history, ""

 
# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("### LangGraph Chat with Gradio")
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Type your message")
    send_btn = gr.Button("Send")

    # Hook the send button
    send_btn.click(fn=chat_fn, inputs=[msg, chatbot], outputs=[chatbot, msg])


if __name__ == "__main__":
    demo.launch()