Jeremy Live commited on
Commit
698b8b0
·
1 Parent(s): ffa9fb3

memory per session

Browse files
Files changed (1) hide show
  1. app.py +86 -33
app.py CHANGED
@@ -229,8 +229,16 @@ def create_agent():
229
 
230
  logger.info("Starting agent creation process...")
231
 
232
- def create_agent(llm, db_connection):
233
- """Create and return a SQL database agent with conversation memory."""
 
 
 
 
 
 
 
 
234
  if not llm:
235
  error_msg = "Cannot create agent: LLM is not available"
236
  logger.error(error_msg)
@@ -244,10 +252,14 @@ def create_agent(llm, db_connection):
244
  try:
245
  logger.info("Creating SQL agent with memory...")
246
 
247
- # Create conversation memory
 
 
 
 
248
  memory = ConversationBufferWindowMemory(
249
  memory_key="chat_history",
250
- k=5, # Keep last 5 message exchanges in memory
251
  return_messages=True,
252
  output_key="output"
253
  )
@@ -271,15 +283,16 @@ def create_agent(llm, db_connection):
271
  return_intermediate_steps=True # Important for memory to work properly
272
  )
273
 
274
- # Test the agent with a simple query
275
- logger.info("Testing agent with a simple query...")
276
- try:
277
- test_query = "SELECT 1"
278
- test_result = agent.run(test_query)
279
- logger.info(f"Agent test query successful: {str(test_result)[:200]}...")
280
- except Exception as e:
281
- logger.warning(f"Agent test query failed (this might be expected): {str(e)}")
282
- # Continue even if test fails, as it might be due to model limitations
 
283
 
284
  logger.info("SQL agent created successfully")
285
  return agent, ""
@@ -497,7 +510,7 @@ def convert_to_messages_format(chat_history):
497
 
498
  return messages
499
 
500
- async def stream_agent_response(question: str, chat_history: List[List[str]]) -> Tuple[str, Optional["go.Figure"]]:
501
  """Procesa la pregunta del usuario y devuelve la respuesta del agente con memoria de conversación."""
502
  global agent # Make sure we can modify the agent's memory
503
 
@@ -517,7 +530,8 @@ async def stream_agent_response(question: str, chat_history: List[List[str]]) ->
517
  user_message = HumanMessage(content=question)
518
  messages.append(user_message)
519
 
520
- if not agent:
 
521
  error_msg = (
522
  "## ⚠️ Error: Agente no inicializado\n\n"
523
  "No se pudo inicializar el agente de base de datos. Por favor, verifica que:\n"
@@ -531,11 +545,11 @@ async def stream_agent_response(question: str, chat_history: List[List[str]]) ->
531
  # Update the agent's memory with the full conversation history
532
  try:
533
  # Rebuild agent memory from chat history pairs
534
- if hasattr(agent, 'memory') and agent.memory is not None:
535
- agent.memory.clear()
536
  for i in range(0, len(messages)-1, 2): # (user, assistant)
537
  if i+1 < len(messages):
538
- agent.memory.save_context(
539
  {"input": messages[i].content},
540
  {"output": messages[i+1].content}
541
  )
@@ -550,7 +564,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]]) ->
550
  # Execute the agent with proper error handling
551
  try:
552
  # Let the agent use its memory; don't pass raw chat_history
553
- response = await agent.ainvoke({"input": question})
554
  logger.info(f"Agent response type: {type(response)}")
555
  logger.info(f"Agent response content: {str(response)[:500]}...")
556
 
@@ -640,7 +654,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]]) ->
640
  "Devuelve SOLO la consulta SQL en un bloque ```sql``` para responder a: "
641
  f"{question}. No incluyas explicación ni texto adicional."
642
  )
643
- sql_only_resp = await agent.ainvoke({"input": sql_only_prompt})
644
  sql_only_text = str(sql_only_resp)
645
  sql_query2 = extract_sql_query(sql_only_text)
646
  if sql_query2 and looks_like_sql(sql_query2):
@@ -866,6 +880,9 @@ def create_ui():
866
  if not env_ok:
867
  gr.Warning("⚠️ " + env_message)
868
 
 
 
 
869
  # Create the chat interface
870
  with gr.Row():
871
  chatbot = gr.Chatbot(
@@ -975,12 +992,28 @@ def create_ui():
975
  # Hidden component for streaming output
976
  streaming_output_display = gr.Textbox(visible=False)
977
 
978
- return demo, chatbot, chart_display, question_input, submit_button, streaming_output_display
 
 
 
 
 
 
 
 
979
 
980
  def create_application():
981
  """Create and configure the Gradio application."""
982
  # Create the UI components
983
- demo, chatbot, chart_display, question_input, submit_button, streaming_output_display = create_ui()
 
 
 
 
 
 
 
 
984
 
985
  def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
986
  """Add user message to chat history (messages format) and clear input."""
@@ -997,20 +1030,36 @@ def create_application():
997
 
998
  return "", chat_history
999
 
1000
- async def bot_response(chat_history: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Optional[go.Figure]]:
1001
- """Generate bot response for messages-format chat history and return optional chart figure."""
 
 
 
1002
  if not chat_history:
1003
- return chat_history, None
1004
 
1005
  # Ensure last message is a user turn awaiting assistant reply
1006
  last = chat_history[-1]
1007
  if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
1008
- return chat_history, None
1009
 
1010
  try:
1011
  question = last["content"]
1012
  logger.info(f"Processing question: {question}")
1013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1014
  # Convert prior messages to pair history for stream_agent_response()
1015
  pair_history: List[List[str]] = []
1016
  i = 0
@@ -1029,7 +1078,11 @@ def create_application():
1029
  i += 1
1030
 
1031
  # Call the agent for this new user question
1032
- assistant_message, chart_fig = await stream_agent_response(question, pair_history)
 
 
 
 
1033
 
1034
  # Append assistant message back into messages history
1035
  chat_history.append({"role": "assistant", "content": assistant_message})
@@ -1066,14 +1119,14 @@ def create_application():
1066
  )
1067
 
1068
  logger.info("Response generation complete")
1069
- return chat_history, chart_fig
1070
 
1071
  except Exception as e:
1072
  error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
1073
  logger.error(error_msg, exc_info=True)
1074
  # Ensure we add an assistant error message for the UI
1075
  chat_history.append({"role": "assistant", "content": error_msg})
1076
- return chat_history, None
1077
 
1078
  # Event handlers
1079
  with demo:
@@ -1085,8 +1138,8 @@ def create_application():
1085
  queue=True
1086
  ).then(
1087
  fn=bot_response,
1088
- inputs=[chatbot],
1089
- outputs=[chatbot, chart_display],
1090
  api_name="ask"
1091
  )
1092
 
@@ -1098,8 +1151,8 @@ def create_application():
1098
  queue=True
1099
  ).then(
1100
  fn=bot_response,
1101
- inputs=[chatbot],
1102
- outputs=[chatbot, chart_display]
1103
  )
1104
 
1105
  return demo
 
229
 
230
  logger.info("Starting agent creation process...")
231
 
232
+ def create_agent(llm, db_connection, *, run_test: bool = True):
233
+ """Create and return a SQL database agent with conversation memory.
234
+
235
+ Args:
236
+ llm: Language model instance
237
+ db_connection: SQLDatabase connection
238
+ run_test: If True, executes a small test query which may add entries
239
+ to agent memory. Disable for per-session agents to start
240
+ with a clean memory.
241
+ """
242
  if not llm:
243
  error_msg = "Cannot create agent: LLM is not available"
244
  logger.error(error_msg)
 
252
  try:
253
  logger.info("Creating SQL agent with memory...")
254
 
255
+ # Create conversation memory (configurable window via env MEMORY_K)
256
+ try:
257
+ memory_k = int(os.getenv("MEMORY_K", "8"))
258
+ except Exception:
259
+ memory_k = 8
260
  memory = ConversationBufferWindowMemory(
261
  memory_key="chat_history",
262
+ k=memory_k,
263
  return_messages=True,
264
  output_key="output"
265
  )
 
283
  return_intermediate_steps=True # Important for memory to work properly
284
  )
285
 
286
+ # if run_test:
287
+ # # Test the agent with a simple query
288
+ # logger.info("Testing agent with a simple query...")
289
+ # try:
290
+ # test_query = "SELECT 1"
291
+ # test_result = agent.run(test_query)
292
+ # logger.info(f"Agent test query successful: {str(test_result)[:200]}...")
293
+ # except Exception as e:
294
+ # logger.warning(f"Agent test query failed (this might be expected): {str(e)}")
295
+ # # Continue even if test fails, as it might be due to model limitations
296
 
297
  logger.info("SQL agent created successfully")
298
  return agent, ""
 
510
 
511
  return messages
512
 
513
+ async def stream_agent_response(question: str, chat_history: List[List[str]], session_agent=None) -> Tuple[str, Optional["go.Figure"]]:
514
  """Procesa la pregunta del usuario y devuelve la respuesta del agente con memoria de conversación."""
515
  global agent # Make sure we can modify the agent's memory
516
 
 
530
  user_message = HumanMessage(content=question)
531
  messages.append(user_message)
532
 
533
+ active_agent = session_agent if session_agent is not None else agent
534
+ if not active_agent:
535
  error_msg = (
536
  "## ⚠️ Error: Agente no inicializado\n\n"
537
  "No se pudo inicializar el agente de base de datos. Por favor, verifica que:\n"
 
545
  # Update the agent's memory with the full conversation history
546
  try:
547
  # Rebuild agent memory from chat history pairs
548
+ if hasattr(active_agent, 'memory') and active_agent.memory is not None:
549
+ active_agent.memory.clear()
550
  for i in range(0, len(messages)-1, 2): # (user, assistant)
551
  if i+1 < len(messages):
552
+ active_agent.memory.save_context(
553
  {"input": messages[i].content},
554
  {"output": messages[i+1].content}
555
  )
 
564
  # Execute the agent with proper error handling
565
  try:
566
  # Let the agent use its memory; don't pass raw chat_history
567
+ response = await active_agent.ainvoke({"input": question})
568
  logger.info(f"Agent response type: {type(response)}")
569
  logger.info(f"Agent response content: {str(response)[:500]}...")
570
 
 
654
  "Devuelve SOLO la consulta SQL en un bloque ```sql``` para responder a: "
655
  f"{question}. No incluyas explicación ni texto adicional."
656
  )
657
+ sql_only_resp = await active_agent.ainvoke({"input": sql_only_prompt})
658
  sql_only_text = str(sql_only_resp)
659
  sql_query2 = extract_sql_query(sql_only_text)
660
  if sql_query2 and looks_like_sql(sql_query2):
 
880
  if not env_ok:
881
  gr.Warning("⚠️ " + env_message)
882
 
883
+ # Create session-scoped state for the agent with memory
884
+ session_agent_state = gr.State(value=None)
885
+
886
  # Create the chat interface
887
  with gr.Row():
888
  chatbot = gr.Chatbot(
 
992
  # Hidden component for streaming output
993
  streaming_output_display = gr.Textbox(visible=False)
994
 
995
+ return (
996
+ demo,
997
+ chatbot,
998
+ chart_display,
999
+ question_input,
1000
+ submit_button,
1001
+ streaming_output_display,
1002
+ session_agent_state,
1003
+ )
1004
 
1005
  def create_application():
1006
  """Create and configure the Gradio application."""
1007
  # Create the UI components
1008
+ (
1009
+ demo,
1010
+ chatbot,
1011
+ chart_display,
1012
+ question_input,
1013
+ submit_button,
1014
+ streaming_output_display,
1015
+ session_agent_state,
1016
+ ) = create_ui()
1017
 
1018
  def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
1019
  """Add user message to chat history (messages format) and clear input."""
 
1030
 
1031
  return "", chat_history
1032
 
1033
+ async def bot_response(
1034
+ chat_history: List[Dict[str, str]],
1035
+ session_agent,
1036
+ ) -> Tuple[List[Dict[str, str]], Optional[go.Figure], Any]:
1037
+ """Generate bot response using a session-scoped agent with memory."""
1038
  if not chat_history:
1039
+ return chat_history, None, session_agent
1040
 
1041
  # Ensure last message is a user turn awaiting assistant reply
1042
  last = chat_history[-1]
1043
  if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
1044
+ return chat_history, None, session_agent
1045
 
1046
  try:
1047
  question = last["content"]
1048
  logger.info(f"Processing question: {question}")
1049
 
1050
+ # Ensure we have a session-specific agent with memory
1051
+ local_agent = session_agent
1052
+ if local_agent is None:
1053
+ try:
1054
+ local_agent, _ = create_agent(llm, db_connection, run_test=False)
1055
+ session_agent = local_agent
1056
+ logger.info("Created new session agent with memory")
1057
+ except Exception as e:
1058
+ logger.error(f"Could not create session agent: {e}")
1059
+ # Fallback to global agent if available
1060
+ session_agent = None
1061
+ local_agent = None
1062
+
1063
  # Convert prior messages to pair history for stream_agent_response()
1064
  pair_history: List[List[str]] = []
1065
  i = 0
 
1078
  i += 1
1079
 
1080
  # Call the agent for this new user question
1081
+ assistant_message, chart_fig = await stream_agent_response(
1082
+ question,
1083
+ pair_history,
1084
+ session_agent=session_agent,
1085
+ )
1086
 
1087
  # Append assistant message back into messages history
1088
  chat_history.append({"role": "assistant", "content": assistant_message})
 
1119
  )
1120
 
1121
  logger.info("Response generation complete")
1122
+ return chat_history, chart_fig, session_agent
1123
 
1124
  except Exception as e:
1125
  error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
1126
  logger.error(error_msg, exc_info=True)
1127
  # Ensure we add an assistant error message for the UI
1128
  chat_history.append({"role": "assistant", "content": error_msg})
1129
+ return chat_history, None, session_agent
1130
 
1131
  # Event handlers
1132
  with demo:
 
1138
  queue=True
1139
  ).then(
1140
  fn=bot_response,
1141
+ inputs=[chatbot, session_agent_state],
1142
+ outputs=[chatbot, chart_display, session_agent_state],
1143
  api_name="ask"
1144
  )
1145
 
 
1151
  queue=True
1152
  ).then(
1153
  fn=bot_response,
1154
+ inputs=[chatbot, session_agent_state],
1155
+ outputs=[chatbot, chart_display, session_agent_state]
1156
  )
1157
 
1158
  return demo