Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import json | |
import re | |
import logging | |
import threading | |
from datetime import datetime | |
from dotenv import load_dotenv | |
import gradio as gr | |
load_dotenv() | |
from model_logic import ( | |
get_available_providers, get_model_display_names_for_provider, | |
get_default_model_display_name_for_provider, call_model_stream, MODELS_BY_PROVIDER | |
) | |
from memory_logic import ( | |
initialize_memory_system, | |
add_memory_entry, retrieve_memories_semantic, get_all_memories_cached, clear_all_memory_data_backend, | |
add_rule_entry, retrieve_rules_semantic, remove_rule_entry, get_all_rules_cached, clear_all_rules_data_backend, | |
save_faiss_indices_to_disk, STORAGE_BACKEND as MEMORY_STORAGE_BACKEND # Import for UI | |
) | |
from websearch_logic import scrape_url, search_and_scrape_duckduckgo, search_and_scrape_google | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(threadName)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
for lib_name in ["urllib3", "requests", "huggingface_hub", "PIL.PngImagePlugin", "matplotlib", "gradio_client.client", "multipart.multipart", "httpx", "sentence_transformers", "faiss", "datasets"]: | |
if logging.getLogger(lib_name): logging.getLogger(lib_name).setLevel(logging.WARNING) | |
WEB_SEARCH_ENABLED = os.getenv("WEB_SEARCH_ENABLED", "true").lower() == "true" | |
TOOL_DECISION_PROVIDER_ENV = os.getenv("TOOL_DECISION_PROVIDER", "groq") | |
TOOL_DECISION_MODEL_ID_ENV = os.getenv("TOOL_DECISION_MODEL", "llama3-8b-8192") | |
MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", 7)) | |
current_chat_session_history = [] | |
DEFAULT_SYSTEM_PROMPT = os.getenv( | |
"DEFAULT_SYSTEM_PROMPT", | |
"You are a helpful AI research assistant. Your primary goal is to answer questions and perform research tasks accurately and thoroughly. You can use tools like web search and page browsing. When providing information from the web, cite your sources if possible. If asked to perform a task beyond your capabilities, explain politely. Be concise unless asked for detail." | |
) | |
logger.info(f"App Config: WebSearch={WEB_SEARCH_ENABLED}, ToolDecisionProvider={TOOL_DECISION_PROVIDER_ENV}, ToolDecisionModelID={TOOL_DECISION_MODEL_ID_ENV}, MemoryBackend={MEMORY_STORAGE_BACKEND}") | |
def format_insights_for_prompt(retrieved_insights_list: list[str]) -> tuple[str, list[dict]]: | |
if not retrieved_insights_list: | |
return "No specific guiding principles or learned insights retrieved.", [] | |
parsed = [] | |
for text in retrieved_insights_list: | |
match = re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\](.*)", text.strip(), re.DOTALL | re.IGNORECASE) | |
if match: | |
parsed.append({"type": match.group(1).upper().replace(" ", "_"), "score": match.group(2), "text": match.group(3).strip(), "original": text.strip()}) | |
else: | |
parsed.append({"type": "GENERAL_LEARNING", "score": "0.5", "text": text.strip(), "original": text.strip()}) | |
try: | |
parsed.sort(key=lambda x: float(x["score"]) if x["score"].replace('.', '', 1).isdigit() else -1.0, reverse=True) | |
except ValueError: logger.warning("FORMAT_INSIGHTS: Sort error due to invalid score format.") | |
grouped = {"CORE_RULE": [], "RESPONSE_PRINCIPLE": [], "BEHAVIORAL_ADJUSTMENT": [], "GENERAL_LEARNING": []} | |
for p_item in parsed: grouped.get(p_item["type"], grouped["GENERAL_LEARNING"]).append(f"- (Score: {p_item['score']}) {p_item['text']}") | |
sections = [f"{k.replace('_', ' ').title()}:\n" + "\n".join(v) for k, v in grouped.items() if v] | |
return "\n\n".join(sections) if sections else "No guiding principles retrieved.", parsed | |
def generate_interaction_metrics(user_input: str, bot_response: str, provider: str, model_display_name: str, api_key_override: str = None) -> dict: | |
metric_start_time = time.time() | |
logger.info(f"Generating metrics with: {provider}/{model_display_name}") | |
metric_prompt_content = f"User: \"{user_input}\"\nAI: \"{bot_response}\"\nMetrics: \"takeaway\" (3-7 words), \"response_success_score\" (0.0-1.0), \"future_confidence_score\" (0.0-1.0). Output JSON ONLY, ensure it's a single, valid JSON object." | |
metric_messages = [{"role": "system", "content": "You are a precise JSON output agent. Output a single JSON object containing interaction metrics as requested by the user. Do not include any explanatory text before or after the JSON object."}, {"role": "user", "content": metric_prompt_content}] | |
try: | |
metrics_provider_final, metrics_model_display_final = provider, model_display_name | |
metrics_model_env = os.getenv("METRICS_MODEL") | |
if metrics_model_env and "/" in metrics_model_env: | |
m_prov, m_id = metrics_model_env.split('/', 1) | |
m_disp_name = next((dn for dn, mid in MODELS_BY_PROVIDER.get(m_prov.lower(), {}).get("models", {}).items() if mid == m_id), None) | |
if m_disp_name: metrics_provider_final, metrics_model_display_final = m_prov, m_disp_name | |
else: logger.warning(f"METRICS_MODEL '{metrics_model_env}' not found, using interaction model.") | |
response_chunks = list(call_model_stream(provider=metrics_provider_final, model_display_name=metrics_model_display_final, messages=metric_messages, api_key_override=api_key_override, temperature=0.05, max_tokens=200)) | |
resp_str = "".join(response_chunks).strip() | |
json_match = re.search(r"```json\s*(\{.*?\})\s*```", resp_str, re.DOTALL | re.IGNORECASE) or re.search(r"(\{.*?\})", resp_str, re.DOTALL) | |
if json_match: metrics_data = json.loads(json_match.group(1)) | |
else: | |
logger.warning(f"METRICS_GEN: Non-JSON response from {metrics_provider_final}/{metrics_model_display_final}: '{resp_str}'") | |
return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": "metrics format error"} | |
parsed_metrics = {"takeaway": metrics_data.get("takeaway", "N/A"), "response_success_score": float(metrics_data.get("response_success_score", 0.5)), "future_confidence_score": float(metrics_data.get("future_confidence_score", 0.5)), "error": metrics_data.get("error")} | |
logger.info(f"METRICS_GEN: Generated in {time.time() - metric_start_time:.2f}s. Data: {parsed_metrics}") | |
return parsed_metrics | |
except Exception as e: | |
logger.error(f"METRICS_GEN Error: {e}", exc_info=False) | |
return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": str(e)} | |
def process_user_interaction_gradio(user_input: str, provider_name: str, model_display_name: str, chat_history_for_prompt: list[dict], custom_system_prompt: str = None, ui_api_key_override: str = None): | |
process_start_time = time.time() | |
request_id = os.urandom(4).hex() | |
logger.info(f"PUI_GRADIO [{request_id}] Start. User: '{user_input[:50]}...' Provider: {provider_name}/{model_display_name} Hist_len:{len(chat_history_for_prompt)}") | |
history_str_for_prompt = "\n".join([f"{('User' if t_msg['role'] == 'user' else 'AI')}: {t_msg['content']}" for t_msg in chat_history_for_prompt[-(MAX_HISTORY_TURNS * 2):]]) | |
yield "status", "<i>[Checking guidelines (semantic search)...]</i>" | |
initial_insights = retrieve_rules_semantic(f"{user_input}\n{history_str_for_prompt}", k=5) | |
initial_insights_ctx_str, parsed_initial_insights_list = format_insights_for_prompt(initial_insights) | |
logger.info(f"PUI_GRADIO [{request_id}]: Initial RAG (insights) found {len(initial_insights)}. Context: {initial_insights_ctx_str[:150]}...") | |
action_type, action_input_dict = "quick_respond", {} | |
user_input_lower = user_input.lower() | |
time_before_tool_decision = time.time() | |
if WEB_SEARCH_ENABLED and ("http://" in user_input or "https://" in user_input): | |
url_match = re.search(r'(https?://[^\s]+)', user_input) | |
if url_match: action_type, action_input_dict = "scrape_url_and_report", {"url": url_match.group(1)} | |
if action_type == "quick_respond" and len(user_input.split()) <= 3 and any(kw in user_input_lower for kw in ["hello", "hi", "thanks", "ok", "bye"]) and not "?" in user_input: pass | |
elif action_type == "quick_respond" and WEB_SEARCH_ENABLED and (len(user_input.split()) > 3 or "?" in user_input or any(w in user_input_lower for w in ["what is", "how to", "explain", "search for"])): | |
yield "status", "<i>[LLM choosing best approach...]</i>" | |
history_snippet = "\n".join([f"{msg['role']}: {msg['content'][:100]}" for msg in chat_history_for_prompt[-2:]]) | |
guideline_snippet = initial_insights_ctx_str[:200].replace('\n', ' ') | |
tool_sys_prompt = "You are a precise routing agent... Output JSON only. Example: {\"action\": \"search_duckduckgo_and_report\", \"action_input\": {\"search_engine_query\": \"query\"}}" | |
tool_user_prompt = f"User Query: \"{user_input}\"\nRecent History:\n{history_snippet}\nGuidelines: {guideline_snippet}...\nAvailable Actions: quick_respond, answer_using_conversation_memory, search_duckduckgo_and_report, scrape_url_and_report.\nSelect one action and input. Output JSON." | |
tool_decision_messages = [{"role":"system", "content": tool_sys_prompt}, {"role":"user", "content": tool_user_prompt}] | |
tool_provider, tool_model_id = TOOL_DECISION_PROVIDER_ENV, TOOL_DECISION_MODEL_ID_ENV | |
tool_model_display = next((dn for dn, mid in MODELS_BY_PROVIDER.get(tool_provider.lower(), {}).get("models", {}).items() if mid == tool_model_id), None) | |
if not tool_model_display: tool_model_display = get_default_model_display_name_for_provider(tool_provider) | |
if tool_model_display: | |
try: | |
logger.info(f"PUI_GRADIO [{request_id}]: Tool decision LLM: {tool_provider}/{tool_model_display}") | |
tool_resp_chunks = list(call_model_stream(provider=tool_provider, model_display_name=tool_model_display, messages=tool_decision_messages, temperature=0.0, max_tokens=150)) | |
tool_resp_raw = "".join(tool_resp_chunks).strip() | |
json_match_tool = re.search(r"\{.*\}", tool_resp_raw, re.DOTALL) | |
if json_match_tool: | |
action_data = json.loads(json_match_tool.group(0)) | |
action_type, action_input_dict = action_data.get("action", "quick_respond"), action_data.get("action_input", {}) | |
if not isinstance(action_input_dict, dict): action_input_dict = {} | |
logger.info(f"PUI_GRADIO [{request_id}]: LLM Tool Decision: Action='{action_type}', Input='{action_input_dict}'") | |
else: logger.warning(f"PUI_GRADIO [{request_id}]: Tool decision LLM non-JSON. Raw: {tool_resp_raw}") | |
except Exception as e: logger.error(f"PUI_GRADIO [{request_id}]: Tool decision LLM error: {e}", exc_info=False) | |
else: logger.error(f"No model for tool decision provider {tool_provider}.") | |
elif action_type == "quick_respond" and not WEB_SEARCH_ENABLED and (len(user_input.split()) > 4 or "?" in user_input or any(w in user_input_lower for w in ["remember","recall"])): | |
action_type="answer_using_conversation_memory" | |
logger.info(f"PUI_GRADIO [{request_id}]: Tool decision logic took {time.time() - time_before_tool_decision:.3f}s. Action: {action_type}, Input: {action_input_dict}") | |
yield "status", f"<i>[Path: {action_type}. Preparing response...]</i>" | |
final_system_prompt_str, final_user_prompt_content_str = custom_system_prompt or DEFAULT_SYSTEM_PROMPT, "" | |
if action_type == "quick_respond": | |
final_system_prompt_str += " Respond directly using guidelines & history." | |
final_user_prompt_content_str = f"History:\n{history_str_for_prompt}\nGuidelines:\n{initial_insights_ctx_str}\nQuery: \"{user_input}\"\nResponse:" | |
elif action_type == "answer_using_conversation_memory": | |
yield "status", "<i>[Searching conversation memory (semantic)...]</i>" | |
retrieved_mems = retrieve_memories_semantic(f"User query: {user_input}\nContext:\n{history_str_for_prompt[-1000:]}", k=2) | |
memory_context = "Relevant Past Interactions:\n" + "\n".join([f"- User:{m.get('user_input','')}->AI:{m.get('bot_response','')} (Takeaway:{m.get('metrics',{}).get('takeaway','N/A')})" for m in retrieved_mems]) if retrieved_mems else "No relevant past interactions found." | |
final_system_prompt_str += " Respond using Memory Context, guidelines, & history." | |
final_user_prompt_content_str = f"History:\n{history_str_for_prompt}\nGuidelines:\n{initial_insights_ctx_str}\nMemory Context:\n{memory_context}\nQuery: \"{user_input}\"\nResponse (use memory context if relevant):" | |
elif WEB_SEARCH_ENABLED and action_type in ["search_duckduckgo_and_report", "scrape_url_and_report"]: | |
query_or_url = action_input_dict.get("search_engine_query") if "search" in action_type else action_input_dict.get("url") | |
if not query_or_url: | |
final_system_prompt_str += " Respond directly (web action failed: no input)." | |
final_user_prompt_content_str = f"History:\n{history_str_for_prompt}\nGuidelines:\n{initial_insights_ctx_str}\nQuery: \"{user_input}\"\nResponse:" | |
else: | |
yield "status", f"<i>[Web: '{query_or_url[:60]}'...]</i>" | |
web_results, max_results = [], 1 if action_type == "scrape_url_and_report" else 2 | |
try: | |
if action_type == "search_duckduckgo_and_report": web_results = search_and_scrape_duckduckgo(query_or_url, num_results=max_results) | |
elif action_type == "scrape_url_and_report": | |
res = scrape_url(query_or_url) | |
if res and (res.get("content") or res.get("error")): web_results = [res] | |
except Exception as e: web_results = [{"url": query_or_url, "title": "Tool Error", "error": str(e)}] | |
scraped_content = "\n".join([f"Source {i+1}:\nURL:{r.get('url','N/A')}\nTitle:{r.get('title','N/A')}\nContent:\n{(r.get('content') or r.get('error') or 'N/A')[:3500]}\n---" for i,r in enumerate(web_results)]) if web_results else f"No results from {action_type} for '{query_or_url}'." | |
yield "status", "<i>[Synthesizing web report...]</i>" | |
final_system_prompt_str += " Generate report/answer from web content, history, & guidelines. Cite URLs as [Source X]." | |
final_user_prompt_content_str = f"History:\n{history_str_for_prompt}\nGuidelines:\n{initial_insights_ctx_str}\nWeb Content:\n{scraped_content}\nQuery: \"{user_input}\"\nReport/Response (cite sources [Source X]):" | |
else: # Fallback | |
final_system_prompt_str += " Respond directly (unknown action path)." | |
final_user_prompt_content_str = f"History:\n{history_str_for_prompt}\nGuidelines:\n{initial_insights_ctx_str}\nQuery: \"{user_input}\"\nResponse:" | |
final_llm_messages = [{"role": "system", "content": final_system_prompt_str}, {"role": "user", "content": final_user_prompt_content_str}] | |
logger.debug(f"PUI_GRADIO [{request_id}]: Final LLM System Prompt: {final_system_prompt_str[:200]}...") | |
logger.debug(f"PUI_GRADIO [{request_id}]: Final LLM User Prompt Start: {final_user_prompt_content_str[:200]}...") | |
streamed_response, time_before_llm = "", time.time() | |
try: | |
for chunk in call_model_stream(provider=provider_name, model_display_name=model_display_name, messages=final_llm_messages, api_key_override=ui_api_key_override, temperature=0.6, max_tokens=2500): | |
if isinstance(chunk, str) and chunk.startswith("Error:"): streamed_response += f"\n{chunk}\n"; yield "response_chunk", f"\n{chunk}\n"; break | |
streamed_response += chunk; yield "response_chunk", chunk | |
except Exception as e: streamed_response += f"\n\n(Error: {str(e)[:150]})"; yield "response_chunk", f"\n\n(Error: {str(e)[:150]})" | |
logger.info(f"PUI_GRADIO [{request_id}]: Main LLM stream took {time.time() - time_before_llm:.3f}s.") | |
final_bot_text = streamed_response.strip() or "(No response or error.)" | |
logger.info(f"PUI_GRADIO [{request_id}]: Finished. Total: {time.time() - process_start_time:.2f}s. Resp len: {len(final_bot_text)}") | |
yield "final_response_and_insights", {"response": final_bot_text, "insights_used": parsed_initial_insights_list} | |
def deferred_learning_and_memory_task(user_input: str, bot_response: str, provider: str, model_disp_name: str, insights_reflected: list[dict], api_key_override: str = None): | |
start_time, task_id = time.time(), os.urandom(4).hex() | |
logger.info(f"DEFERRED [{task_id}]: START User='{user_input[:40]}...', Bot='{bot_response[:40]}...'") | |
try: | |
metrics = generate_interaction_metrics(user_input, bot_response, provider, model_disp_name, api_key_override) | |
logger.info(f"DEFERRED [{task_id}]: Metrics: {metrics}") | |
add_memory_entry(user_input, metrics, bot_response) | |
summary = f"User:\"{user_input}\"\nAI:\"{bot_response}\"\nMetrics(takeaway):{metrics.get('takeaway','N/A')},Success:{metrics.get('response_success_score','N/A')}" | |
existing_rules_ctx = "\n".join([f"- \"{r}\"" for r in retrieve_rules_semantic(f"{summary}\n{user_input}", k=10)]) or "No existing rules context." | |
insight_sys_prompt = """You are an expert AI knowledge base curator... (Your full long prompt from ai-learn)... Output ONLY JSON list.""" # Ensure this is the FULL prompt | |
insight_user_prompt = f"""Interaction Summary:\n{summary}\nRelevant Existing Rules:\n{existing_rules_ctx}\nConsidered Principles:\n{json.dumps([p['original'] for p in insights_reflected if 'original' in p]) if insights_reflected else "None"}\nTask: Generate JSON list of add/update operations... (Full task description from ai-learn)""" | |
insight_msgs = [{"role":"system", "content":insight_sys_prompt}, {"role":"user", "content":insight_user_prompt}] | |
insight_prov, insight_model_disp = provider, model_disp_name | |
insight_env_model = os.getenv("INSIGHT_MODEL_OVERRIDE") | |
if insight_env_model and "/" in insight_env_model: | |
i_p, i_id = insight_env_model.split('/', 1) | |
i_d_n = next((dn for dn, mid in MODELS_BY_PROVIDER.get(i_p.lower(), {}).get("models", {}).items() if mid == i_id), None) | |
if i_d_n: insight_prov, insight_model_disp = i_p, i_d_n | |
logger.info(f"DEFERRED [{task_id}]: Generating insights with {insight_prov}/{insight_model_disp}") | |
raw_ops_json = "".join(list(call_model_stream(provider=insight_prov, model_display_name=insight_model_disp, messages=insight_msgs, api_key_override=api_key_override, temperature=0.05, max_tokens=2000))).strip() | |
ops, processed_count = [], 0 | |
json_match_ops = re.search(r"```json\s*(\[.*?\])\s*```", raw_ops_json, re.DOTALL|re.I) or re.search(r"(\[.*?\])", raw_ops_json, re.DOTALL) | |
if json_match_ops: | |
try: ops = json.loads(json_match_ops.group(1)) | |
except Exception as e: logger.error(f"DEFERRED [{task_id}]: JSON ops parse error: {e}") | |
if isinstance(ops, list) and ops: | |
logger.info(f"DEFERRED [{task_id}]: LLM provided {len(ops)} insight ops.") | |
for op in ops: | |
if not isinstance(op, dict): continue | |
action, insight_text = op.get("action","").lower(), op.get("insight","").strip() | |
if not insight_text or not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\]", insight_text, re.I): continue | |
if action == "add": | |
success, _ = add_rule_entry(insight_text) | |
if success: processed_count +=1 | |
elif action == "update": | |
old_insight = op.get("old_insight_to_replace","").strip() | |
if old_insight and old_insight != insight_text: remove_rule_entry(old_insight) | |
success, _ = add_rule_entry(insight_text) | |
if success: processed_count +=1 | |
logger.info(f"DEFERRED [{task_id}]: Processed {processed_count} insight ops.") | |
else: logger.info(f"DEFERRED [{task_id}]: No valid insight ops from LLM.") | |
except Exception as e: logger.error(f"DEFERRED [{task_id}]: CRITICAL ERROR: {e}", exc_info=True) | |
logger.info(f"DEFERRED [{task_id}]: END. Total: {time.time() - start_time:.2f}s") | |
def handle_gradio_chat_submit(user_msg_txt: str, gr_hist_list: list, sel_prov_name: str, sel_model_disp_name: str, ui_api_key: str|None, cust_sys_prompt: str): | |
global current_chat_session_history | |
cleared_input, updated_gr_hist, status_txt = "", list(gr_hist_list), "Initializing..." | |
def_detect_out_md, def_fmt_out_txt, def_dl_btn = gr.Markdown("*Processing...*"), gr.Textbox("*Waiting...*"), gr.DownloadButton(interactive=False, value=None, visible=False) | |
if not user_msg_txt.strip(): | |
status_txt = "Error: Empty message." | |
updated_gr_hist.append((user_msg_txt or "(Empty)", status_txt)) | |
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn); return | |
updated_gr_hist.append((user_msg_txt, "<i>Thinking...</i>")) | |
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn) | |
internal_hist = list(current_chat_session_history); internal_hist.append({"role": "user", "content": user_msg_txt}) | |
if len(internal_hist) > (MAX_HISTORY_TURNS * 2 + 1): | |
if internal_hist[0]["role"] == "system" and len(internal_hist) > (MAX_HISTORY_TURNS * 2 + 1) : internal_hist = [internal_hist[0]] + internal_hist[-(MAX_HISTORY_TURNS * 2):] | |
else: internal_hist = internal_hist[-(MAX_HISTORY_TURNS * 2):] | |
final_bot_resp_acc, insights_used_parsed = "", [] | |
try: | |
processor_gen = process_user_interaction_gradio(user_input=user_msg_txt, provider_name=sel_prov_name, model_display_name=sel_model_disp_name, chat_history_for_prompt=internal_hist, custom_system_prompt=cust_sys_prompt.strip() or None, ui_api_key_override=ui_api_key.strip() if ui_api_key else None) | |
curr_bot_disp_msg = "" | |
for upd_type, upd_data in processor_gen: | |
if upd_type == "status": | |
status_txt = upd_data | |
if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt: updated_gr_hist[-1] = (user_msg_txt, f"{curr_bot_disp_msg} <i>{status_txt}</i>" if curr_bot_disp_msg else f"<i>{status_txt}</i>") | |
elif upd_type == "response_chunk": | |
curr_bot_disp_msg += upd_data | |
if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt: updated_gr_hist[-1] = (user_msg_txt, curr_bot_disp_msg) | |
elif upd_type == "final_response_and_insights": | |
final_bot_resp_acc, insights_used_parsed = upd_data["response"], upd_data["insights_used"] | |
status_txt = "Response complete." | |
if not curr_bot_disp_msg and final_bot_resp_acc : curr_bot_disp_msg = final_bot_resp_acc | |
if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt: updated_gr_hist[-1] = (user_msg_txt, curr_bot_disp_msg or "(No text)") | |
def_fmt_out_txt = gr.Textbox(value=curr_bot_disp_msg) | |
if curr_bot_disp_msg and not curr_bot_disp_msg.startswith("Error:"): | |
def_dl_btn = gr.DownloadButton(label="Download Report (.md)", value=curr_bot_disp_msg, filename=f"ai_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md", visible=True, interactive=True) | |
insights_md = "### Insights Considered:\n" + ("\n".join([f"- **[{i.get('type','N/A')}|{i.get('score','N/A')}]** {i.get('text','N/A')[:100]}..." for i in insights_used_parsed[:3]]) if insights_used_parsed else "*None specific.*") | |
def_detect_out_md = gr.Markdown(insights_md) | |
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn) | |
if upd_type == "final_response_and_insights": break | |
except Exception as e: | |
logger.error(f"Chat handler error: {e}", exc_info=True); status_txt = f"Error: {str(e)[:100]}" | |
if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt: updated_gr_hist[-1] = (user_msg_txt, status_txt) | |
else: updated_gr_hist.append((user_msg_txt, status_txt)) | |
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn); return | |
if final_bot_resp_acc and not final_bot_resp_acc.startswith("Error:"): | |
current_chat_session_history.extend([{"role": "user", "content": user_msg_txt}, {"role": "assistant", "content": final_bot_resp_acc}]) | |
hist_len_check = MAX_HISTORY_TURNS * 2 | |
if current_chat_session_history and current_chat_session_history[0]["role"] == "system": hist_len_check +=1 | |
if len(current_chat_session_history) > hist_len_check: | |
current_chat_session_history = ([current_chat_session_history[0]] if current_chat_session_history[0]["role"] == "system" else []) + current_chat_session_history[- (MAX_HISTORY_TURNS * 2):] | |
threading.Thread(target=deferred_learning_and_memory_task, args=(user_msg_txt, final_bot_resp_acc, sel_prov_name, sel_model_disp_name, insights_used_parsed, ui_api_key.strip() if ui_api_key else None), daemon=True).start() | |
status_txt = "Response complete. Background learning initiated." | |
else: status_txt = "Processing finished; no response or error." | |
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn) | |
def ui_view_rules_action_fn(): return "\n\n---\n\n".join(get_all_rules_cached()) or "No rules found." | |
def ui_upload_rules_action_fn(file_obj, progress=gr.Progress()): | |
if not file_obj: return "No file." | |
try: content = open(file_obj.name, 'r', encoding='utf-8').read() | |
except Exception as e: return f"Error reading file: {e}" | |
if not content.strip(): return "File empty." | |
potential_rules = content.split("\n\n---\n\n") | |
if len(potential_rules) == 1 and "\n" in content: potential_rules = [r.strip() for r in content.splitlines() if r.strip()] | |
if not potential_rules: return "No rules found in file." | |
added, skipped, errors = 0,0,0; total = len(potential_rules) | |
for idx, rule_text in enumerate(potential_rules): | |
if not rule_text.strip(): continue | |
success, status = add_rule_entry(rule_text.strip()) | |
if success: added +=1 | |
elif status == "duplicate": skipped +=1 | |
else: errors +=1 | |
progress((idx+1)/total, desc=f"Processing {idx+1}/{total} rules...") | |
return f"Rules Upload: Total {total}. Added: {added}, Skipped (duplicates): {skipped}, Errors/Invalid: {errors}." | |
def ui_view_memories_action_fn(): return get_all_memories_cached() or [] | |
def ui_upload_memories_action_fn(file_obj, progress=gr.Progress()): | |
if not file_obj: return "No file." | |
try: content = open(file_obj.name, 'r', encoding='utf-8').read() | |
except Exception as e: return f"Error reading file: {e}" | |
if not content.strip(): return "File empty." | |
mem_objs, fmt_errors, added, save_errors = [], 0,0,0 | |
try: | |
parsed = json.loads(content) | |
mem_objs = parsed if isinstance(parsed, list) else [parsed] | |
except json.JSONDecodeError: | |
for line in content.splitlines(): | |
if line.strip(): | |
try: mem_objs.append(json.loads(line)) | |
except: fmt_errors+=1 | |
if not mem_objs and fmt_errors == 0: return "No valid memories in file." | |
total = len(mem_objs) | |
for idx, mem_data in enumerate(mem_objs): | |
if isinstance(mem_data, dict) and all(k in mem_data for k in ["user_input", "bot_response", "metrics"]): | |
success, _ = add_memory_entry(mem_data["user_input"], mem_data["metrics"], mem_data["bot_response"]) | |
if success: added +=1 | |
else: save_errors +=1 | |
else: fmt_errors +=1 | |
progress((idx+1)/total, desc=f"Processing {idx+1}/{total} memories...") | |
return f"Memories Upload: Total {total}. Added: {added}, Format Errors: {fmt_errors}, Save Errors: {save_errors}." | |
custom_theme = gr.themes.Base(primary_hue="teal", secondary_hue="purple", neutral_hue="zinc", text_size="sm", spacing_size="sm", radius_size="sm", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"]) | |
custom_css = """ body { font-family: 'Inter', sans-serif; } .gradio-container { max-width: 96% !important; margin: auto !important; padding-top: 1rem !important; } footer { display: none !important; } .gr-button { white-space: nowrap; } .gr-input, .gr-textarea textarea, .gr-dropdown input { border-radius: 8px !important; } .gr-chatbot .message { border-radius: 10px !important; box-shadow: 0 2px 5px rgba(0,0,0,0.08) !important; } .prose { h1 { font-size: 1.8rem; margin-bottom: 0.6em; margin-top: 0.8em; } h2 { font-size: 1.4rem; margin-bottom: 0.5em; margin-top: 0.7em; } h3 { font-size: 1.15rem; margin-bottom: 0.4em; margin-top: 0.6em; } p { margin-bottom: 0.8em; line-height: 1.65; } ul, ol { margin-left: 1.5em; margin-bottom: 0.8em; } code { background-color: #f1f5f9; padding: 0.2em 0.45em; border-radius: 4px; font-size: 0.9em; } pre > code { display: block; padding: 0.8em; overflow-x: auto; background-color: #f8fafc; border: 1px solid #e2e8f0; border-radius: 6px;}} .compact-group .gr-input-label, .compact-group .gr-dropdown-label { font-size: 0.8rem !important; padding-bottom: 2px !important;}""" | |
with gr.Blocks(theme=custom_theme, css=custom_css, title="AI Research Mega Agent v4") as demo: | |
gr.Markdown("# π AI Research Mega Agent (Advanced Memory & Dynamic Models)", elem_classes="prose") | |
avail_provs, def_prov = get_available_providers(), get_available_providers()[0] if get_available_providers() else None | |
def_models, def_model = get_model_display_names_for_provider(def_prov) if def_prov else [], get_default_model_display_name_for_provider(def_prov) if def_prov else None | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=320): | |
gr.Markdown("## βοΈ Configuration", elem_classes="prose") | |
with gr.Accordion("API & Model Settings", open=True): | |
with gr.Group(elem_classes="compact-group"): | |
gr.Markdown("### LLM Provider & Model", elem_classes="prose") | |
prov_sel_dd = gr.Dropdown(label="Provider", choices=avail_provs, value=def_prov, interactive=True) | |
model_sel_dd = gr.Dropdown(label="Model", choices=def_models, value=def_model, interactive=True) | |
api_key_tb = gr.Textbox(label="API Key Override", type="password", placeholder="Optional", info="Overrides .env key for session.") | |
with gr.Group(elem_classes="compact-group"): | |
gr.Markdown("### System Prompt", elem_classes="prose") | |
sys_prompt_tb = gr.Textbox(label="Custom System Prompt Base", lines=6, value=DEFAULT_SYSTEM_PROMPT, interactive=True) | |
with gr.Accordion("Knowledge Management (Backend: " + MEMORY_STORAGE_BACKEND + ")", open=False): # Show backend type | |
gr.Markdown("### Rules (Insights)", elem_classes="prose"); view_rules_btn = gr.Button("View Rules"); upload_rules_fobj = gr.File(label="Upload Rules (.txt/.jsonl)", file_types=[".txt", ".jsonl"]); rules_stat_tb = gr.Textbox(label="Status", interactive=False, lines=2); clear_rules_btn = gr.Button("β οΈ Clear All Rules", variant="stop") | |
if MEMORY_STORAGE_BACKEND == "RAM": save_faiss_ram_btn = gr.Button("Save FAISS Indices (RAM Backend)") | |
gr.Markdown("### Memories", elem_classes="prose"); view_mems_btn = gr.Button("View Memories"); upload_mems_fobj = gr.File(label="Upload Memories (.jsonl)", file_types=[".jsonl"]); mems_stat_tb = gr.Textbox(label="Status", interactive=False, lines=2); clear_mems_btn = gr.Button("β οΈ Clear All Memories", variant="stop") | |
with gr.Column(scale=3): | |
gr.Markdown("## π¬ AI Research Assistant Chat", elem_classes="prose"); main_chat_disp = gr.Chatbot(label="Chat", height=650, bubble_full_width=False, avatar_images=(None, "https://raw.githubusercontent.com/huggingface/brand-assets/main/hf-logo-with-title.png"), show_copy_button=True, render_markdown=True, sanitize_html=True) | |
with gr.Row(): user_msg_tb = gr.Textbox(show_label=False, placeholder="Ask a question or give an instruction...", scale=7, lines=1, max_lines=5, autofocus=True); send_btn = gr.Button("Send", variant="primary", scale=1, min_width=100) | |
agent_stat_tb = gr.Textbox(label="Agent Status", interactive=False, lines=1, value="Initializing...") | |
with gr.Tabs(): | |
with gr.TabItem("π Report/Output"): gr.Markdown("AI's full response/report.", elem_classes="prose"); fmt_report_tb = gr.Textbox(label="Output", lines=20, interactive=True, show_copy_button=True, value="*Responses appear here...*"); dl_report_btn = gr.DownloadButton(label="Download Report", interactive=False, visible=False) | |
with gr.TabItem("π Details / Data"): gr.Markdown("Intermediate details, loaded data.", elem_classes="prose"); detect_out_md = gr.Markdown("*Insights used or details show here...*"); gr.HTML("<hr style='margin:1em 0;'>"); gr.Markdown("### Rules Viewer", elem_classes="prose"); rules_disp_ta = gr.TextArea(label="Rules Snapshot", lines=10, interactive=False); gr.HTML("<hr style='margin:1em 0;'>"); gr.Markdown("### Memories Viewer", elem_classes="prose"); mems_disp_json = gr.JSON(label="Memories Snapshot") | |
def dyn_upd_model_dd(sel_prov_dyn:str): models_dyn, def_model_dyn = get_model_display_names_for_provider(sel_prov_dyn), get_default_model_display_name_for_provider(sel_prov_dyn); return gr.Dropdown(choices=models_dyn, value=def_model_dyn, interactive=True) | |
prov_sel_dd.change(fn=dyn_upd_model_dd, inputs=prov_sel_dd, outputs=model_sel_dd) | |
chat_ins = [user_msg_tb, main_chat_disp, prov_sel_dd, model_sel_dd, api_key_tb, sys_prompt_tb] | |
chat_outs = [user_msg_tb, main_chat_disp, agent_stat_tb, detect_out_md, fmt_report_tb, dl_report_btn] | |
send_btn.click(fn=handle_gradio_chat_submit, inputs=chat_ins, outputs=chat_outs); user_msg_tb.submit(fn=handle_gradio_chat_submit, inputs=chat_ins, outputs=chat_outs) | |
view_rules_btn.click(fn=ui_view_rules_action_fn, outputs=rules_disp_ta) | |
upload_rules_fobj.upload(fn=ui_upload_rules_action_fn, inputs=[upload_rules_fobj], outputs=[rules_stat_tb], show_progress="full").then(fn=ui_view_rules_action_fn, outputs=rules_disp_ta) | |
clear_rules_btn.click(fn=lambda: "All rules cleared." if clear_all_rules_data_backend() else "Error clearing rules.", outputs=rules_stat_tb).then(fn=ui_view_rules_action_fn, outputs=rules_disp_ta) | |
if MEMORY_STORAGE_BACKEND == "RAM": save_faiss_ram_btn.click(fn=save_faiss_indices_to_disk, outputs=None, success_message="Attempted to save FAISS indices to disk.") | |
view_mems_btn.click(fn=ui_view_memories_action_fn, outputs=mems_disp_json) | |
upload_mems_fobj.upload(fn=ui_upload_memories_action_fn, inputs=[upload_mems_fobj], outputs=[mems_stat_tb], show_progress="full").then(fn=ui_view_memories_action_fn, outputs=mems_disp_json) | |
clear_mems_btn.click(fn=lambda: "All memories cleared." if clear_all_memory_data_backend() else "Error clearing memories.", outputs=mems_stat_tb).then(fn=ui_view_memories_action_fn, outputs=mems_disp_json) | |
def app_load_fn(): initialize_memory_system(); logger.info("App loaded. Memory system initialized."); return f"AI Systems Initialized (Backend: {MEMORY_STORAGE_BACKEND}). Ready." | |
demo.load(fn=app_load_fn, inputs=None, outputs=agent_stat_tb) | |
if __name__ == "__main__": | |
logger.info(f"Starting Gradio AI Research Mega Agent (v4 with Advanced Memory: {MEMORY_STORAGE_BACKEND})...") | |
app_port, app_server = int(os.getenv("GRADIO_PORT", 7860)), os.getenv("GRADIO_SERVER_NAME", "127.0.0.1") | |
app_debug, app_share = os.getenv("GRADIO_DEBUG", "False").lower()=="true", os.getenv("GRADIO_SHARE", "False").lower()=="true" | |
logger.info(f"Launching Gradio server: http://{app_server}:{app_port}. Debug: {app_debug}, Share: {app_share}") | |
demo.queue().launch(server_name=app_server, server_port=app_port, debug=app_debug, share=app_share) | |
logger.info("Gradio application shut down.") |