APRG commited on
Commit
dbc454d
·
verified ·
1 Parent(s): 2dc3ffd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -110
app.py CHANGED
@@ -13,8 +13,7 @@ from langchain_core.tools import tool
13
  from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage, SystemMessage
14
  from random import randint
15
 
16
- from tkinter import messagebox
17
- #messagebox.showinfo("Test", "Script run successfully")
18
 
19
  import gradio as gr
20
  import logging
@@ -26,12 +25,15 @@ class OrderState(TypedDict):
26
  finished: bool
27
 
28
  # System instruction for the BaristaBot
29
- BARISTABOT_SYSINT = (
30
  "system",
31
- "You are a BaristaBot, an interactive cafe ordering system. A human will talk to you about the "
32
- "available products. Answer questions about menu items, help customers place orders, and "
33
- "confirm details before finalizing. Use the provided tools to manage the order."
34
- )
 
 
 
35
 
36
  WELCOME_MSG = "Welcome to the BaristaBot cafe. Type `q` to quit. How may I serve you today?"
37
 
@@ -39,52 +41,24 @@ WELCOME_MSG = "Welcome to the BaristaBot cafe. Type `q` to quit. How may I serve
39
  llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest")
40
 
41
  @tool
42
- def get_menu() -> str:
43
- """Provide the cafe menu."""
44
- #messagebox.showinfo("Test", "Script run successfully")
45
- with open("menu.txt", 'r', encoding = "UTF-8") as f:
46
- return f.read()
47
-
48
- @tool
49
- def add_to_order(drink: str, modifiers: Iterable[str] = []) -> str:
50
- """Adds the specified drink to the customer's order."""
51
- return f"{drink} ({', '.join(modifiers) if modifiers else 'no modifiers'})"
52
-
53
- @tool
54
- def confirm_order() -> str:
55
- """Asks the customer to confirm the order."""
56
- return "Order confirmation requested"
57
-
58
- @tool
59
- def get_order() -> str:
60
- """Returns the current order."""
61
- return "Current order details requested"
62
-
63
- @tool
64
- def clear_order() -> str:
65
- """Clears the current order."""
66
- return "Order cleared"
67
-
68
- @tool
69
- def place_order() -> int:
70
- """Sends the order to the kitchen."""
71
- #messagebox.showinfo("Test", "Order successful!")
72
- return randint(2, 10) # Estimated wait time
73
-
74
- def chatbot_with_tools(state: OrderState) -> OrderState:
75
- """Chatbot with tool handling."""
76
- logging.info(f"Messagelist sent to chatbot node: {[msg.content for msg in state.get('messages', [])]}")
77
  defaults = {"order": [], "finished": False}
78
 
79
  # Ensure we always have at least a system message
80
  if not state.get("messages", []):
81
- new_output = AIMessage(content=WELCOME_MSG)
82
- return defaults | state | {"messages": [SystemMessage(content=BARISTABOT_SYSINT), new_output]}
83
 
84
  try:
85
  # Prepend system instruction if not already present
86
  messages_with_system = [
87
- SystemMessage(content=BARISTABOT_SYSINT)
88
  ] + state.get("messages", [])
89
 
90
  # Process messages through the LLM
@@ -95,52 +69,6 @@ def chatbot_with_tools(state: OrderState) -> OrderState:
95
  # Fallback if LLM processing fails
96
  return defaults | state | {"messages": [AIMessage(content=f"I'm having trouble processing that. {str(e)}")]}
97
 
98
- def order_node(state: OrderState) -> OrderState:
99
- """Handles order-related tool calls."""
100
- logging.info("order node")
101
- tool_msg = state.get("messages", [])[-1]
102
- order = state.get("order", [])
103
- outbound_msgs = []
104
- order_placed = False
105
-
106
- for tool_call in tool_msg.tool_calls:
107
- tool_name = tool_call["name"]
108
- tool_args = tool_call["args"]
109
-
110
- if tool_name == "add_to_order":
111
- modifiers = tool_args.get("modifiers", [])
112
- modifier_str = ", ".join(modifiers) if modifiers else "no modifiers"
113
- order.append(f'{tool_args["drink"]} ({modifier_str})')
114
- response = "\n".join(order)
115
-
116
- elif tool_name == "confirm_order":
117
- response = "Your current order:\n" + "\n".join(order) + "\nIs this correct?"
118
-
119
- elif tool_name == "get_order":
120
- response = "\n".join(order) if order else "(no order)"
121
-
122
- elif tool_name == "clear_order":
123
- order.clear()
124
- response = "Order cleared"
125
-
126
- elif tool_name == "place_order":
127
- order_text = "\n".join(order)
128
- order_placed = True
129
- response = f"Order placed successfully!\nYour order:\n{order_text}\nEstimated wait: {randint(2, 10)} minutes"
130
-
131
- else:
132
- raise NotImplementedError(f'Unknown tool call: {tool_name}')
133
-
134
- outbound_msgs.append(
135
- ToolMessage(
136
- content=response,
137
- name=tool_name,
138
- tool_call_id=tool_call["id"],
139
- )
140
- )
141
-
142
- return {"messages": outbound_msgs, "order": order, "finished": order_placed}
143
-
144
  def maybe_route_to_tools(state: OrderState) -> str:
145
  """Route between chat and tool nodes."""
146
  if not (msgs := state.get("messages", [])):
@@ -149,20 +77,16 @@ def maybe_route_to_tools(state: OrderState) -> str:
149
  msg = msgs[-1]
150
 
151
  if state.get("finished", False):
152
- logging.info("from chatbot GOTO End node")
153
  return END
154
 
155
  elif hasattr(msg, "tool_calls") and len(msg.tool_calls) > 0:
156
  if any(tool["name"] in tool_node.tools_by_name.keys() for tool in msg.tool_calls):
157
- logging.info("from chatbot GOTO tools node")
158
  return "tools"
159
- else:
160
- logging.info("from chatbot GOTO order node")
161
- return "ordering"
162
 
163
- else:
164
- logging.info("from chatbot GOTO human node")
165
- return "human"
166
 
167
  def human_node(state: OrderState) -> OrderState:
168
  """Handle user input."""
@@ -174,7 +98,7 @@ def human_node(state: OrderState) -> OrderState:
174
 
175
  return state
176
 
177
- def maybe_exit_human_node(state: OrderState) -> Literal["chatbot", "__end__"]:
178
  """Determine if conversation should continue."""
179
  if state.get("finished", False):
180
  logging.info("from human GOTO End node")
@@ -184,32 +108,31 @@ def maybe_exit_human_node(state: OrderState) -> Literal["chatbot", "__end__"]:
184
  logging.info("Chatbot response obtained, ending conversation")
185
  return END
186
  else:
187
- logging.info("from human GOTO chatbot node")
188
- return "chatbot"
189
 
190
  # Prepare tools
191
- auto_tools = [get_menu]
192
  tool_node = ToolNode(auto_tools)
193
 
194
- order_tools = [add_to_order, confirm_order, get_order, clear_order, place_order]
195
 
196
  # Bind all tools to the LLM
197
- llm_with_tools = llm.bind_tools(auto_tools + order_tools)
198
 
199
  # Build the graph
200
  graph_builder = StateGraph(OrderState)
201
 
202
  # Add nodes
203
- graph_builder.add_node("chatbot", chatbot_with_tools)
204
  graph_builder.add_node("human", human_node)
205
  graph_builder.add_node("tools", tool_node)
206
- graph_builder.add_node("ordering", order_node)
207
 
208
  # Add edges and routing
209
- graph_builder.add_conditional_edges("chatbot", maybe_route_to_tools)
210
  graph_builder.add_conditional_edges("human", maybe_exit_human_node)
211
- graph_builder.add_edge("tools", "chatbot")
212
- graph_builder.add_edge("ordering", "chatbot")
213
  graph_builder.add_edge(START, "human")
214
 
215
  # Compile the graph
 
13
  from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage, SystemMessage
14
  from random import randint
15
 
16
+ import wikipedia
 
17
 
18
  import gradio as gr
19
  import logging
 
25
  finished: bool
26
 
27
  # System instruction for the BaristaBot
28
+ SYSINT = (
29
  "system",
30
+ "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: "
31
+ "FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings."
32
+ "If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise."
33
+ "If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise."
34
+ "If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
35
+ "If a tool required for task completion is unavailable after multiple tries, return 0."
36
+ )
37
 
38
  WELCOME_MSG = "Welcome to the BaristaBot cafe. Type `q` to quit. How may I serve you today?"
39
 
 
41
  llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest")
42
 
43
  @tool
44
+ def wikipedia_search(title: str) -> str:
45
+ """Provides a short snippet from a Wikipedia article with the given itle"""
46
+ page = wikipedia.page(title)
47
+ return page.content[:100]
48
+
49
+ def agent_node(state: OrderState) -> OrderState:
50
+ """agent with tool handling."""
51
+ print(f"Messagelist sent to agent node: {[msg.content for msg in state.get('messages', [])]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  defaults = {"order": [], "finished": False}
53
 
54
  # Ensure we always have at least a system message
55
  if not state.get("messages", []):
56
+ return defaults | state | {"messages": []}
 
57
 
58
  try:
59
  # Prepend system instruction if not already present
60
  messages_with_system = [
61
+ SystemMessage(content=SYSINT)
62
  ] + state.get("messages", [])
63
 
64
  # Process messages through the LLM
 
69
  # Fallback if LLM processing fails
70
  return defaults | state | {"messages": [AIMessage(content=f"I'm having trouble processing that. {str(e)}")]}
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def maybe_route_to_tools(state: OrderState) -> str:
73
  """Route between chat and tool nodes."""
74
  if not (msgs := state.get("messages", [])):
 
77
  msg = msgs[-1]
78
 
79
  if state.get("finished", False):
80
+ print("from agent GOTO End node")
81
  return END
82
 
83
  elif hasattr(msg, "tool_calls") and len(msg.tool_calls) > 0:
84
  if any(tool["name"] in tool_node.tools_by_name.keys() for tool in msg.tool_calls):
85
+ print("from agent GOTO tools node")
86
  return "tools"
 
 
 
87
 
88
+ print("tool call failed, letting agent try again")
89
+ return "human"
 
90
 
91
  def human_node(state: OrderState) -> OrderState:
92
  """Handle user input."""
 
98
 
99
  return state
100
 
101
+ def maybe_exit_human_node(state: OrderState) -> Literal["agent", "__end__"]:
102
  """Determine if conversation should continue."""
103
  if state.get("finished", False):
104
  logging.info("from human GOTO End node")
 
108
  logging.info("Chatbot response obtained, ending conversation")
109
  return END
110
  else:
111
+ logging.info("from human GOTO agent node")
112
+ return "agent"
113
 
114
  # Prepare tools
115
+ auto_tools = []
116
  tool_node = ToolNode(auto_tools)
117
 
118
+ interactive_tools = [wikipedia_search]
119
 
120
  # Bind all tools to the LLM
121
+ llm_with_tools = llm.bind_tools(auto_tools + interactive_tools)
122
 
123
  # Build the graph
124
  graph_builder = StateGraph(OrderState)
125
 
126
  # Add nodes
127
+ graph_builder.add_node("chatbot", agent_node)
128
  graph_builder.add_node("human", human_node)
129
  graph_builder.add_node("tools", tool_node)
 
130
 
131
  # Add edges and routing
132
+ graph_builder.add_conditional_edges("agent", maybe_route_to_tools)
133
  graph_builder.add_conditional_edges("human", maybe_exit_human_node)
134
+ graph_builder.add_edge("tools", "agent")
135
+ graph_builder.add_edge("ordering", "agent")
136
  graph_builder.add_edge(START, "human")
137
 
138
  # Compile the graph