Jeremy Live
commited on
Commit
·
698b8b0
1
Parent(s):
ffa9fb3
memory per session
Browse files
app.py
CHANGED
@@ -229,8 +229,16 @@ def create_agent():
|
|
229 |
|
230 |
logger.info("Starting agent creation process...")
|
231 |
|
232 |
-
def create_agent(llm, db_connection):
|
233 |
-
"""Create and return a SQL database agent with conversation memory.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
if not llm:
|
235 |
error_msg = "Cannot create agent: LLM is not available"
|
236 |
logger.error(error_msg)
|
@@ -244,10 +252,14 @@ def create_agent(llm, db_connection):
|
|
244 |
try:
|
245 |
logger.info("Creating SQL agent with memory...")
|
246 |
|
247 |
-
# Create conversation memory
|
|
|
|
|
|
|
|
|
248 |
memory = ConversationBufferWindowMemory(
|
249 |
memory_key="chat_history",
|
250 |
-
k=
|
251 |
return_messages=True,
|
252 |
output_key="output"
|
253 |
)
|
@@ -271,15 +283,16 @@ def create_agent(llm, db_connection):
|
|
271 |
return_intermediate_steps=True # Important for memory to work properly
|
272 |
)
|
273 |
|
274 |
-
#
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
283 |
|
284 |
logger.info("SQL agent created successfully")
|
285 |
return agent, ""
|
@@ -497,7 +510,7 @@ def convert_to_messages_format(chat_history):
|
|
497 |
|
498 |
return messages
|
499 |
|
500 |
-
async def stream_agent_response(question: str, chat_history: List[List[str]]) -> Tuple[str, Optional["go.Figure"]]:
|
501 |
"""Procesa la pregunta del usuario y devuelve la respuesta del agente con memoria de conversación."""
|
502 |
global agent # Make sure we can modify the agent's memory
|
503 |
|
@@ -517,7 +530,8 @@ async def stream_agent_response(question: str, chat_history: List[List[str]]) ->
|
|
517 |
user_message = HumanMessage(content=question)
|
518 |
messages.append(user_message)
|
519 |
|
520 |
-
if not agent
|
|
|
521 |
error_msg = (
|
522 |
"## ⚠️ Error: Agente no inicializado\n\n"
|
523 |
"No se pudo inicializar el agente de base de datos. Por favor, verifica que:\n"
|
@@ -531,11 +545,11 @@ async def stream_agent_response(question: str, chat_history: List[List[str]]) ->
|
|
531 |
# Update the agent's memory with the full conversation history
|
532 |
try:
|
533 |
# Rebuild agent memory from chat history pairs
|
534 |
-
if hasattr(
|
535 |
-
|
536 |
for i in range(0, len(messages)-1, 2): # (user, assistant)
|
537 |
if i+1 < len(messages):
|
538 |
-
|
539 |
{"input": messages[i].content},
|
540 |
{"output": messages[i+1].content}
|
541 |
)
|
@@ -550,7 +564,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]]) ->
|
|
550 |
# Execute the agent with proper error handling
|
551 |
try:
|
552 |
# Let the agent use its memory; don't pass raw chat_history
|
553 |
-
response = await
|
554 |
logger.info(f"Agent response type: {type(response)}")
|
555 |
logger.info(f"Agent response content: {str(response)[:500]}...")
|
556 |
|
@@ -640,7 +654,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]]) ->
|
|
640 |
"Devuelve SOLO la consulta SQL en un bloque ```sql``` para responder a: "
|
641 |
f"{question}. No incluyas explicación ni texto adicional."
|
642 |
)
|
643 |
-
sql_only_resp = await
|
644 |
sql_only_text = str(sql_only_resp)
|
645 |
sql_query2 = extract_sql_query(sql_only_text)
|
646 |
if sql_query2 and looks_like_sql(sql_query2):
|
@@ -866,6 +880,9 @@ def create_ui():
|
|
866 |
if not env_ok:
|
867 |
gr.Warning("⚠️ " + env_message)
|
868 |
|
|
|
|
|
|
|
869 |
# Create the chat interface
|
870 |
with gr.Row():
|
871 |
chatbot = gr.Chatbot(
|
@@ -975,12 +992,28 @@ def create_ui():
|
|
975 |
# Hidden component for streaming output
|
976 |
streaming_output_display = gr.Textbox(visible=False)
|
977 |
|
978 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
979 |
|
980 |
def create_application():
|
981 |
"""Create and configure the Gradio application."""
|
982 |
# Create the UI components
|
983 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
984 |
|
985 |
def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
|
986 |
"""Add user message to chat history (messages format) and clear input."""
|
@@ -997,20 +1030,36 @@ def create_application():
|
|
997 |
|
998 |
return "", chat_history
|
999 |
|
1000 |
-
async def bot_response(
|
1001 |
-
|
|
|
|
|
|
|
1002 |
if not chat_history:
|
1003 |
-
return chat_history, None
|
1004 |
|
1005 |
# Ensure last message is a user turn awaiting assistant reply
|
1006 |
last = chat_history[-1]
|
1007 |
if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
|
1008 |
-
return chat_history, None
|
1009 |
|
1010 |
try:
|
1011 |
question = last["content"]
|
1012 |
logger.info(f"Processing question: {question}")
|
1013 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1014 |
# Convert prior messages to pair history for stream_agent_response()
|
1015 |
pair_history: List[List[str]] = []
|
1016 |
i = 0
|
@@ -1029,7 +1078,11 @@ def create_application():
|
|
1029 |
i += 1
|
1030 |
|
1031 |
# Call the agent for this new user question
|
1032 |
-
assistant_message, chart_fig = await stream_agent_response(
|
|
|
|
|
|
|
|
|
1033 |
|
1034 |
# Append assistant message back into messages history
|
1035 |
chat_history.append({"role": "assistant", "content": assistant_message})
|
@@ -1066,14 +1119,14 @@ def create_application():
|
|
1066 |
)
|
1067 |
|
1068 |
logger.info("Response generation complete")
|
1069 |
-
return chat_history, chart_fig
|
1070 |
|
1071 |
except Exception as e:
|
1072 |
error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
|
1073 |
logger.error(error_msg, exc_info=True)
|
1074 |
# Ensure we add an assistant error message for the UI
|
1075 |
chat_history.append({"role": "assistant", "content": error_msg})
|
1076 |
-
return chat_history, None
|
1077 |
|
1078 |
# Event handlers
|
1079 |
with demo:
|
@@ -1085,8 +1138,8 @@ def create_application():
|
|
1085 |
queue=True
|
1086 |
).then(
|
1087 |
fn=bot_response,
|
1088 |
-
inputs=[chatbot],
|
1089 |
-
outputs=[chatbot, chart_display],
|
1090 |
api_name="ask"
|
1091 |
)
|
1092 |
|
@@ -1098,8 +1151,8 @@ def create_application():
|
|
1098 |
queue=True
|
1099 |
).then(
|
1100 |
fn=bot_response,
|
1101 |
-
inputs=[chatbot],
|
1102 |
-
outputs=[chatbot, chart_display]
|
1103 |
)
|
1104 |
|
1105 |
return demo
|
|
|
229 |
|
230 |
logger.info("Starting agent creation process...")
|
231 |
|
232 |
+
def create_agent(llm, db_connection, *, run_test: bool = True):
|
233 |
+
"""Create and return a SQL database agent with conversation memory.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
llm: Language model instance
|
237 |
+
db_connection: SQLDatabase connection
|
238 |
+
run_test: If True, executes a small test query which may add entries
|
239 |
+
to agent memory. Disable for per-session agents to start
|
240 |
+
with a clean memory.
|
241 |
+
"""
|
242 |
if not llm:
|
243 |
error_msg = "Cannot create agent: LLM is not available"
|
244 |
logger.error(error_msg)
|
|
|
252 |
try:
|
253 |
logger.info("Creating SQL agent with memory...")
|
254 |
|
255 |
+
# Create conversation memory (configurable window via env MEMORY_K)
|
256 |
+
try:
|
257 |
+
memory_k = int(os.getenv("MEMORY_K", "8"))
|
258 |
+
except Exception:
|
259 |
+
memory_k = 8
|
260 |
memory = ConversationBufferWindowMemory(
|
261 |
memory_key="chat_history",
|
262 |
+
k=memory_k,
|
263 |
return_messages=True,
|
264 |
output_key="output"
|
265 |
)
|
|
|
283 |
return_intermediate_steps=True # Important for memory to work properly
|
284 |
)
|
285 |
|
286 |
+
# if run_test:
|
287 |
+
# # Test the agent with a simple query
|
288 |
+
# logger.info("Testing agent with a simple query...")
|
289 |
+
# try:
|
290 |
+
# test_query = "SELECT 1"
|
291 |
+
# test_result = agent.run(test_query)
|
292 |
+
# logger.info(f"Agent test query successful: {str(test_result)[:200]}...")
|
293 |
+
# except Exception as e:
|
294 |
+
# logger.warning(f"Agent test query failed (this might be expected): {str(e)}")
|
295 |
+
# # Continue even if test fails, as it might be due to model limitations
|
296 |
|
297 |
logger.info("SQL agent created successfully")
|
298 |
return agent, ""
|
|
|
510 |
|
511 |
return messages
|
512 |
|
513 |
+
async def stream_agent_response(question: str, chat_history: List[List[str]], session_agent=None) -> Tuple[str, Optional["go.Figure"]]:
|
514 |
"""Procesa la pregunta del usuario y devuelve la respuesta del agente con memoria de conversación."""
|
515 |
global agent # Make sure we can modify the agent's memory
|
516 |
|
|
|
530 |
user_message = HumanMessage(content=question)
|
531 |
messages.append(user_message)
|
532 |
|
533 |
+
active_agent = session_agent if session_agent is not None else agent
|
534 |
+
if not active_agent:
|
535 |
error_msg = (
|
536 |
"## ⚠️ Error: Agente no inicializado\n\n"
|
537 |
"No se pudo inicializar el agente de base de datos. Por favor, verifica que:\n"
|
|
|
545 |
# Update the agent's memory with the full conversation history
|
546 |
try:
|
547 |
# Rebuild agent memory from chat history pairs
|
548 |
+
if hasattr(active_agent, 'memory') and active_agent.memory is not None:
|
549 |
+
active_agent.memory.clear()
|
550 |
for i in range(0, len(messages)-1, 2): # (user, assistant)
|
551 |
if i+1 < len(messages):
|
552 |
+
active_agent.memory.save_context(
|
553 |
{"input": messages[i].content},
|
554 |
{"output": messages[i+1].content}
|
555 |
)
|
|
|
564 |
# Execute the agent with proper error handling
|
565 |
try:
|
566 |
# Let the agent use its memory; don't pass raw chat_history
|
567 |
+
response = await active_agent.ainvoke({"input": question})
|
568 |
logger.info(f"Agent response type: {type(response)}")
|
569 |
logger.info(f"Agent response content: {str(response)[:500]}...")
|
570 |
|
|
|
654 |
"Devuelve SOLO la consulta SQL en un bloque ```sql``` para responder a: "
|
655 |
f"{question}. No incluyas explicación ni texto adicional."
|
656 |
)
|
657 |
+
sql_only_resp = await active_agent.ainvoke({"input": sql_only_prompt})
|
658 |
sql_only_text = str(sql_only_resp)
|
659 |
sql_query2 = extract_sql_query(sql_only_text)
|
660 |
if sql_query2 and looks_like_sql(sql_query2):
|
|
|
880 |
if not env_ok:
|
881 |
gr.Warning("⚠️ " + env_message)
|
882 |
|
883 |
+
# Create session-scoped state for the agent with memory
|
884 |
+
session_agent_state = gr.State(value=None)
|
885 |
+
|
886 |
# Create the chat interface
|
887 |
with gr.Row():
|
888 |
chatbot = gr.Chatbot(
|
|
|
992 |
# Hidden component for streaming output
|
993 |
streaming_output_display = gr.Textbox(visible=False)
|
994 |
|
995 |
+
return (
|
996 |
+
demo,
|
997 |
+
chatbot,
|
998 |
+
chart_display,
|
999 |
+
question_input,
|
1000 |
+
submit_button,
|
1001 |
+
streaming_output_display,
|
1002 |
+
session_agent_state,
|
1003 |
+
)
|
1004 |
|
1005 |
def create_application():
|
1006 |
"""Create and configure the Gradio application."""
|
1007 |
# Create the UI components
|
1008 |
+
(
|
1009 |
+
demo,
|
1010 |
+
chatbot,
|
1011 |
+
chart_display,
|
1012 |
+
question_input,
|
1013 |
+
submit_button,
|
1014 |
+
streaming_output_display,
|
1015 |
+
session_agent_state,
|
1016 |
+
) = create_ui()
|
1017 |
|
1018 |
def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
|
1019 |
"""Add user message to chat history (messages format) and clear input."""
|
|
|
1030 |
|
1031 |
return "", chat_history
|
1032 |
|
1033 |
+
async def bot_response(
|
1034 |
+
chat_history: List[Dict[str, str]],
|
1035 |
+
session_agent,
|
1036 |
+
) -> Tuple[List[Dict[str, str]], Optional[go.Figure], Any]:
|
1037 |
+
"""Generate bot response using a session-scoped agent with memory."""
|
1038 |
if not chat_history:
|
1039 |
+
return chat_history, None, session_agent
|
1040 |
|
1041 |
# Ensure last message is a user turn awaiting assistant reply
|
1042 |
last = chat_history[-1]
|
1043 |
if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
|
1044 |
+
return chat_history, None, session_agent
|
1045 |
|
1046 |
try:
|
1047 |
question = last["content"]
|
1048 |
logger.info(f"Processing question: {question}")
|
1049 |
|
1050 |
+
# Ensure we have a session-specific agent with memory
|
1051 |
+
local_agent = session_agent
|
1052 |
+
if local_agent is None:
|
1053 |
+
try:
|
1054 |
+
local_agent, _ = create_agent(llm, db_connection, run_test=False)
|
1055 |
+
session_agent = local_agent
|
1056 |
+
logger.info("Created new session agent with memory")
|
1057 |
+
except Exception as e:
|
1058 |
+
logger.error(f"Could not create session agent: {e}")
|
1059 |
+
# Fallback to global agent if available
|
1060 |
+
session_agent = None
|
1061 |
+
local_agent = None
|
1062 |
+
|
1063 |
# Convert prior messages to pair history for stream_agent_response()
|
1064 |
pair_history: List[List[str]] = []
|
1065 |
i = 0
|
|
|
1078 |
i += 1
|
1079 |
|
1080 |
# Call the agent for this new user question
|
1081 |
+
assistant_message, chart_fig = await stream_agent_response(
|
1082 |
+
question,
|
1083 |
+
pair_history,
|
1084 |
+
session_agent=session_agent,
|
1085 |
+
)
|
1086 |
|
1087 |
# Append assistant message back into messages history
|
1088 |
chat_history.append({"role": "assistant", "content": assistant_message})
|
|
|
1119 |
)
|
1120 |
|
1121 |
logger.info("Response generation complete")
|
1122 |
+
return chat_history, chart_fig, session_agent
|
1123 |
|
1124 |
except Exception as e:
|
1125 |
error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
|
1126 |
logger.error(error_msg, exc_info=True)
|
1127 |
# Ensure we add an assistant error message for the UI
|
1128 |
chat_history.append({"role": "assistant", "content": error_msg})
|
1129 |
+
return chat_history, None, session_agent
|
1130 |
|
1131 |
# Event handlers
|
1132 |
with demo:
|
|
|
1138 |
queue=True
|
1139 |
).then(
|
1140 |
fn=bot_response,
|
1141 |
+
inputs=[chatbot, session_agent_state],
|
1142 |
+
outputs=[chatbot, chart_display, session_agent_state],
|
1143 |
api_name="ask"
|
1144 |
)
|
1145 |
|
|
|
1151 |
queue=True
|
1152 |
).then(
|
1153 |
fn=bot_response,
|
1154 |
+
inputs=[chatbot, session_agent_state],
|
1155 |
+
outputs=[chatbot, chart_display, session_agent_state]
|
1156 |
)
|
1157 |
|
1158 |
return demo
|