Jeremy Live commited on
Commit
683b6ad
·
1 Parent(s): 0e57fc2

Revert "memory per session"

Browse files

This reverts commit 698b8b0faf1bc95aa16ebc5d7bb8c28bde8ecafc.

Files changed (1) hide show
  1. app.py +33 -86
app.py CHANGED
@@ -229,16 +229,8 @@ def create_agent():
229
 
230
  logger.info("Starting agent creation process...")
231
 
232
- def create_agent(llm, db_connection, *, run_test: bool = True):
233
- """Create and return a SQL database agent with conversation memory.
234
-
235
- Args:
236
- llm: Language model instance
237
- db_connection: SQLDatabase connection
238
- run_test: If True, executes a small test query which may add entries
239
- to agent memory. Disable for per-session agents to start
240
- with a clean memory.
241
- """
242
  if not llm:
243
  error_msg = "Cannot create agent: LLM is not available"
244
  logger.error(error_msg)
@@ -252,14 +244,10 @@ def create_agent(llm, db_connection, *, run_test: bool = True):
252
  try:
253
  logger.info("Creating SQL agent with memory...")
254
 
255
- # Create conversation memory (configurable window via env MEMORY_K)
256
- try:
257
- memory_k = int(os.getenv("MEMORY_K", "8"))
258
- except Exception:
259
- memory_k = 8
260
  memory = ConversationBufferWindowMemory(
261
  memory_key="chat_history",
262
- k=memory_k,
263
  return_messages=True,
264
  output_key="output"
265
  )
@@ -283,16 +271,15 @@ def create_agent(llm, db_connection, *, run_test: bool = True):
283
  return_intermediate_steps=True # Important for memory to work properly
284
  )
285
 
286
- # if run_test:
287
- # # Test the agent with a simple query
288
- # logger.info("Testing agent with a simple query...")
289
- # try:
290
- # test_query = "SELECT 1"
291
- # test_result = agent.run(test_query)
292
- # logger.info(f"Agent test query successful: {str(test_result)[:200]}...")
293
- # except Exception as e:
294
- # logger.warning(f"Agent test query failed (this might be expected): {str(e)}")
295
- # # Continue even if test fails, as it might be due to model limitations
296
 
297
  logger.info("SQL agent created successfully")
298
  return agent, ""
@@ -510,7 +497,7 @@ def convert_to_messages_format(chat_history):
510
 
511
  return messages
512
 
513
- async def stream_agent_response(question: str, chat_history: List[List[str]], session_agent=None) -> Tuple[str, Optional["go.Figure"]]:
514
  """Procesa la pregunta del usuario y devuelve la respuesta del agente con memoria de conversación."""
515
  global agent # Make sure we can modify the agent's memory
516
 
@@ -530,8 +517,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]], se
530
  user_message = HumanMessage(content=question)
531
  messages.append(user_message)
532
 
533
- active_agent = session_agent if session_agent is not None else agent
534
- if not active_agent:
535
  error_msg = (
536
  "## ⚠️ Error: Agente no inicializado\n\n"
537
  "No se pudo inicializar el agente de base de datos. Por favor, verifica que:\n"
@@ -545,11 +531,11 @@ async def stream_agent_response(question: str, chat_history: List[List[str]], se
545
  # Update the agent's memory with the full conversation history
546
  try:
547
  # Rebuild agent memory from chat history pairs
548
- if hasattr(active_agent, 'memory') and active_agent.memory is not None:
549
- active_agent.memory.clear()
550
  for i in range(0, len(messages)-1, 2): # (user, assistant)
551
  if i+1 < len(messages):
552
- active_agent.memory.save_context(
553
  {"input": messages[i].content},
554
  {"output": messages[i+1].content}
555
  )
@@ -564,7 +550,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]], se
564
  # Execute the agent with proper error handling
565
  try:
566
  # Let the agent use its memory; don't pass raw chat_history
567
- response = await active_agent.ainvoke({"input": question})
568
  logger.info(f"Agent response type: {type(response)}")
569
  logger.info(f"Agent response content: {str(response)[:500]}...")
570
 
@@ -654,7 +640,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]], se
654
  "Devuelve SOLO la consulta SQL en un bloque ```sql``` para responder a: "
655
  f"{question}. No incluyas explicación ni texto adicional."
656
  )
657
- sql_only_resp = await active_agent.ainvoke({"input": sql_only_prompt})
658
  sql_only_text = str(sql_only_resp)
659
  sql_query2 = extract_sql_query(sql_only_text)
660
  if sql_query2 and looks_like_sql(sql_query2):
@@ -880,9 +866,6 @@ def create_ui():
880
  if not env_ok:
881
  gr.Warning("⚠️ " + env_message)
882
 
883
- # Create session-scoped state for the agent with memory
884
- session_agent_state = gr.State(value=None)
885
-
886
  # Create the chat interface
887
  with gr.Row():
888
  chatbot = gr.Chatbot(
@@ -992,28 +975,12 @@ def create_ui():
992
  # Hidden component for streaming output
993
  streaming_output_display = gr.Textbox(visible=False)
994
 
995
- return (
996
- demo,
997
- chatbot,
998
- chart_display,
999
- question_input,
1000
- submit_button,
1001
- streaming_output_display,
1002
- session_agent_state,
1003
- )
1004
 
1005
  def create_application():
1006
  """Create and configure the Gradio application."""
1007
  # Create the UI components
1008
- (
1009
- demo,
1010
- chatbot,
1011
- chart_display,
1012
- question_input,
1013
- submit_button,
1014
- streaming_output_display,
1015
- session_agent_state,
1016
- ) = create_ui()
1017
 
1018
  def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
1019
  """Add user message to chat history (messages format) and clear input."""
@@ -1030,36 +997,20 @@ def create_application():
1030
 
1031
  return "", chat_history
1032
 
1033
- async def bot_response(
1034
- chat_history: List[Dict[str, str]],
1035
- session_agent,
1036
- ) -> Tuple[List[Dict[str, str]], Optional[go.Figure], Any]:
1037
- """Generate bot response using a session-scoped agent with memory."""
1038
  if not chat_history:
1039
- return chat_history, None, session_agent
1040
 
1041
  # Ensure last message is a user turn awaiting assistant reply
1042
  last = chat_history[-1]
1043
  if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
1044
- return chat_history, None, session_agent
1045
 
1046
  try:
1047
  question = last["content"]
1048
  logger.info(f"Processing question: {question}")
1049
 
1050
- # Ensure we have a session-specific agent with memory
1051
- local_agent = session_agent
1052
- if local_agent is None:
1053
- try:
1054
- local_agent, _ = create_agent(llm, db_connection, run_test=False)
1055
- session_agent = local_agent
1056
- logger.info("Created new session agent with memory")
1057
- except Exception as e:
1058
- logger.error(f"Could not create session agent: {e}")
1059
- # Fallback to global agent if available
1060
- session_agent = None
1061
- local_agent = None
1062
-
1063
  # Convert prior messages to pair history for stream_agent_response()
1064
  pair_history: List[List[str]] = []
1065
  i = 0
@@ -1078,11 +1029,7 @@ def create_application():
1078
  i += 1
1079
 
1080
  # Call the agent for this new user question
1081
- assistant_message, chart_fig = await stream_agent_response(
1082
- question,
1083
- pair_history,
1084
- session_agent=session_agent,
1085
- )
1086
 
1087
  # Append assistant message back into messages history
1088
  chat_history.append({"role": "assistant", "content": assistant_message})
@@ -1119,14 +1066,14 @@ def create_application():
1119
  )
1120
 
1121
  logger.info("Response generation complete")
1122
- return chat_history, chart_fig, session_agent
1123
 
1124
  except Exception as e:
1125
  error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
1126
  logger.error(error_msg, exc_info=True)
1127
  # Ensure we add an assistant error message for the UI
1128
  chat_history.append({"role": "assistant", "content": error_msg})
1129
- return chat_history, None, session_agent
1130
 
1131
  # Event handlers
1132
  with demo:
@@ -1138,8 +1085,8 @@ def create_application():
1138
  queue=True
1139
  ).then(
1140
  fn=bot_response,
1141
- inputs=[chatbot, session_agent_state],
1142
- outputs=[chatbot, chart_display, session_agent_state],
1143
  api_name="ask"
1144
  )
1145
 
@@ -1151,8 +1098,8 @@ def create_application():
1151
  queue=True
1152
  ).then(
1153
  fn=bot_response,
1154
- inputs=[chatbot, session_agent_state],
1155
- outputs=[chatbot, chart_display, session_agent_state]
1156
  )
1157
 
1158
  return demo
 
229
 
230
  logger.info("Starting agent creation process...")
231
 
232
+ def create_agent(llm, db_connection):
233
+ """Create and return a SQL database agent with conversation memory."""
 
 
 
 
 
 
 
 
234
  if not llm:
235
  error_msg = "Cannot create agent: LLM is not available"
236
  logger.error(error_msg)
 
244
  try:
245
  logger.info("Creating SQL agent with memory...")
246
 
247
+ # Create conversation memory
 
 
 
 
248
  memory = ConversationBufferWindowMemory(
249
  memory_key="chat_history",
250
+ k=5, # Keep last 5 message exchanges in memory
251
  return_messages=True,
252
  output_key="output"
253
  )
 
271
  return_intermediate_steps=True # Important for memory to work properly
272
  )
273
 
274
+ # Test the agent with a simple query
275
+ logger.info("Testing agent with a simple query...")
276
+ try:
277
+ test_query = "SELECT 1"
278
+ test_result = agent.run(test_query)
279
+ logger.info(f"Agent test query successful: {str(test_result)[:200]}...")
280
+ except Exception as e:
281
+ logger.warning(f"Agent test query failed (this might be expected): {str(e)}")
282
+ # Continue even if test fails, as it might be due to model limitations
 
283
 
284
  logger.info("SQL agent created successfully")
285
  return agent, ""
 
497
 
498
  return messages
499
 
500
+ async def stream_agent_response(question: str, chat_history: List[List[str]]) -> Tuple[str, Optional["go.Figure"]]:
501
  """Procesa la pregunta del usuario y devuelve la respuesta del agente con memoria de conversación."""
502
  global agent # Make sure we can modify the agent's memory
503
 
 
517
  user_message = HumanMessage(content=question)
518
  messages.append(user_message)
519
 
520
+ if not agent:
 
521
  error_msg = (
522
  "## ⚠️ Error: Agente no inicializado\n\n"
523
  "No se pudo inicializar el agente de base de datos. Por favor, verifica que:\n"
 
531
  # Update the agent's memory with the full conversation history
532
  try:
533
  # Rebuild agent memory from chat history pairs
534
+ if hasattr(agent, 'memory') and agent.memory is not None:
535
+ agent.memory.clear()
536
  for i in range(0, len(messages)-1, 2): # (user, assistant)
537
  if i+1 < len(messages):
538
+ agent.memory.save_context(
539
  {"input": messages[i].content},
540
  {"output": messages[i+1].content}
541
  )
 
550
  # Execute the agent with proper error handling
551
  try:
552
  # Let the agent use its memory; don't pass raw chat_history
553
+ response = await agent.ainvoke({"input": question})
554
  logger.info(f"Agent response type: {type(response)}")
555
  logger.info(f"Agent response content: {str(response)[:500]}...")
556
 
 
640
  "Devuelve SOLO la consulta SQL en un bloque ```sql``` para responder a: "
641
  f"{question}. No incluyas explicación ni texto adicional."
642
  )
643
+ sql_only_resp = await agent.ainvoke({"input": sql_only_prompt})
644
  sql_only_text = str(sql_only_resp)
645
  sql_query2 = extract_sql_query(sql_only_text)
646
  if sql_query2 and looks_like_sql(sql_query2):
 
866
  if not env_ok:
867
  gr.Warning("⚠️ " + env_message)
868
 
 
 
 
869
  # Create the chat interface
870
  with gr.Row():
871
  chatbot = gr.Chatbot(
 
975
  # Hidden component for streaming output
976
  streaming_output_display = gr.Textbox(visible=False)
977
 
978
+ return demo, chatbot, chart_display, question_input, submit_button, streaming_output_display
 
 
 
 
 
 
 
 
979
 
980
  def create_application():
981
  """Create and configure the Gradio application."""
982
  # Create the UI components
983
+ demo, chatbot, chart_display, question_input, submit_button, streaming_output_display = create_ui()
 
 
 
 
 
 
 
 
984
 
985
  def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
986
  """Add user message to chat history (messages format) and clear input."""
 
997
 
998
  return "", chat_history
999
 
1000
+ async def bot_response(chat_history: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Optional[go.Figure]]:
1001
+ """Generate bot response for messages-format chat history and return optional chart figure."""
 
 
 
1002
  if not chat_history:
1003
+ return chat_history, None
1004
 
1005
  # Ensure last message is a user turn awaiting assistant reply
1006
  last = chat_history[-1]
1007
  if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
1008
+ return chat_history, None
1009
 
1010
  try:
1011
  question = last["content"]
1012
  logger.info(f"Processing question: {question}")
1013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1014
  # Convert prior messages to pair history for stream_agent_response()
1015
  pair_history: List[List[str]] = []
1016
  i = 0
 
1029
  i += 1
1030
 
1031
  # Call the agent for this new user question
1032
+ assistant_message, chart_fig = await stream_agent_response(question, pair_history)
 
 
 
 
1033
 
1034
  # Append assistant message back into messages history
1035
  chat_history.append({"role": "assistant", "content": assistant_message})
 
1066
  )
1067
 
1068
  logger.info("Response generation complete")
1069
+ return chat_history, chart_fig
1070
 
1071
  except Exception as e:
1072
  error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
1073
  logger.error(error_msg, exc_info=True)
1074
  # Ensure we add an assistant error message for the UI
1075
  chat_history.append({"role": "assistant", "content": error_msg})
1076
+ return chat_history, None
1077
 
1078
  # Event handlers
1079
  with demo:
 
1085
  queue=True
1086
  ).then(
1087
  fn=bot_response,
1088
+ inputs=[chatbot],
1089
+ outputs=[chatbot, chart_display],
1090
  api_name="ask"
1091
  )
1092
 
 
1098
  queue=True
1099
  ).then(
1100
  fn=bot_response,
1101
+ inputs=[chatbot],
1102
+ outputs=[chatbot, chart_display]
1103
  )
1104
 
1105
  return demo