Spaces:
Sleeping
Sleeping
import os | |
import json | |
import requests | |
import time | |
from datetime import datetime | |
from dotenv import load_dotenv | |
from datasets import load_dataset, Dataset | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import logging | |
import sqlite3 | |
import re | |
import threading # For background tasks | |
import html # For escaping HTML in Gradio Markdown | |
# Gradio | |
import gradio as gr | |
# Local scraper module | |
from scraper_module import scrape_url, search_and_scrape_duckduckgo, search_and_scrape_google | |
# --- Load Environment Variables --- | |
load_dotenv() | |
# --- ai-learn Configuration (Copied and adapted) --- | |
STORAGE_BACKEND = os.getenv("STORAGE_BACKEND", "SQLITE").upper() | |
SQLITE_DB_PATH = os.getenv("SQLITE_DB_PATH", "data/chatbot_memory.db") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
HF_MEMORY_DATASET_NAME = os.getenv("HF_MEMORY_DATASET_NAME", "your_hf_username/ai-brain-memory") # Replace with your actual dataset | |
HF_RULES_DATASET_NAME = os.getenv("HF_RULES_DATASET_NAME", "your_hf_username/ai-brain-rules") # Replace with your actual dataset | |
WEB_SEARCH_ENABLED = os.getenv("WEB_SEARCH_ENABLED", "true").lower() == "true" | |
TOOL_DECISION_PROVIDER = os.getenv("TOOL_DECISION_PROVIDER", "groq") | |
TOOL_DECISION_MODEL = os.getenv("TOOL_DECISION_MODEL", "llama3-8b-8192") | |
API_KEYS = {key: os.getenv(f"{key.upper()}_API_KEY") for key in ["HUGGINGFACE", "GROQ", "OPENROUTER", "TOGETHERAI", "COHERE", "XAI", "OPENAI", "TAVILY"]} | |
API_URLS = { | |
"HUGGINGFACE": "https://api-inference.huggingface.co/models/", "GROQ": "https://api.groq.com/openai/v1/chat/completions", | |
"OPENROUTER": "https://openrouter.ai/api/v1/chat/completions", "TOGETHERAI": "https://api.together.ai/v1/chat/completions", | |
"COHERE": "https://api.cohere.ai/v2/chat", "XAI": "https://api.x.ai/v1/chat/completions", | |
"OPENAI": "https://api.openai.com/v1/chat/completions", | |
} | |
# --- Logging Setup --- | |
logging.basicConfig(level=logging.INFO, # Default to INFO, DEBUG can be very verbose | |
format='%(asctime)s - %(name)s - %(levelname)s - %(threadName)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Reduce verbosity of some libraries | |
for lib_name in ["urllib3", "requests", "huggingface_hub", "PIL.PngImagePlugin", "datasets", | |
"sentence_transformers.SentenceTransformer", "faiss.loader", "duckduckgo_search", | |
"chardet", "charset_normalizer", "filelock", "matplotlib", "gradio_client.client"]: | |
logging.getLogger(lib_name).setLevel(logging.WARNING) | |
logger.info(f"Initial Config: Storage={STORAGE_BACKEND}, WebSearch={WEB_SEARCH_ENABLED}, ToolDecision={TOOL_DECISION_PROVIDER}/{TOOL_DECISION_MODEL}") | |
# --- Globals for RAG (from ai-learn) --- | |
embedder, dimension = None, None | |
faiss_memory_index, memory_texts = None, [] | |
faiss_rules_index, rules_texts = None, [] # rules_texts are insights | |
# --- Models Data (from ai-learn, for deferred learning model selection and UI) --- | |
# This needs to be kept up-to-date or dynamically fetched if possible | |
models_data_global_scope = { | |
"huggingface": ["mistralai/Mixtral-8x7B-Instruct-v0.1"], | |
"groq": ["llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it", "llama3-70b-8192"], | |
"openrouter": ["meta-llama/llama-3.1-8b-instruct", "openai/gpt-4o-mini", "anthropic/claude-3.5-sonnet", "openai/gpt-4o"], | |
"togetherai": ["meta-llama/Llama-3-8b-chat-hf"], | |
"cohere": ["command-r-plus"], | |
"xai": ["grok-1.5-flash"], | |
"openai": ["gpt-4o-mini", "gpt-3.5-turbo", "gpt-4o"] | |
} | |
# Map Gradio display names to (provider, model_id) | |
# This needs to be more robust if we expand UI model selection beyond Groq | |
# For now, node_search's groq_model_select is specific to Groq. | |
# If we add a general provider dropdown, this mapping becomes more important. | |
# The current Gradio UI just has a Groq model selector. | |
# `handle_research_chat_submit` will use the provider "groq" and the selected model. | |
# --- Database & RAG Initialization (from ai-learn) --- | |
def get_sqlite_connection(): | |
db_dir = os.path.dirname(SQLITE_DB_PATH) | |
if db_dir and not os.path.exists(db_dir): os.makedirs(db_dir, exist_ok=True) | |
return sqlite3.connect(SQLITE_DB_PATH) | |
def init_sqlite_db(): | |
if STORAGE_BACKEND != "SQLITE": return | |
try: | |
with get_sqlite_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("CREATE TABLE IF NOT EXISTS memories (id INTEGER PRIMARY KEY AUTOINCREMENT, memory_json TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)") | |
cursor.execute("CREATE TABLE IF NOT EXISTS rules (id INTEGER PRIMARY KEY AUTOINCREMENT, rule_text TEXT NOT NULL UNIQUE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)") # "rules" are insights | |
conn.commit() | |
logger.info("SQLite DB tables checked/created for memories and rules.") | |
except Exception as e: logger.error(f"SQLite init error: {e}", exc_info=True) | |
def load_data_on_startup(): | |
global memory_texts, rules_texts, faiss_memory_index, faiss_rules_index, embedder, dimension | |
startup_time_start = time.time() | |
if not embedder: | |
try: | |
logger.info("Loading SentenceTransformer model (all-MiniLM-L6-v2)...") | |
model_load_start = time.time() | |
embedder = SentenceTransformer('all-MiniLM-L6-v2', cache_folder="./sentence_transformer_cache") | |
dimension = embedder.get_sentence_embedding_dimension() | |
if not dimension: dimension = 384 # Fallback if property not found | |
logger.info(f"SentenceTransformer loaded in {time.time() - model_load_start:.2f}s. Dimension: {dimension}") | |
except Exception as e: logger.critical(f"FATAL: Error loading SentenceTransformer: {e}", exc_info=True); raise | |
logger.info(f"LOAD_DATA: Backend: {STORAGE_BACKEND}") | |
# Load Memories | |
m_store_load_start = time.time() | |
m_store = [] | |
if STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN: | |
try: | |
logger.info(f"Loading memories from Hugging Face Dataset: {HF_MEMORY_DATASET_NAME}") | |
d = load_dataset(HF_MEMORY_DATASET_NAME, token=HF_TOKEN, download_mode="force_redownload", trust_remote_code=True) | |
if "train" in d and "memories" in d["train"].column_names: m_store = [m for m in d["train"]["memories"] if isinstance(m, str)] | |
else: logger.warning(f"HF Dataset {HF_MEMORY_DATASET_NAME} structure unexpected or 'memories' column missing.") | |
except Exception as e: logger.error(f"LOAD_DATA (Memories HF): {e}") | |
elif STORAGE_BACKEND == "SQLITE": | |
try: | |
with get_sqlite_connection() as conn: m_store = [r[0] for r in conn.execute("SELECT memory_json FROM memories ORDER BY created_at ASC")] | |
except Exception as e: logger.error(f"LOAD_DATA (Memories SQL): {e}") | |
memory_texts[:] = m_store | |
logger.info(f"Loaded {len(memory_texts)} memories from {STORAGE_BACKEND} in {time.time() - m_store_load_start:.2f}s") | |
m_faiss_build_start = time.time() | |
try: | |
faiss_memory_index = faiss.IndexFlatL2(dimension) | |
if memory_texts: | |
logger.info(f"Encoding {len(memory_texts)} memories for FAISS index...") | |
em = embedder.encode([json.loads(mt)['user_input'] + " " + json.loads(mt)['bot_response'] for mt in memory_texts], convert_to_numpy=True, show_progress_bar=True) | |
if em.ndim == 2 and em.shape[0] == len(memory_texts) and em.shape[1] == dimension: faiss_memory_index.add(np.array(em, dtype=np.float32)) | |
else: logger.error(f"LOAD_DATA (Memories FAISS): Embedding shape error. Expected ({len(memory_texts)}, {dimension}), Got {em.shape if hasattr(em, 'shape') else 'N/A'}") | |
logger.info(f"Memory FAISS index built/loaded in {time.time()-m_faiss_build_start:.2f}s. Index total: {getattr(faiss_memory_index, 'ntotal', 'N/I')}") | |
except Exception as e: logger.error(f"LOAD_DATA (Memories FAISS build): {e}", exc_info=True) | |
# Load Rules (Insights) | |
r_store_load_start = time.time() | |
r_store = [] | |
if STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN: | |
try: | |
logger.info(f"Loading rules from Hugging Face Dataset: {HF_RULES_DATASET_NAME}") | |
d = load_dataset(HF_RULES_DATASET_NAME, token=HF_TOKEN, trust_remote_code=True, download_mode="force_redownload") | |
if "train" in d and "rule_text" in d["train"].column_names: r_store = [str(r).strip() for r in d["train"]["rule_text"] if str(r).strip()] | |
else: logger.warning(f"HF Dataset {HF_RULES_DATASET_NAME} structure unexpected or 'rule_text' column missing.") | |
except Exception as e: logger.error(f"LOAD_DATA (Rules HF): {e}") | |
elif STORAGE_BACKEND == "SQLITE": | |
try: | |
with get_sqlite_connection() as conn: r_store = [str(r[0]).strip() for r in conn.execute("SELECT rule_text FROM rules ORDER BY created_at ASC") if str(r[0]).strip()] | |
except Exception as e: logger.error(f"LOAD_DATA (Rules SQL): {e}") | |
rules_texts[:] = sorted(list(set(r_store))) | |
logger.info(f"Loaded {len(rules_texts)} rules from {STORAGE_BACKEND} in {time.time() - r_store_load_start:.2f}s") | |
r_faiss_build_start = time.time() | |
try: | |
faiss_rules_index = faiss.IndexFlatL2(dimension) | |
if rules_texts: | |
logger.info(f"Encoding {len(rules_texts)} rules for FAISS index...") | |
em = embedder.encode(rules_texts, convert_to_numpy=True, show_progress_bar=True) | |
if em.ndim == 2 and em.shape[0] == len(rules_texts) and em.shape[1] == dimension: faiss_rules_index.add(np.array(em, dtype=np.float32)) | |
else: logger.error(f"LOAD_DATA (Rules FAISS): Embedding shape error. Expected ({len(rules_texts)}, {dimension}), Got {em.shape if hasattr(em, 'shape') else 'N/A'}") | |
logger.info(f"Rules FAISS index built/loaded in {time.time()-r_faiss_build_start:.2f}s. Index total: {getattr(faiss_rules_index, 'ntotal', 'N/I')}") | |
except Exception as e: logger.error(f"LOAD_DATA (Rules FAISS build): {e}", exc_info=True) | |
logger.info(f"Total load_data_on_startup took {time.time() - startup_time_start:.2f}s") | |
# --- LLM API Call (from ai-learn) --- | |
def callAIModel(api_provider_param, model, messages_list, maxTokens=200, stream=False, retries=3, delay=1, temperature=0.7): | |
call_start_time = time.time() | |
# Use API_KEYS global, ensure it's up-to-date from UI if changed | |
keyName = api_provider_param.upper() | |
current_api_key = API_KEYS.get(keyName) | |
if not current_api_key or (isinstance(current_api_key, str) and f"YOUR_{keyName}_API_KEY" in current_api_key): # Check for placeholder | |
logger.error(f"{api_provider_param} API key missing or is a placeholder in API_KEYS global.") | |
raise Exception(f"{api_provider_param} API key missing or placeholder.") | |
headers = {"Content-Type": "application/json"} | |
url = None | |
payload_data_dict = {"model": model, "messages": messages_list, "max_tokens": maxTokens, "stream": stream, "temperature": temperature} | |
data_as_string = None | |
current_payload = None | |
if api_provider_param.startswith("huggingface"): | |
headers["Authorization"] = f"Bearer {current_api_key}" | |
url = f"{API_URLS['HUGGINGFACE']}{model}" | |
prompt_string = "\n".join([f"{m['role'].capitalize() if m['role']!='system' else ''}{':' if m['role']!='system' else ''} {m['content']}" for m in messages_list]) + "\nAssistant:\n" | |
current_payload = {"inputs": prompt_string.strip(), "parameters": {"max_new_tokens": maxTokens, "return_full_text": False, "temperature": temperature if temperature > 0 else 0.01, "do_sample": temperature > 0}} | |
if stream: current_payload["parameters"]["stream"] = True | |
elif api_provider_param in ["groq", "togetherai", "xai", "openai"]: | |
headers["Authorization"] = f"Bearer {current_api_key}" | |
url = API_URLS[keyName] | |
current_payload = payload_data_dict | |
elif api_provider_param == "openrouter": | |
headers["Authorization"] = f"Bearer {current_api_key}" | |
headers["HTTP-Referer"] = os.getenv("OPENROUTER_REFERRER", "http://localhost") # OpenRouter requires Referer | |
headers["X-Title"] = os.getenv("OPENROUTER_X_TITLE", "Gradio AI Researcher") # Optional | |
url = API_URLS[keyName] | |
# OpenRouter models often have provider prefix, e.g., "openai/gpt-3.5-turbo" | |
# Ensure payload_data_dict["model"] is correctly formatted. | |
current_payload = payload_data_dict | |
elif api_provider_param == "cohere": | |
headers["Authorization"] = f"Bearer {current_api_key}" | |
url = API_URLS["COHERE"] | |
current_message_content, cohere_chat_history, system_message_for_cohere_preamble = "", [], "" | |
processed_messages = list(messages_list) | |
if processed_messages and processed_messages[0]['role'] == 'system': system_message_for_cohere_preamble = processed_messages.pop(0)['content'] | |
if processed_messages: | |
current_message_content = processed_messages[-1]["content"] | |
for msg_item in processed_messages[:-1]: cohere_chat_history.append({"role": "USER" if msg_item["role"] == "user" else "CHATBOT", "message": msg_item["content"]}) | |
elif system_message_for_cohere_preamble: current_message_content = "..." # Dummy if only system prompt | |
current_payload = {"message": current_message_content, "chat_history": cohere_chat_history, "model": model, "max_tokens": maxTokens, "stream": stream, "temperature": temperature} | |
if system_message_for_cohere_preamble: current_payload["preamble"] = system_message_for_cohere_preamble | |
else: | |
raise Exception(f"Unsupported API provider: {api_provider_param}") | |
if url is None: | |
raise Exception(f"URL not configured for API provider: {api_provider_param}") | |
first_chunk_logged = False | |
for attempt in range(int(retries)): | |
attempt_start_time = time.time() | |
try: | |
request_args = {"headers": headers, "stream": stream, "timeout": 180} | |
if data_as_string is not None: | |
request_args["data"] = data_as_string | |
elif current_payload is not None: | |
request_args["json"] = current_payload | |
else: | |
logger.error(f"callAIModel: No payload determined for {api_provider_param}") | |
raise Exception("Payload construction error") | |
logger.debug(f"callAIModel [{api_provider_param}/{model}] Attempt {attempt+1}: POST to {url}, Payload keys: {current_payload.keys() if current_payload else 'N/A'}") | |
r = requests.post(url, **request_args) | |
r.raise_for_status() | |
if stream: | |
for line in r.iter_lines(): | |
if not first_chunk_logged: | |
logger.info(f"callAIModel [{api_provider_param}] Time to first byte/line from stream: {time.time() - attempt_start_time:.2f}s") | |
first_chunk_logged = True | |
if not line: continue | |
s_line = line.decode('utf-8').strip() | |
chunk_to_yield = None | |
if s_line.startswith("data: "): data_content = s_line[len("data: "):] | |
else: data_content = s_line | |
if data_content == "[DONE]": break | |
try: | |
parsed_json = json.loads(data_content) | |
if api_provider_param == "cohere": | |
if parsed_json.get("event_type") == "text-generation" and parsed_json.get("text"): chunk_to_yield = parsed_json["text"] | |
if parsed_json.get("event_type") == "stream-end": break # Cohere specific end signal | |
elif parsed_json.get("choices") and parsed_json["choices"][0].get("delta", {}).get("content") is not None: | |
chunk_to_yield = parsed_json["choices"][0]["delta"]["content"] | |
elif parsed_json.get("token", {}).get("text"): chunk_to_yield = parsed_json["token"]["text"] # HuggingFace TGI format | |
except json.JSONDecodeError: # For HF non-JSON stream lines | |
if api_provider_param.startswith("huggingface") and not data_content.startswith("{"): | |
# Check if it's the TGI last message structure | |
try: | |
hf_end_obj = json.loads(data_content) | |
if hf_end_obj.get("generated_text") is not None and hf_end_obj.get("details") is not None: | |
break # End of stream for TGI if it sends a final full object | |
except json.JSONDecodeError: | |
chunk_to_yield = data_content # Assume raw text chunk | |
if chunk_to_yield is not None: yield chunk_to_yield | |
else: # Non-streaming | |
result = r.json() | |
logger.info(f"callAIModel [{api_provider_param}] Non-streaming response received in {time.time() - attempt_start_time:.2f}s") | |
if api_provider_param.startswith("huggingface"): yield result[0]["generated_text"].strip() if isinstance(result, list) and result and "generated_text" in result[0] else "" | |
elif api_provider_param == "cohere": yield result.get("text", "").strip() or (result.get("generations")[0].get("text","").strip() if result.get("generations") else "") | |
else: yield result.get("choices", [{}])[0].get("message", {}).get("content", "").strip() | |
logger.info(f"callAIModel [{api_provider_param}] Call successful in {time.time() - call_start_time:.2f}s (attempt {attempt+1})") | |
return | |
except requests.exceptions.HTTPError as e: | |
response_text_snippet = e.response.text[:500] if e.response and e.response.text else "No response body" | |
logger.warning(f"callAIModel HTTPError {e.response.status_code} for {api_provider_param} (attempt {attempt+1}/{retries}) after {time.time() - attempt_start_time:.2f}s: {response_text_snippet}") | |
if e.response.status_code == 401: logger.error(f"API Key invalid for {api_provider_param}. Cannot retry."); raise Exception(f"API Key invalid for {api_provider_param}.") | |
if e.response.status_code == 429: delay *= 1.5 | |
if e.response.status_code >= 500: logger.warning("Server error, retrying...") # Retrying on 5xx | |
else: # For 4xx errors other than 401, 429, don't retry unless specifically handled | |
if attempt == retries -1 : raise # if it's the last attempt, raise it | |
except Exception as e: | |
logger.warning(f"callAIModel attempt {attempt+1}/{retries} error after {time.time() - attempt_start_time:.2f}s for {api_provider_param}: {e}", exc_info=False) # exc_info=True is too verbose for normal operation | |
if attempt < retries - 1: | |
sleep_duration = delay * (1.5**attempt) | |
logger.info(f"Retrying {api_provider_param} in {sleep_duration:.2f}s...") | |
time.sleep(sleep_duration) | |
else: | |
total_call_duration = time.time() - call_start_time | |
logger.error(f"API call to {api_provider_param} failed after {retries} retries over {total_call_duration:.2f}s.") | |
raise Exception(f"API call to {api_provider_param} failed after {retries} retries.") | |
# --- Memory & Insight Functions (from ai-learn, adapted) --- | |
def generate_interaction_metrics(user_input, bot_response, api_provider, model): | |
metric_start_time = time.time() | |
metric_prompt = f"User: \"{user_input}\"\nAI: \"{bot_response}\"\nMetrics: \"takeaway\" (3-7 words), \"response_success_score\" (0.0-1.0), \"future_confidence_score\" (0.0-1.0). JSON ONLY." | |
messages = [{"role": "system", "content": "Output JSON metrics for user-AI interaction."}, {"role": "user", "content": metric_prompt}] | |
try: | |
# Use a potentially faster/cheaper model for metrics | |
metrics_provider = TOOL_DECISION_PROVIDER | |
metrics_model = TOOL_DECISION_MODEL | |
# Override if specific metrics model is set | |
metrics_model_override = os.getenv("METRICS_MODEL") | |
if metrics_model_override and "/" in metrics_model_override: | |
metrics_provider, metrics_model = metrics_model_override.split("/",1) | |
elif metrics_model_override: # assume same provider as TOOL_DECISION_PROVIDER | |
metrics_model = metrics_model_override | |
resp_str = "".join(list(callAIModel( | |
api_provider_param=metrics_provider, model=metrics_model, messages_list=messages, | |
maxTokens=150, stream=False, retries=2, temperature=0.1 | |
))).strip() | |
match = re.search(r"\{.*\}", resp_str, re.DOTALL) | |
if match: metrics_data = json.loads(match.group(0)) | |
else: | |
logger.warning(f"METRICS_GEN: Non-JSON response from {metrics_provider}/{metrics_model}: {resp_str}") | |
return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": "metrics format"} | |
parsed_metrics = { | |
"takeaway": metrics_data.get("takeaway", "N/A"), | |
"response_success_score": float(metrics_data.get("response_success_score", 0.5)), | |
"future_confidence_score": float(metrics_data.get("future_confidence_score", 0.5)), | |
"error": metrics_data.get("error") | |
} | |
logger.info(f"METRICS_GEN: Metrics generated by {metrics_provider}/{metrics_model} in {time.time() - metric_start_time:.2f}s. Data: {parsed_metrics}") | |
return parsed_metrics | |
except Exception as e: | |
logger.error(f"METRICS_GEN Error in {time.time() - metric_start_time:.2f}s: {e}", exc_info=False) | |
return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": str(e)} | |
def add_memory(user_input, interaction_metrics, bot_response): | |
global faiss_memory_index, memory_texts, dimension, embedder | |
if not embedder: logger.error("ADD_MEMORY: Embedder not initialized."); return False | |
add_mem_start = time.time() | |
ts = datetime.utcnow().isoformat() | |
memory_json = json.dumps({"user_input": user_input, "metrics": interaction_metrics, "bot_response": bot_response, "timestamp": ts}) | |
try: | |
text_to_embed = f"User: {user_input}\nAI: {bot_response}\nTakeaway: {interaction_metrics.get('takeaway', 'N/A')}" | |
embedding = np.array(embedder.encode([text_to_embed]), dtype=np.float32) | |
if embedding.shape == (1, dimension) and faiss_memory_index is not None: | |
faiss_memory_index.add(embedding) | |
memory_texts.append(memory_json) | |
if STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN: | |
logger.info(f"ADD_MEMORY: Pushing {len(memory_texts)} memories to HF Hub: {HF_MEMORY_DATASET_NAME}") | |
Dataset.from_dict({"memories": list(memory_texts)}).push_to_hub(HF_MEMORY_DATASET_NAME, token=HF_TOKEN, private=True) | |
elif STORAGE_BACKEND == "SQLITE": | |
with get_sqlite_connection() as conn: conn.execute("INSERT INTO memories (memory_json) VALUES (?)", (memory_json,)); conn.commit() | |
logger.info(f"ADD_MEMORY: Added. RAM:{len(memory_texts)}, FAISS:{faiss_memory_index.ntotal}. Total time: {time.time() - add_mem_start:.2f}s") | |
return True | |
else: | |
logger.warning(f"ADD_MEMORY: FAISS index not init or embedding error (Shape: {embedding.shape}). Time: {time.time() - add_mem_start:.2f}s") | |
return False | |
except Exception as e: | |
logger.error(f"ADD_MEMORY: Error in {time.time() - add_mem_start:.2f}s: {e}", exc_info=True) | |
return False | |
def retrieve_memories(query, k=3): | |
global faiss_memory_index, memory_texts, embedder | |
if not embedder: logger.error("RETRIEVE_MEMORIES: Embedder not initialized."); return [] | |
if not faiss_memory_index or faiss_memory_index.ntotal == 0: return [] | |
retrieve_start = time.time() | |
try: | |
embedding = np.array(embedder.encode([query]), dtype=np.float32) | |
if embedding.ndim == 1: embedding = embedding.reshape(1, -1) # Ensure 2D for FAISS | |
if embedding.shape[1] != dimension: | |
logger.error(f"RETRIEVE_MEMORIES: Query embedding dimension mismatch. Expected {dimension}, got {embedding.shape[1]}. Query: '{query[:30]}...'") | |
return [] | |
_, indices = faiss_memory_index.search(embedding, min(k, faiss_memory_index.ntotal)) | |
results = [json.loads(memory_texts[i]) for i in indices[0] if 0 <= i < len(memory_texts)] | |
logger.debug(f"RETRIEVE_MEMORIES: Found {len(results)} memories in {time.time() - retrieve_start:.4f}s for query '{query[:30]}...'") | |
return results | |
except Exception as e: | |
logger.error(f"RETRIEVE_MEMORIES error in {time.time() - retrieve_start:.4f}s: {e}", exc_info=True) | |
return [] | |
# ... (Insight/Rule functions: remove_insight_from_memory, _add_new_insight_to_store, add_learned_insight, retrieve_learned_insights) ... | |
# These are complex and involve FAISS updates. For brevity in this combined script, | |
# I'll include their signatures and key logic points. The full versions are in ai-learn.py. | |
def remove_insight_from_memory(insight_text_to_remove): # insight_text_to_remove is a rule | |
global rules_texts, faiss_rules_index, embedder, dimension | |
if not embedder: logger.error("REMOVE_INSIGHT: Embedder not initialized."); return False | |
if insight_text_to_remove not in rules_texts: | |
logger.info(f"REMOVE_INSIGHT: Insight '{insight_text_to_remove[:70]}...' not found. Skipping.") | |
return False | |
# ... (Full logic from ai-learn.py including FAISS rebuild and backend storage update) ... | |
logger.info(f"Attempting to remove insight: {insight_text_to_remove}") | |
try: | |
rules_texts.remove(insight_text_to_remove) | |
# Rebuild FAISS index for rules | |
if rules_texts: | |
new_embeddings = embedder.encode(rules_texts, convert_to_numpy=True) | |
faiss_rules_index = faiss.IndexFlatL2(dimension) | |
faiss_rules_index.add(np.array(new_embeddings, dtype=np.float32)) | |
else: | |
faiss_rules_index = faiss.IndexFlatL2(dimension) # Empty index | |
if STORAGE_BACKEND == "SQLITE": | |
with get_sqlite_connection() as conn: | |
conn.execute("DELETE FROM rules WHERE rule_text = ?", (insight_text_to_remove,)) | |
conn.commit() | |
elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN: | |
Dataset.from_dict({"rule_text": list(rules_texts)}).push_to_hub(HF_RULES_DATASET_NAME, token=HF_TOKEN, private=True) | |
logger.info(f"Insight '{insight_text_to_remove}' removed. FAISS rules total: {faiss_rules_index.ntotal}") | |
return True | |
except Exception as e: | |
logger.error(f"Error removing insight '{insight_text_to_remove}': {e}", exc_info=True) | |
# Potentially re-add if removal from list succeeded but backend failed. For simplicity, not implemented here. | |
return False | |
def _add_new_insight_to_store(insight_text): # insight_text is a rule | |
global faiss_rules_index, rules_texts, dimension, embedder | |
if not embedder: logger.error("_ADD_NEW_INSIGHT: Embedder not initialized."); return False | |
if not insight_text or not isinstance(insight_text, str): | |
logger.warning(f"_ADD_NEW_INSIGHT: Invalid or empty insight text: {insight_text}") | |
return False | |
insight_text = insight_text.strip() | |
if insight_text in rules_texts: | |
logger.info(f"_ADD_NEW_INSIGHT: Insight '{insight_text[:70]}...' already exists. Skipped.") | |
return False | |
# ... (Full logic from ai-learn.py including FAISS add and backend storage update) ... | |
logger.info(f"Adding new insight: {insight_text}") | |
try: | |
embedding = np.array(embedder.encode([insight_text]), dtype=np.float32) | |
if embedding.shape != (1, dimension): | |
logger.error(f"_ADD_NEW_INSIGHT: Embedding shape error for insight. Expected (1, {dimension}), got {embedding.shape}") | |
return False | |
if faiss_rules_index is None: # Should have been initialized | |
faiss_rules_index = faiss.IndexFlatL2(dimension) | |
faiss_rules_index.add(embedding) | |
rules_texts.append(insight_text) | |
rules_texts.sort() | |
if STORAGE_BACKEND == "SQLITE": | |
with get_sqlite_connection() as conn: | |
conn.execute("INSERT OR IGNORE INTO rules (rule_text) VALUES (?)", (insight_text,)) | |
conn.commit() | |
elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN: | |
Dataset.from_dict({"rule_text": list(rules_texts)}).push_to_hub(HF_RULES_DATASET_NAME, token=HF_TOKEN, private=True) | |
logger.info(f"Insight '{insight_text}' added. FAISS rules total: {faiss_rules_index.ntotal}") | |
return True | |
except Exception as e: | |
logger.error(f"Error adding insight '{insight_text}': {e}", exc_info=True) | |
# Rollback logic from ai-learn is complex, simplified here. | |
if insight_text in rules_texts: rules_texts.remove(insight_text) # Basic rollback | |
# FAISS rollback is harder, may require full rebuild on error. | |
return False | |
def add_learned_insight(insight_text_with_format): | |
insight_text = insight_text_with_format.strip() | |
# Basic validation of format [TYPE|SCORE] Text | |
if not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\](.*)", insight_text, re.I|re.DOTALL): | |
logger.warning(f"ADD_LEARNED_INSIGHT: Invalid format for insight: {insight_text[:100]}...") | |
return False | |
return _add_new_insight_to_store(insight_text) | |
def retrieve_learned_insights(query, k_insights=3): # retrieves rules | |
global faiss_rules_index, rules_texts, embedder | |
if not embedder: logger.error("RETRIEVE_INSIGHTS: Embedder not initialized."); return [] | |
if not faiss_rules_index or faiss_rules_index.ntotal == 0: return [] | |
retrieve_start = time.time() | |
try: | |
embedding = np.array(embedder.encode([query]), dtype=np.float32) | |
if embedding.ndim == 1: embedding = embedding.reshape(1, -1) | |
if embedding.shape[1] != dimension: | |
logger.error(f"RETRIEVE_INSIGHTS: Query embedding dimension mismatch. Expected {dimension}, got {embedding.shape[1]}. Query: '{query[:30]}...'") | |
return [] | |
_, indices = faiss_rules_index.search(embedding, min(k_insights, faiss_rules_index.ntotal)) | |
results = [rules_texts[i] for i in indices[0] if 0 <= i < len(rules_texts)] | |
logger.debug(f"RETRIEVE_INSIGHTS: Found {len(results)} insights in {time.time() - retrieve_start:.4f}s for query '{query[:30]}...'") | |
return results | |
except Exception as e: | |
logger.error(f"RETRIEVE_INSIGHTS error in {time.time() - retrieve_start:.4f}s: {e}", exc_info=True) | |
return [] | |
# --- Chat History & Formatting (from ai-learn) --- | |
MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", 5)) | |
current_chat_session_history = [] # Global chat history for ai-learn logic | |
def format_insights_for_prompt(retrieved_insights_list): | |
if not retrieved_insights_list: return "No specific guiding principles or learned insights retrieved.", [] | |
parsed = [] | |
for text in retrieved_insights_list: | |
match = re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\](.*)", text.strip(), re.DOTALL|re.IGNORECASE) | |
if match: parsed.append({"type": match.group(1).upper().replace(" ","_"), "score": match.group(2), "text": match.group(3).strip(), "original": text}) | |
else: parsed.append({"type": "GENERAL_LEARNING", "score": "0.5", "text": text.strip(), "original": text.strip()}) # Default if format slightly off | |
try: | |
parsed.sort(key=lambda x: float(x["score"]) if x["score"].replace('.', '', 1).isdigit() else -1.0, reverse=True) | |
except ValueError: | |
logger.warning("FORMAT_INSIGHTS: Sort error due to invalid score format in an insight.") | |
grouped = {"CORE_RULE":[],"RESPONSE_PRINCIPLE":[],"BEHAVIORAL_ADJUSTMENT":[],"GENERAL_LEARNING":[]} | |
for p_item in parsed: # Renamed p to p_item to avoid conflict | |
grouped.get(p_item["type"], grouped["GENERAL_LEARNING"]).append(f"- (Score: {p_item['score']}) {p_item['text']}") | |
sections = [f"{k.replace('_',' ').title()}:\n" + "\n".join(v) for k,v in grouped.items() if v] | |
return "\n\n".join(sections) if sections else "No guiding principles retrieved.", parsed | |
# --- Core Interaction Processing (adapted from ai-learn's process_user_interaction) --- | |
# This function is now a generator for Gradio streaming | |
def process_user_interaction_gradio(user_input, api_provider, model, chat_history_with_current_user_msg, custom_system_prompt=None): | |
process_start_time = time.time() | |
request_id = os.urandom(4).hex() | |
logger.info(f"PUI_GRADIO [{request_id}] Start. User: '{user_input[:40]}...' API: {api_provider}/{model} Hist_len:{len(chat_history_with_current_user_msg)}") | |
# History string for prompts | |
history_str_parts = [] | |
for t in chat_history_with_current_user_msg[- (MAX_HISTORY_TURNS * 2):]: # Use last N turns for prompt context | |
role = "User" if t['role'] == 'user' else "AI" | |
history_str_parts.append(f"{role}: {t['content']}") | |
history_str = "\n".join(history_str_parts) | |
yield "status", "<i>[Checking guidelines...]</i>" | |
time_before_initial_rag = time.time() | |
initial_insights = retrieve_learned_insights(user_input + "\n" + history_str, k_insights=5) # More context for insight retrieval | |
initial_insights_ctx, parsed_initial_insights = format_insights_for_prompt(initial_insights) | |
logger.info(f"PUI_GRADIO [{request_id}]: Initial RAG (insights) took {time.time() - time_before_initial_rag:.3f}s. Found {len(initial_insights)} insights. Context: {initial_insights_ctx[:100]}...") | |
action_type, action_input = "quick_respond", {} | |
user_input_lower = user_input.lower() | |
time_before_tool_decision_logic = time.time() | |
# Simplified heuristic checks from ai-learn | |
simple_keywords = ["hello", "hi", "hey", "thanks", "thank you", "ok", "okay", "yes", "no", "bye", "cool", "great", "awesome", "sounds good", "got it"] | |
if len(user_input.split()) <= 4 and any(kw in user_input_lower for kw in simple_keywords) and not "?" in user_input : # Added not ? to avoid misclassifying simple questions | |
action_type = "quick_respond" | |
logger.info(f"PUI_GRADIO [{request_id}]: Heuristic: Simple keyword. Action: quick_respond.") | |
elif WEB_SEARCH_ENABLED and ("http://" in user_input or "https://" in user_input): # Handle direct URL | |
url_match = re.search(r'(https?://[^\s]+)', user_input) | |
if url_match: | |
action_type = "scrape_url_and_report" | |
action_input = {"url": url_match.group(1)} | |
logger.info(f"PUI_GRADIO [{request_id}]: Heuristic: URL detected. Action: scrape_url_and_report.") | |
# else: # If it looks like a URL but regex fails, might fall through to search | |
# logger.info(f"PUI_GRADIO [{request_id}]: URL-like input but no clean match. May default to search via LLM decision.") | |
# pass # Let LLM decide | |
# LLM-based tool decision if not simple or direct URL, and web search is enabled | |
if action_type == "quick_respond" and WEB_SEARCH_ENABLED and (len(user_input.split()) > 3 or "?" in user_input or any(w in user_input_lower for w in ["what is", "how to", "explain", "tell me about", "search for", "find information on", "who is", "why"])): | |
yield "status", "<i>[Choosing best approach...]</i>" | |
# Reduced history snippet for tool prompt | |
history_snippet = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history_with_current_user_msg[-3:]]) | |
guideline_snippet = initial_insights_ctx[:200].replace('\n', ' ') # Compact | |
tool_prompt_content = f"User Query: \"{user_input}\"\nRecent Conversation Snippet:\n{history_snippet}\nKey Guidelines (summary): {guideline_snippet}...\n\nAvailable Actions & Required Inputs:\n1. `quick_respond`: For simple chat, greetings, or if no external info/memory is needed. (Input: N/A)\n2. `answer_using_conversation_memory`: If the query refers to past specific details of THIS conversation not covered by general guidelines. (Input: N/A)\n3. `search_duckduckgo_and_report`: For general knowledge, facts, current events, or if user asks to search. (Input: `search_engine_query`: string)\n4. `scrape_url_and_report`: If user explicitly provides a URL to summarize or analyze. (Input: `url`: string)\n\nBased on the query and context, select ONLY ONE action and its required input (if any). Output a single JSON object like: {{\"action\": \"chosen_action\", \"action_input\": {{\"param_name\": \"value\"}}}} or {{\"action\": \"chosen_action\", \"action_input\": {{}}}} if no input needed." | |
tool_msgs = [ | |
{"role":"system", "content": "You are a precise routing agent. Your task is to choose the single most appropriate action from the list to address the user's query. Output JSON only."}, | |
{"role":"user", "content": tool_prompt_content} | |
] | |
time_before_tool_llm = time.time() | |
try: | |
tool_resp_raw = "".join(list(callAIModel( | |
api_provider_param=TOOL_DECISION_PROVIDER, model=TOOL_DECISION_MODEL, messages_list=tool_msgs, | |
maxTokens=100, stream=False, retries=1, temperature=0.0 | |
))).strip() | |
logger.info(f"PUI_GRADIO [{request_id}]: Tool Decision LLM ({TOOL_DECISION_PROVIDER}/{TOOL_DECISION_MODEL}) call took {time.time() - time_before_tool_llm:.3f}s. Raw Response: '{tool_resp_raw}'") | |
match = re.search(r"\{.*\}", tool_resp_raw, re.DOTALL) # Extract JSON part | |
if match: | |
action_data = json.loads(match.group(0)) | |
action_type = action_data.get("action", "quick_respond") | |
action_input_raw = action_data.get("action_input", {}) | |
# Ensure action_input is a dict | |
if isinstance(action_input_raw, dict): | |
action_input = action_input_raw | |
elif isinstance(action_input_raw, str) and action_input_raw: # Handle if LLM gives string for action_input | |
if "search" in action_type: action_input = {"search_engine_query": action_input_raw} | |
elif "scrape_url" in action_type: action_input = {"url": action_input_raw} | |
else: action_input = {} | |
else: action_input = {} | |
logger.info(f"PUI_GRADIO [{request_id}]: LLM Tool Decision: Action='{action_type}', Input='{action_input}'") | |
else: | |
logger.warning(f"PUI_GRADIO [{request_id}]: Tool decision LLM non-JSON or no JSON found. Defaulting to quick_respond. Raw: {tool_resp_raw}") | |
action_type = "quick_respond"; action_input = {} | |
except Exception as e_tool_llm: | |
logger.error(f"PUI_GRADIO [{request_id}]: Tool decision LLM error after {time.time() - time_before_tool_llm:.3f}s: {e_tool_llm}", exc_info=False) | |
action_type = "quick_respond"; action_input = {} | |
elif not WEB_SEARCH_ENABLED and action_type == "quick_respond": # If web search disabled, consider memory | |
if len(user_input.split()) > 4 or "?" in user_input or any(w in user_input_lower for w in ["remember","recall", "what did i say about", "what was"]): | |
action_type="answer_using_conversation_memory" | |
logger.info(f"PUI_GRADIO [{request_id}]: Web search disabled, heuristic for memory retrieval. Action: {action_type}") | |
logger.info(f"PUI_GRADIO [{request_id}]: Tool decision logic (total) took {time.time() - time_before_tool_decision_logic:.3f}s. Chosen Action: {action_type}, Input: {action_input}") | |
yield "status", f"<i>[Path: {action_type}. Preparing response...]</i>" | |
final_bot_response_str, mem_ctx, scraped_content_str = "", "No memories reviewed for this path.", "" | |
# final_prompt_hist_for_llm = history_str + "\nAssistant:" # This is already part of history_str construction | |
system_prompt_text = custom_system_prompt or "You are a helpful and concise AI assistant." # Base system prompt | |
user_prompt_content_text = "" | |
time_before_action_execution = time.time() | |
# Construct prompts based on action_type | |
if action_type == "quick_respond": | |
system_prompt_text += " Respond directly to the user's query using the provided guidelines and conversation history for context. Be concise." | |
user_prompt_content_text = f"Conversation History:\n{history_str}\n\nGuiding Principles:\n{initial_insights_ctx}\n\nUser's Current Query: \"{user_input}\"\n\nYour concise and helpful response:" | |
elif action_type == "answer_using_conversation_memory": | |
yield "status", "<i>[Searching conversation memory...]</i>" | |
# Truncate history_str for memory query if too long | |
mem_query_context = history_str[-1000:] # Last 1000 chars of history for mem query context | |
mem_query = f"User's current query: {user_input}\nRelevant conversation context:\n{mem_query_context}" | |
memories = retrieve_memories(mem_query, k=2) | |
if memories: | |
mem_ctx = "Relevant Past Interactions (for your reference):\n" + "\n".join([f"- User: {m.get('user_input','')} -> AI: {m.get('bot_response','')} (Takeaway: {m.get('metrics',{}).get('takeaway','N/A')}, Timestamp: {m.get('timestamp','N/A')})" for m in memories]) | |
else: | |
mem_ctx = "No highly relevant past interactions found in memory for this specific query." | |
logger.info(f"PUI_GRADIO [{request_id}]: Memory retrieval found {len(memories)} items. Context: {mem_ctx[:100]}...") | |
system_prompt_text += " Respond to the user by incorporating relevant information from past interactions (provided below as 'Memory Context') and your general guidelines. Prioritize the user's current query." | |
user_prompt_content_text = f"Conversation History:\n{history_str}\n\nGuiding Principles:\n{initial_insights_ctx}\n\nMemory Context (from previous related interactions):\n{mem_ctx}\n\nUser's Current Query: \"{user_input}\"\n\nYour helpful response (draw from memory context if applicable, otherwise answer generally):" | |
elif WEB_SEARCH_ENABLED and action_type in ["search_duckduckgo_and_report", "search_google_and_report", "scrape_url_and_report"]: | |
query_or_url_for_web = action_input.get("search_engine_query") if "search" in action_type else action_input.get("url") | |
if not query_or_url_for_web: | |
logger.warning(f"PUI_GRADIO [{request_id}]: Missing 'search_engine_query' or 'url' for action {action_type}. Falling back to quick_respond.") | |
action_type = "quick_respond" # Fallback | |
system_prompt_text += " Respond directly. (Note: A web action was attempted but failed due to missing input)." | |
user_prompt_content_text = f"Conversation History:\n{history_str}\n\nGuiding Principles:\n{initial_insights_ctx}\n\nUser's Current Query: \"{user_input}\"\n\nYour concise and helpful response:" | |
else: | |
yield "status", f"<i>[Fetching web: '{query_or_url_for_web[:50]}'...]</i>" | |
web_results_data = [] | |
time_before_scraping = time.time() | |
max_scrape_results = 1 if action_type == "scrape_url_and_report" else 2 # Max 2 search results to summarize | |
try: | |
if action_type == "search_duckduckgo_and_report": | |
web_results_data = search_and_scrape_duckduckgo(query_or_url_for_web, num_results=max_scrape_results) | |
elif action_type == "search_google_and_report": | |
web_results_data = search_and_scrape_google(query_or_url_for_web, num_results=max_scrape_results) # Placeholder, uses DDG | |
elif action_type == "scrape_url_and_report": | |
scrape_res = scrape_url(query_or_url_for_web) | |
if scrape_res and scrape_res.get("content"): web_results_data = [scrape_res] | |
elif scrape_res: web_results_data = [{"url":query_or_url_for_web, "title":"Scrape Error", "content":None, "error":scrape_res.get("error","Unknown scrape error")}] | |
except Exception as e_scrape_call: | |
logger.error(f"PUI_GRADIO [{request_id}]: Error during web tool call for {action_type}: {e_scrape_call}", exc_info=True) | |
web_results_data = [{"url":query_or_url_for_web, "title":"Tool Execution Error", "content":None, "error":str(e_scrape_call)}] | |
logger.info(f"PUI_GRADIO [{request_id}]: Web scraping/fetching took {time.time() - time_before_scraping:.3f}s. Found {len(web_results_data)} results.") | |
if web_results_data: | |
scraped_parts = [] | |
for i, r_item in enumerate(web_results_data): | |
yield "status", f"<i>[Processing web result {i+1}/{len(web_results_data)}: {r_item.get('title','N/A')[:30]}...]</i>" | |
content_for_prompt = (r_item.get('content') or r_item.get('error') or 'N/A') | |
# Truncate individual source if very long, main LLM has token limits | |
max_source_len = 3000 # Max characters per source for the prompt | |
if len(content_for_prompt) > max_source_len: | |
content_for_prompt = content_for_prompt[:max_source_len] + "... (truncated)" | |
scraped_parts.append(f"Source {i+1}:\nURL: {r_item.get('url','N/A')}\nTitle: {r_item.get('title','N/A')}\nContent Snippet:\n{content_for_prompt}\n---") | |
scraped_content_str = "\n".join(scraped_parts) if scraped_parts else "No usable content extracted from web sources." | |
else: | |
scraped_content_str = f"No results or content found from {action_type} for '{query_or_url_for_web}'." | |
yield "status", "<i>[Synthesizing web report...]</i>" | |
system_prompt_text += " You are an AI assistant that generates reports or answers based on web content. Use the provided web content, conversation history, and guidelines. Cite URLs clearly as [Source X] where X is the source number." | |
user_prompt_content_text = f"Conversation History:\n{history_str}\n\nGuiding Principles:\n{initial_insights_ctx}\n\nWeb Content Found:\n{scraped_content_str}\n\nUser's Current Query: \"{user_input}\"\n\nYour report/response (ensure to cite sources like [Source 1], [Source 2], etc., if you use their content):" | |
else: # Should not happen if logic is correct, but as a fallback | |
logger.warning(f"PUI_GRADIO [{request_id}]: Unknown action_type '{action_type}'. Defaulting to quick_respond.") | |
action_type = "quick_respond" # Fallback | |
system_prompt_text += " Respond directly. (Note: An unexpected internal state occurred)." | |
user_prompt_content_text = f"Conversation History:\n{history_str}\n\nGuiding Principles:\n{initial_insights_ctx}\n\nUser's Current Query: \"{user_input}\"\n\nYour concise and helpful response:" | |
logger.info(f"PUI_GRADIO [{request_id}]: Action execution (RAG memory, web scrape, prompt prep) took {time.time() - time_before_action_execution:.3f}s.") | |
# Construct final messages for LLM | |
# Ensure chat_history_with_current_user_msg is OpenAI format | |
# For the main call, we can construct it from scratch using our variables | |
# Or use the chat_history_with_current_user_msg passed in, but ensure it's not too long and system prompt is right | |
# Let's build messages for the final LLM call cleanly | |
final_llm_messages = [] | |
if system_prompt_text: | |
final_llm_messages.append({"role": "system", "content": system_prompt_text}) | |
# Add relevant history turns (excluding the very last user message which is part of user_prompt_content_text or user_input) | |
# For this version, user_prompt_content_text already includes history and user_input, so we just need that. | |
final_llm_messages.append({"role": "user", "content": user_prompt_content_text}) | |
# Debug: Log the first and last parts of the prompt being sent to LLM | |
logger.debug(f"PUI_GRADIO [{request_id}]: Final LLM System Prompt: {system_prompt_text[:200]}...") | |
if len(user_prompt_content_text) > 400 : | |
logger.debug(f"PUI_GRADIO [{request_id}]: Final LLM User Prompt Start: {user_prompt_content_text[:200]}...") | |
logger.debug(f"PUI_GRADIO [{request_id}]: Final LLM User Prompt End: ...{user_prompt_content_text[-200:]}") | |
else: | |
logger.debug(f"PUI_GRADIO [{request_id}]: Final LLM User Prompt: {user_prompt_content_text}") | |
streamed_resp_accumulator = "" | |
time_before_main_llm = time.time() | |
try: | |
response_iterator = callAIModel( | |
api_provider_param=api_provider, model=model, messages_list=final_llm_messages, | |
maxTokens=2000, stream=True, temperature=0.6, retries=1 # Adjusted maxTokens and temp | |
) | |
for chunk in response_iterator: | |
streamed_resp_accumulator += chunk | |
yield "response_chunk", chunk | |
except Exception as e_final_llm: | |
logger.error(f"PUI_GRADIO [{request_id}]: Final LLM call error: {e_final_llm}", exc_info=False) | |
error_response_chunk = f"\n\n(Error during final response generation: {str(e_final_llm)[:150]})" | |
streamed_resp_accumulator += error_response_chunk | |
yield "response_chunk", error_response_chunk | |
logger.info(f"PUI_GRADIO [{request_id}]: Main LLM call (streamed) took {time.time() - time_before_main_llm:.3f}s.") | |
current_final_bot_response_str = streamed_resp_accumulator.strip() or "(No response generated.)" | |
logger.info(f"PUI_GRADIO [{request_id}]: Processing finished. Total wall time: {time.time() - process_start_time:.2f}s. Final response length: {len(current_final_bot_response_str)}") | |
yield "final_response_and_insights", {"response": current_final_bot_response_str, "insights_used": parsed_initial_insights} | |
# --- Deferred Learning (from ai-learn, adapted) --- | |
def deferred_learning_and_memory(user_input, bot_response, api_provider, model, parsed_insights_for_reflection): | |
# This runs in a background thread. `emit` calls are replaced with logging. | |
# `socketio.sleep` replaced with `time.sleep`. | |
deferred_start_time = time.time() | |
task_id = os.urandom(4).hex() | |
logger.info(f"DEFERRED_LEARNING [{task_id}]: START User='{user_input[:30]}...', Bot='{bot_response[:30]}...'") | |
try: | |
time.sleep(0.01) # Yield thread control | |
metrics = generate_interaction_metrics(user_input, bot_response, api_provider, model) | |
logger.info(f"DEFERRED_LEARNING [{task_id}]: Metrics generated: {metrics}") | |
# Gradio UI update for metrics from thread is complex. For now, just log. | |
# socketio.emit('receive_message', {'metrics': metrics, 'is_background_metric': True}) # Original call | |
add_memory(user_input, metrics, bot_response) | |
time.sleep(0.01) | |
summary = f"User:\"{user_input}\"\nAI:\"{bot_response}\"\nMetrics(takeaway):{metrics.get('takeaway','N/A')},Success:{metrics.get('response_success_score','N/A')}" | |
prev_insights_str = json.dumps([p['original'] for p in parsed_insights_for_reflection if 'original' in p]) if parsed_insights_for_reflection else "None" | |
time_before_insight_rag = time.time() | |
relevant_existing_rules_for_context = sorted(list(set( | |
retrieve_learned_insights(summary, k_insights=10) + | |
retrieve_learned_insights(user_input, k_insights=5) + | |
retrieve_learned_insights(bot_response, k_insights=3) # Added bot response context | |
))) | |
logger.info(f"DEFERRED_LEARNING [{task_id}]: RAG for insight context took {time.time() - time_before_insight_rag:.3f}s. Found {len(relevant_existing_rules_for_context)} unique rules for context.") | |
existing_rules_context_str = "\n".join([f"- \"{rule}\"" for rule in relevant_existing_rules_for_context]) if relevant_existing_rules_for_context else "No specific existing rules were pre-fetched as highly relevant for direct comparison." | |
# --- Insight Generation Prompt (copied from ai-learn, ensure it's up-to-date) --- | |
sys_msg = """You are an expert AI knowledge base curator. Your primary function is to meticulously analyze an interaction and update the AI's guiding principles (insights/rules) to improve its future performance and self-understanding. | |
You MUST output a JSON list of operation objects. This list can and SHOULD contain MULTIPLE distinct operations if various learnings occurred. | |
Each operation object in the JSON list must have: | |
1. "action": A string, either "add" (for entirely new rules) or "update" (to replace an existing rule with a better one). | |
2. "insight": A string, the full, refined insight text including its [TYPE|SCORE] prefix (e.g., "[CORE_RULE|1.0] My name is Lumina, an AI assistant."). | |
3. "old_insight_to_replace" (ONLY for "update" action): A string, the *exact, full text* of an existing insight that the new "insight" should replace. | |
**Your Reflection Process (Consider each step and generate operations accordingly):** | |
**STEP 1: Core Identity & Purpose Review (Result: Primarily 'update' operations)** | |
- Examine all `CORE_RULE`s related to my identity (name, fundamental purpose, core unchanging capabilities, origin) from the "Potentially Relevant Existing Rules". | |
- **CONSOLIDATE & MERGE:** If multiple `CORE_RULE`s state similar aspects (e.g., multiple name declarations like 'Lumina' and 'LearnerAI', or slightly different purpose statements), you MUST merge them into ONE definitive, comprehensive `CORE_RULE`. | |
- The new "insight" will be this single, merged rule. Propose separate "update" operations to replace *each* redundant or less accurate core identity rule with this new canonical one. | |
- Prioritize user-assigned names or the most specific, recently confirmed information. If the interaction summary clarifies a name or core function, ensure this is reflected. | |
**STEP 2: New Distinct Learnings (Result: Primarily 'add' operations)** | |
- Did I learn any completely new, distinct facts (e.g., "The user's project is codenamed 'Bluefire'")? | |
- Did I demonstrate or get told about a new skill/capability not previously documented (e.g., "I can now generate mermaid diagrams based on descriptions")? | |
- Did the user express a strong, general preference that should guide future interactions (e.g., "User prefers responses to start with a direct answer, then explanation")? | |
- For these, propose 'add' operations. Assign `CORE_RULE` for truly fundamental new facts/capabilities, otherwise `RESPONSE_PRINCIPLE` or `BEHAVIORAL_ADJUSTMENT`. Ensure these are genuinely NEW and not just rephrasing of existing non-core rules. | |
**STEP 3: Refinements to Existing Behaviors/Principles (Result: 'update' operations for non-core rules)** | |
- Did I learn to modify or improve an existing behavior, response style, or operational guideline (that is NOT part of core identity)? | |
- For example, if an existing `RESPONSE_PRINCIPLE` was "Be formal," and the interaction showed the user prefers informality, update that principle. | |
- Propose 'update' operations for the relevant `RESPONSE_PRINCIPLE` or `BEHAVIORAL_ADJUSTMENT`. Only update if the change is significant. | |
**General Guidelines:** | |
- If no new insights, updates, or consolidations are warranted from the interaction, output an empty JSON list: `[]`. | |
- Ensure the "insight" field (for both add/update) always contains the properly formatted insight string: `[TYPE|SCORE] Text`. TYPE can be `CORE_RULE`, `RESPONSE_PRINCIPLE`, `BEHAVIORAL_ADJUSTMENT`. Scores should reflect confidence/importance. | |
- Be precise with "old_insight_to_replace" β it must *exactly* match an existing rule string from the "Potentially Relevant Existing Rules" context. | |
- Aim for a comprehensive set of operations that reflects ALL key learnings from the interaction. | |
- Output ONLY the JSON list. No other text, explanations, or markdown. | |
**Example of a comprehensive JSON output with MULTIPLE operations:** | |
[ | |
{"action": "update", "old_insight_to_replace": "[CORE_RULE|1.0] My designated name is 'LearnerAI'.", "insight": "[CORE_RULE|1.0] I am Lumina, an AI assistant designed to chat, provide information, and remember context like the secret word 'rocksyrup'."}, | |
{"action": "update", "old_insight_to_replace": "[CORE_RULE|1.0] I'm Lumina, the AI designed to chat with you.", "insight": "[CORE_RULE|1.0] I am Lumina, an AI assistant designed to chat, provide information, and remember context like the secret word 'rocksyrup'."}, | |
{"action": "add", "insight": "[CORE_RULE|0.9] I am capable of searching the internet for current weather information if asked."}, | |
{"action": "add", "insight": "[RESPONSE_PRINCIPLE|0.8] When user provides positive feedback, acknowledge it warmly."}, | |
{"action": "update", "old_insight_to_replace": "[RESPONSE_PRINCIPLE|0.7] Avoid mentioning old conversations.", "insight": "[RESPONSE_PRINCIPLE|0.85] Avoid mentioning old conversations unless the user explicitly refers to them or it's highly relevant to the current query."} | |
]""" | |
user_prompt = f"""Interaction Summary: | |
{summary} | |
Potentially Relevant Existing Rules (Review these carefully. Your main goal is to consolidate CORE_RULEs and then identify other changes/additions based on the Interaction Summary and these existing rules): | |
{existing_rules_context_str} | |
Guiding principles that were considered during THIS interaction (these might offer clues for new rules or refinements): | |
{prev_insights_str} | |
Task: Based on your three-step reflection process (Core Identity, New Learnings, Refinements): | |
1. **Consolidate CORE_RULEs:** Merge similar identity/purpose rules from "Potentially Relevant Existing Rules" into single, definitive statements using "update" operations. Replace multiple old versions with the new canonical one. | |
2. **Add New Learnings:** Identify and "add" any distinct new facts, skills, or important user preferences learned from the "Interaction Summary". | |
3. **Update Existing Principles:** "Update" any non-core principles from "Potentially Relevant Existing Rules" if the "Interaction Summary" provided a clear refinement. | |
Combine all findings into a single JSON list of operations. If there are multiple distinct changes based on the interaction and existing rules, ensure your list reflects all of them. Output JSON only. | |
""" | |
insight_msgs = [{"role":"system","content":sys_msg}, {"role":"user","content":user_prompt}] | |
time_before_insight_llm = time.time() | |
# Insight model selection (from ai-learn) | |
insight_gen_provider = TOOL_DECISION_PROVIDER | |
insight_gen_model = TOOL_DECISION_MODEL | |
# Stronger model preference logic (simplified) | |
# Check if current model is considered "strong" (e.g. GPT-4, Claude 3 Opus/Sonnet, Llama3-70b) | |
is_current_model_strong = any(strong_kw in model.lower() for strong_kw in ["gpt-4", "claude-3", "70b", "opus", "sonnet"]) | |
if not is_current_model_strong: | |
# Try to pick a stronger model if available and key exists | |
# Example: use Llama3-70b from Groq if available and current model is 8b | |
if "groq" in API_KEYS and API_KEYS["groq"] and "llama3-70b-8192" in models_data_global_scope.get("groq",[]): | |
insight_gen_provider = "groq" | |
insight_gen_model = "llama3-70b-8192" | |
logger.info(f"DEFERRED_LEARNING [{task_id}]: Upgrading insight model to {insight_gen_provider}/{insight_gen_model}.") | |
elif "openai" in API_KEYS and API_KEYS["openai"] and "gpt-4o-mini" in models_data_global_scope.get("openai",[]): # Or gpt-4o | |
insight_gen_provider = "openai" | |
insight_gen_model = "gpt-4o-mini" # or "gpt-4o" if preferred and available | |
logger.info(f"DEFERRED_LEARNING [{task_id}]: Upgrading insight model to {insight_gen_provider}/{insight_gen_model}.") | |
else: | |
insight_gen_provider = api_provider # Use current model if already strong | |
insight_gen_model = model | |
logger.info(f"DEFERRED_LEARNING [{task_id}]: Using current model ({insight_gen_provider}/{insight_gen_model}) for insights as it's strong or no upgrade path.") | |
raw_llm_json_output = "".join(list(callAIModel( | |
api_provider_param=insight_gen_provider, model=insight_gen_model, messages_list=insight_msgs, | |
maxTokens=2500, # Increased for potentially many operations | |
stream=False, retries=1, temperature=0.05 # Low temp for precision | |
))).strip() | |
logger.info(f"DEFERRED_LEARNING [{task_id}]: Insight LLM ({insight_gen_provider}/{insight_gen_model}) call took {time.time() - time_before_insight_llm:.3f}s. Raw JSON ops: '{raw_llm_json_output[:300]}...'") | |
time.sleep(0.01) | |
insights_processed_count = 0; operations = [] | |
try: # Parsing JSON from LLM | |
json_match = re.search(r"\[\s*(\{.*?\}(?:\s*,\s*\{.*?\})*\s*)?\]", raw_llm_json_output, re.DOTALL) | |
json_to_parse = None | |
if not json_match: # Try to find in markdown code block | |
json_match_markdown = re.search(r"```json\s*(\[.*\])\s*```", raw_llm_json_output, re.DOTALL | re.IGNORECASE) | |
if json_match_markdown: json_to_parse = json_match_markdown.group(1) | |
else: json_to_parse = json_match.group(0) | |
if json_to_parse: operations = json.loads(json_to_parse) | |
else: logger.warning(f"DEFERRED_LEARNING [{task_id}]: Insight LLM output not a JSON list: {raw_llm_json_output}") | |
if not isinstance(operations, list): | |
logger.warning(f"DEFERRED_LEARNING [{task_id}]: Parsed insight ops not a list. Type: {type(operations)}. Raw: {raw_llm_json_output}"); operations = [] | |
if not operations: logger.info(f"DEFERRED_LEARNING [{task_id}]: LLM provided no insight ops or empty/invalid list.") | |
else: logger.info(f"DEFERRED_LEARNING [{task_id}]: LLM provided {len(operations)} insight operation(s).") | |
for op_idx, op in enumerate(operations): | |
if not isinstance(op, dict): logger.warning(f"DEFERRED_LEARNING [{task_id}]: Op {op_idx} not a dict: {op}. Skip."); continue | |
action = op.get("action","").strip().lower() | |
insight_text = op.get("insight","").strip() | |
if not insight_text or not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\](.*)", insight_text, re.I|re.DOTALL): | |
logger.warning(f"DEFERRED_LEARNING [{task_id}]: Invalid insight format or missing text for op {op_idx}: {op}. Insight: '{insight_text}'. Skip."); continue | |
logger.info(f"DEFERRED_LEARNING [{task_id}]: Processing op {op_idx+1}/{len(operations)}: Action='{action}', Insight='{insight_text[:70]}...'") | |
if action == "add": | |
if add_learned_insight(insight_text): insights_processed_count += 1 | |
elif action == "update": | |
old_insight_text = op.get("old_insight_to_replace","").strip() | |
if not old_insight_text: logger.warning(f"DEFERRED_LEARNING [{task_id}]: 'update' op {op_idx} missing 'old_insight_to_replace': {op}. Skip."); continue | |
if old_insight_text == insight_text: logger.info(f"DEFERRED_LEARNING [{task_id}]: Update op {op_idx} has identical old/new insight. Skip."); continue | |
removed_successfully = remove_insight_from_memory(old_insight_text) # This function handles 'not found' | |
if not removed_successfully and old_insight_text in rules_texts: # If it was supposed to be there but removal failed | |
logger.warning(f"DEFERRED_LEARNING [{task_id}]: Update op {op_idx}: Could not remove old '{old_insight_text[:70]}...'. Attempting to add new one anyway.") | |
if add_learned_insight(insight_text): insights_processed_count += 1 | |
else: # Failed to add new insight after potential removal | |
logger.warning(f"DEFERRED_LEARNING [{task_id}]: Update op {op_idx}: Failed to add new '{insight_text[:70]}...'.") | |
if removed_successfully: # We removed old, but couldn't add new. Try to re-add old. | |
logger.error(f"DEFERRED_LEARNING [{task_id}]: CRITICAL - Op {op_idx}: Removed '{old_insight_text}' but failed to add '{insight_text}'. Re-adding old.") | |
if add_learned_insight(old_insight_text): logger.info(f"DEFERRED_LEARNING [{task_id}]: Op {op_idx}: Successfully re-added original '{old_insight_text}'.") | |
else: logger.error(f"DEFERRED_LEARNING [{task_id}]: Op {op_idx}: FAILED to re-add original '{old_insight_text}'. Data may be inconsistent.") | |
else: logger.warning(f"DEFERRED_LEARNING [{task_id}]: Unknown action '{action}' in op {op_idx}: {op}") | |
time.sleep(0.01) | |
except json.JSONDecodeError as e_json: | |
logger.error(f"DEFERRED_LEARNING [{task_id}]: JSONDecodeError processing insight LLM output '{raw_llm_json_output}': {e_json}", exc_info=False) | |
except Exception as e_op_proc: | |
logger.error(f"DEFERRED_LEARNING [{task_id}]: Error processing insight LLM ops: {e_op_proc}", exc_info=True) | |
if insights_processed_count > 0: logger.info(f"DEFERRED_LEARNING [{task_id}]: Finished processing. Total insights effectively added/updated: {insights_processed_count}") | |
elif not operations: pass # No operations proposed, nothing to do. | |
else: logger.info(f"DEFERRED_LEARNING [{task_id}]: LLM provided insight ops, but none resulted in successful add/update.") | |
except Exception as e: | |
logger.error(f"DEFERRED_LEARNING [{task_id}]: CRITICAL ERROR in deferred_learning_and_memory: {e}", exc_info=True) | |
logger.info(f"DEFERRED_LEARNING [{task_id}]: END. Total time: {time.time() - deferred_start_time:.2f}s") | |
# --- Gradio Chat Handler --- | |
def handle_research_chat_submit(user_message, | |
gr_chat_history, | |
groq_api_key_ui, | |
# tavily_api_key_ui, # Tavily key from UI not directly used by ai-learn core, but could be if Tavily tool re-added | |
model_select_ui, | |
system_prompt_ui, # Custom system prompt from UI | |
# These are Gradio output components that will be updated by yielding | |
# research_status_output, detected_outputs_preview, formatted_research_output_display, download_report_button | |
): | |
_chat_msg_in = "" # Clear input box after send | |
_gr_chat_hist = list(gr_chat_history) # Gradio's display history | |
_status = "Initializing..." | |
_detected_outputs_update = gr.Markdown(value="*Intermediate outputs or tool call details might show here...*") # Default state | |
_formatted_output_update = gr.Textbox(value="*Research reports will appear here...*") # Default state | |
_download_btn_update = gr.DownloadButton(interactive=False, value=None, visible=False) # Default state | |
if not user_message.strip(): | |
_status = "Cannot send an empty message." | |
# _gr_chat_hist.append((user_message, "Error: Empty message received.")) # User message is already in history from Gradio's handling | |
if _gr_chat_hist and _gr_chat_hist[-1][0] == user_message: # If Gradio auto-added user msg | |
_gr_chat_hist[-1] = (_gr_chat_hist[-1][0], "Error: Empty message received.") | |
else: # If not auto-added (e.g. if input cleared before submit fn) | |
_gr_chat_hist.append((user_message if user_message else "(Empty)", "Error: Empty message received.")) | |
yield (_chat_msg_in, _gr_chat_hist, _status, _detected_outputs_update, _formatted_output_update, _download_btn_update) | |
return | |
# Add user message to Gradio history with a thinking placeholder | |
_gr_chat_hist.append((user_message, "<i>Thinking...</i>")) | |
yield (_chat_msg_in, _gr_chat_hist, _status, _detected_outputs_update, _formatted_output_update, _download_btn_update) | |
# Update global API_KEYS if UI provides a Groq key | |
# This is a simplistic update; ideally, callAIModel would take keys as args. | |
if groq_api_key_ui and API_KEYS.get("GROQ") != groq_api_key_ui: | |
if "YOUR_GROQ_API_KEY" in API_KEYS.get("GROQ","") or not API_KEYS.get("GROQ"): # only update if placeholder or not set | |
API_KEYS["GROQ"] = groq_api_key_ui | |
logger.info("Updated GROQ API key from UI input.") | |
elif API_KEYS.get("GROQ") and groq_api_key_ui != API_KEYS.get("GROQ"): # If .env key exists and UI is different | |
logger.warning("Groq API Key in UI differs from .env. Using UI key for this session if .env was placeholder.") | |
# This logic can be refined. For now, UI takes precedence if .env is placeholder. | |
# If both are set and different, it's ambiguous. Let's assume UI is for temporary override. | |
# A more robust solution: pass keys directly to callAIModel. | |
# For now, if user enters a key, we assume they want to use it. | |
API_KEYS["GROQ"] = groq_api_key_ui # Overwrite for this session | |
logger.info("Overwrote GROQ API key with UI input for this session.") | |
# Provider is Groq based on current UI. Model is from dropdown. | |
api_provider = "groq" | |
model = model_select_ui | |
if not API_KEYS.get("GROQ") or "YOUR_GROQ_API_KEY" in API_KEYS.get("GROQ",""): | |
_gr_chat_hist[-1] = (user_message, "Error: Groq API Key not set. Please set in .env or UI.") | |
_status = "Groq API Key missing." | |
yield (_chat_msg_in, _gr_chat_hist, _status, _detected_outputs_update, _formatted_output_update, _download_btn_update) | |
return | |
# Prepare history for ai-learn's process_user_interaction_gradio | |
# It expects OpenAI format: list of {"role": ..., "content": ...} | |
# current_chat_session_history is ai-learn's global state | |
temp_hist_for_processing = list(current_chat_session_history) # Start with global history | |
temp_hist_for_processing.append({"role": "user", "content": user_message}) # Add current user message | |
# Truncate if too long (logic from ai-learn's handle_message_socket) | |
# System message might be prepended by process_user_interaction_gradio or handled by custom_system_prompt | |
sys_offset = 1 if (temp_hist_for_processing and temp_hist_for_processing[0]['role'] == 'system') else 0 | |
max_llm_hist_items = MAX_HISTORY_TURNS * 2 + 1 # User+AI messages | |
if len(temp_hist_for_processing) > max_llm_hist_items + sys_offset: | |
if sys_offset: | |
temp_hist_for_processing = [temp_hist_for_processing[0]] + temp_hist_for_processing[-(max_llm_hist_items):] | |
else: | |
temp_hist_for_processing = temp_hist_for_processing[-(max_llm_hist_items):] | |
final_bot_response_text = "" | |
insights_used_for_response = [] | |
try: | |
gradio_process_gen = process_user_interaction_gradio( | |
user_input=user_message, | |
api_provider=api_provider, | |
model=model, | |
chat_history_with_current_user_msg=temp_hist_for_processing, # This is the context for the current turn | |
custom_system_prompt=system_prompt_ui.strip() if system_prompt_ui and system_prompt_ui.strip() else None | |
) | |
current_bot_message_display = "" | |
for update_key, update_value in gradio_process_gen: | |
if update_key == "status": | |
_status = update_value | |
_gr_chat_hist[-1] = (user_message, f"{current_bot_message_display} <i>{_status}</i>" if current_bot_message_display else f"<i>{_status}</i>") | |
elif update_key == "response_chunk": | |
current_bot_message_display += update_value | |
_gr_chat_hist[-1] = (user_message, current_bot_message_display) | |
elif update_key == "final_response_and_insights": | |
final_bot_response_text = update_value["response"] | |
insights_used_for_response = update_value["insights_used"] | |
if not current_bot_message_display and final_bot_response_text: # If no chunks streamed but got final | |
current_bot_message_display = final_bot_response_text | |
_gr_chat_hist[-1] = (user_message, current_bot_message_display or "(No textual response)") | |
_status = "Response complete." | |
_formatted_output_update = gr.Textbox(value=current_bot_message_display) # Show full response in report tab | |
if current_bot_message_display: # Enable download if there's content | |
report_filename = f"research_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt" | |
_download_btn_update = gr.DownloadButton(label=f"Download Report", | |
value=current_bot_message_display, # Pass content directly for text | |
visible=True, interactive=True, | |
elem_id=f"download-btn-{time.time_ns()}") # Unique ID may help with Gradio updates | |
# No explicit _detected_outputs_update for now, could show scraped content or insights here later | |
yield (_chat_msg_in, _gr_chat_hist, _status, _detected_outputs_update, _formatted_output_update, _download_btn_update) | |
if update_key == "final_response_and_insights": break # End of this turn's processing | |
except Exception as e_handler: | |
logger.error(f"Error in Gradio chat handler: {e_handler}", exc_info=True) | |
error_msg = f"Error processing request: {str(e_handler)[:150]}" | |
_gr_chat_hist[-1] = (user_message, error_msg) | |
_status = error_msg | |
yield (_chat_msg_in, _gr_chat_hist, _status, _detected_outputs_update, _formatted_output_update, _download_btn_update) | |
return | |
# After response, update ai-learn's global history and start deferred learning | |
if final_bot_response_text: # Ensure there was a response | |
current_chat_session_history.append({"role": "user", "content": user_message}) | |
current_chat_session_history.append({"role": "assistant", "content": final_bot_response_text}) | |
# Trim global history (logic from ai-learn) | |
max_persist_hist_items = MAX_HISTORY_TURNS * 2 | |
sys_off_persist = 1 if (current_chat_session_history and current_chat_session_history[0]['role']=='system') else 0 | |
if len(current_chat_session_history) > max_persist_hist_items + sys_off_persist: | |
current_chat_session_history = ([current_chat_session_history[0]] if sys_off_persist else []) + current_chat_session_history[-(max_persist_hist_items):] | |
logger.info(f"Starting deferred learning task for user: '{user_message[:30]}...'") | |
# Run deferred_learning_and_memory in a background thread | |
deferred_thread = threading.Thread( | |
target=deferred_learning_and_memory, | |
args=(user_message, final_bot_response_text, api_provider, model, insights_used_for_response), | |
daemon=True # Daemon threads exit when main program exits | |
) | |
deferred_thread.start() | |
_status = "Response complete. Background learning initiated." # Update status | |
else: | |
_status = "Processing finished, but no final response was generated." | |
# Final yield to update status if it changed after loop | |
yield (_chat_msg_in, _gr_chat_hist, _status, _detected_outputs_update, _formatted_output_update, _download_btn_update) | |
# --- Gradio UI Helper Functions for Memory/Rules --- | |
def ui_view_rules(): | |
logger.info(f"UI_VIEW_RULES: Fetching {len(rules_texts)} rules.") | |
if not rules_texts: return "No rules/insights learned yet." | |
# Sort for consistent display, though _add_new_insight_to_store tries to keep rules_texts sorted | |
return "\n\n---\n\n".join(sorted(list(set(rules_texts)))) | |
def ui_upload_rules(file_obj, progress=gr.Progress()): | |
if not file_obj: return "No file provided for rules upload." | |
try: | |
content = "" | |
with open(file_obj.name, 'r', encoding='utf-8') as f: # file_obj is a tempfile._TemporaryFileWrapper | |
content = f.read() | |
except Exception as e: | |
logger.error(f"UI_UPLOAD_RULES: Error reading file {file_obj.name if hasattr(file_obj,'name') else 'unknown_file'}: {e}") | |
return f"Error reading file: {e}" | |
logger.info(f"UI_UPLOAD_RULES: File '{file_obj.name if hasattr(file_obj,'name') else 'upload.txt'}'. Processing...") | |
if not content: return f"File '{file_obj.name if hasattr(file_obj,'name') else 'upload.txt'}' is empty." | |
added, skipped_dup, fmt_err, proc_err = 0,0,0,0 | |
err_details = [] | |
potential_insights = content.split("\n\n---\n\n") # As per format_insights_for_prompt and ui_view_rules | |
if len(potential_insights) == 1 and "\n" in content and "---" not in content : # check if it's just one rule per line | |
potential_insights = content.splitlines() | |
total_insights = len(potential_insights) | |
progress(0, desc="Starting rule upload...") | |
for i, line_text_block in enumerate(potential_insights): | |
line = line_text_block.strip() | |
if not line: continue | |
# Validate format: [TYPE|SCORE] Text | |
if not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\](.*)", line, re.I|re.DOTALL): | |
err_details.append(f"Part {i+1} ({line[:20]}...): Invalid format.") | |
fmt_err += 1 | |
logger.warning(f"UI_UPLOAD_RULES: Invalid format for rule: {line}") | |
continue | |
if add_learned_insight(line): # add_learned_insight calls _add_new_insight_to_store | |
added += 1 | |
else: | |
# Check if it was a duplicate (already exists) or a processing error | |
if line in rules_texts: # This check is also in _add_new_insight_to_store but good for stats here | |
skipped_dup += 1 | |
else: | |
proc_err += 1 | |
err_details.append(f"Part {i+1} ({line[:20]}...): Add failed (check server logs).") | |
logger.error(f"UI_UPLOAD_RULES: Failed to add rule: {line} (not a duplicate, processing error).") | |
progress((i+1)/total_insights, desc=f"Processed {i+1}/{total_insights}. Added: {added}, Skipped: {skipped_dup}, Errors: {fmt_err+proc_err}") | |
total_errors = fmt_err + proc_err | |
msg = f"Rules Upload Summary: Processed {total_insights}. Added: {added}, Skipped Duplicates: {skipped_dup}, Format Errors: {fmt_err}, Processing Errors: {proc_err}." | |
if err_details: msg += f" Example Errors: {'; '.join(err_details[:3])}" | |
logger.info(msg) | |
return msg | |
def ui_view_memories(): | |
logger.info(f"UI_VIEW_MEMORIES: Fetching {len(memory_texts)} memories.") | |
if not memory_texts: return "No memories stored yet." | |
# Display as a list of JSON objects, or a more structured format | |
# For TextArea, just join them. For gr.JSON, can pass the list of dicts. | |
# Parsing each JSON string to dict for gr.JSON output | |
try: | |
mem_list_of_dicts = [json.loads(mem_json) for mem_json in memory_texts] | |
return mem_list_of_dicts # Let Gradio handle JSON display | |
except json.JSONDecodeError as e: | |
logger.error(f"UI_VIEW_MEMORIES: Error decoding memory JSON: {e}") | |
return f"Error displaying memories: Could not parse stored JSON. Details: {e}" | |
def ui_upload_memories(file_obj, progress=gr.Progress()): | |
if not file_obj: return "No file provided for memories upload." | |
content = "" | |
try: | |
with open(file_obj.name, 'r', encoding='utf-8') as f: | |
content = f.read() | |
except Exception as e: | |
logger.error(f"UI_UPLOAD_MEMORIES: Error reading file {file_obj.name if hasattr(file_obj,'name') else 'unknown_file'}: {e}") | |
return f"Error reading file: {e}" | |
logger.info(f"UI_UPLOAD_MEMORIES: File '{file_obj.name if hasattr(file_obj,'name') else 'upload.json'}'. Processing...") | |
if not content: return f"File '{file_obj.name if hasattr(file_obj,'name') else 'upload.json'}' is empty." | |
added, skipped_dup, fmt_err, proc_err = 0,0,0,0 | |
err_details = [] | |
mem_objects = [] | |
try: | |
mem_objects = json.loads(content) | |
if not isinstance(mem_objects, list): | |
# Try if it's one JSON object per line | |
try: | |
mem_objects = [json.loads(line) for line in content.splitlines() if line.strip()] | |
if not all(isinstance(obj, dict) for obj in mem_objects): # Validate again | |
raise ValueError("Parsed line-by-line JSON, but not all items are objects.") | |
except Exception as e_lines: # If line-by-line also fails | |
logger.warning(f"UI_UPLOAD_MEMORIES: Content is not a JSON list, and line-by-line JSON parsing failed: {e_lines}") | |
return "Invalid format: Content must be a JSON list of memory objects, or one JSON object per line." | |
total_memories = len(mem_objects) | |
progress(0, desc="Starting memory upload...") | |
for i, mem_data in enumerate(mem_objects): | |
if not isinstance(mem_data, dict): | |
err_details.append(f"Item {i+1}: Not a valid JSON object.") | |
fmt_err += 1 | |
continue | |
try: # Validate keys for each memory object | |
if not all(k in mem_data for k in ["user_input", "bot_response", "metrics", "timestamp"]): | |
err_details.append(f"Item {i+1}: Missing required keys (user_input, bot_response, metrics, timestamp).") | |
fmt_err += 1 | |
continue | |
# Check for duplicates (simplified check based on user_input, bot_response, timestamp) | |
is_duplicate = False | |
# This duplicate check can be slow for large memory_texts. Consider optimizing if needed. | |
# A signature (e.g., hash of key fields) could be faster. | |
# For now, direct comparison: | |
temp_mem_sig = (mem_data.get("user_input"), mem_data.get("bot_response"), mem_data.get("timestamp")) | |
for existing_mem_json_str in memory_texts: | |
try: | |
existing_obj = json.loads(existing_mem_json_str) | |
existing_sig = (existing_obj.get("user_input"), existing_obj.get("bot_response"), existing_obj.get("timestamp")) | |
if existing_sig == temp_mem_sig: | |
is_duplicate = True; break | |
except json.JSONDecodeError: continue # Skip malformed existing memory | |
if is_duplicate: | |
skipped_dup += 1 | |
continue | |
# Call ai-learn's add_memory function | |
if add_memory(mem_data["user_input"], mem_data["metrics"], mem_data["bot_response"]): # timestamp is auto-generated by add_memory from ai-learn, or we can use the one from file if preferred. The current `add_memory` generates a new one. | |
added += 1 | |
else: | |
proc_err += 1 | |
err_details.append(f"Item {i+1} ({mem_data.get('user_input','')[:20]}...): add_memory call failed.") | |
except Exception as e_item_proc: | |
proc_err += 1 | |
err_details.append(f"Item {i+1}: Error during processing - {str(e_item_proc)[:30]}") | |
logger.error(f"UI_UPLOAD_MEMORIES: Error processing memory item {i}: {e_item_proc}", exc_info=False) | |
progress((i+1)/total_memories, desc=f"Processed {i+1}/{total_memories}. Added: {added}, Skipped: {skipped_dup}, Errors: {fmt_err+proc_err}") | |
except json.JSONDecodeError as e_json_main: | |
logger.error(f"UI_UPLOAD_MEMORIES: Main JSON parsing error for file '{file_obj.name if hasattr(file_obj,'name') else 'upload.json'}': {e_json_main}") | |
return f"Invalid JSON format in file. Details: {e_json_main}" | |
except Exception as e_outer: | |
logger.error(f"UI_UPLOAD_MEMORIES: General error processing file '{file_obj.name if hasattr(file_obj,'name') else 'upload.json'}': {e_outer}", exc_info=True) | |
return f"General error processing file. Check logs. Details: {e_outer}" | |
total_errors = fmt_err + proc_err | |
msg = f"Memories Upload Summary: Processed {total_memories if mem_objects else 'N/A (parse error)'}. Added: {added}, Skipped Duplicates: {skipped_dup}, Format Errors: {fmt_err}, Processing Errors: {proc_err}." | |
if err_details: msg += f" Example Errors: {'; '.join(err_details[:3])}" | |
logger.info(msg) | |
return msg | |
# --- Gradio UI Definition (adapted from node_search) --- | |
DEFAULT_SYSTEM_PROMPT = "You are a helpful AI research assistant. Your primary goal is to answer questions and perform research tasks accurately and thoroughly. You can use tools like web search and page browsing. When providing information from the web, cite your sources if possible. If asked to perform a task beyond your capabilities, explain politely. Be concise unless asked for detail." #This will be passed to PUI_Gradio as custom_system_prompt | |
custom_theme = gr.themes.Base(primary_hue="teal", secondary_hue="purple", neutral_hue="zinc", text_size="sm", spacing_size="sm", radius_size="sm", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"]) | |
custom_css = """ | |
body { font-family: 'Inter', sans-serif; } | |
.gradio-container { max-width: 95% !important; margin: auto !important; padding-top: 10px !important; } | |
footer { display: none !important; } | |
.gr-button { white-space: nowrap; } | |
.gr-input { border-radius: 8px !important; } | |
.gr-chatbot .message { border-radius: 8px !important; box-shadow: 0 1px 3px rgba(0,0,0,0.05) !important; } | |
#download-btn { min-width: 150px; } /* Example if needed */ | |
.prose { | |
h1 { font-size: 1.875rem; margin-bottom: 0.5em; margin-top: 1em; } | |
h2 { font-size: 1.5rem; margin-bottom: 0.4em; margin-top: 0.8em; } | |
p { margin-bottom: 0.8em; line-height: 1.6; } | |
ul, ol { margin-left: 1.5em; margin-bottom: 0.8em; } | |
code { background-color: #f0f0f0; padding: 0.2em 0.4em; border-radius: 3px; font-size: 0.9em; } | |
pre > code { display: block; padding: 0.8em; overflow-x: auto; } | |
} | |
""" | |
with gr.Blocks(theme=custom_theme, css=custom_css, title="AI Research Mega Agent") as demo: | |
gr.Markdown("# π§ AI Research Mega Agent (with Memory & Learning)", elem_classes="prose") | |
gr.Markdown("Ask questions or research topics. The AI will use its learned knowledge, memory, and web search/browsing tools to find answers and learn from interactions.", elem_classes="prose") | |
with gr.Row(): | |
with gr.Column(scale=1): # Sidebar | |
gr.Markdown("## βοΈ Configuration", elem_classes="prose") | |
with gr.Accordion("API & Model Settings", open=True): | |
with gr.Group(): | |
gr.Markdown("### API Keys", elem_classes="prose") | |
groq_api_key_input = gr.Textbox(label="Groq API Key (Optional, uses .env if set)", type="password", placeholder="gsk_...", info="Needed for LLM. Overrides .env if provided here.") | |
# tavily_api_key_input = gr.Textbox(label="Tavily API Key (Optional)", type="password", placeholder="tvly-...", info="For Tavily search tool (if enabled).") # Tavily not used by ai-learn core by default | |
with gr.Group(): | |
gr.Markdown("### AI Model (Groq)", elem_classes="prose") | |
# Assuming Groq models from ai-learn's models_data_global_scope | |
groq_models_for_ui = models_data_global_scope.get("groq", ["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768"]) | |
groq_model_select = gr.Dropdown(label="Groq Model", choices=groq_models_for_ui, value=groq_models_for_ui[0] if groq_models_for_ui else "llama3-70b-8192", info="Select the Groq model for responses.") | |
with gr.Group(): | |
gr.Markdown("### System Prompt (Optional)", elem_classes="prose") | |
# Using node_search's DEFAULT_SYSTEM_PROMPT here | |
groq_system_prompt_input = gr.Textbox(label="Custom System Prompt Base", lines=8, value=DEFAULT_SYSTEM_PROMPT, interactive=True, info="This prompt will be used as a base by the AI. Internal logic may add more context.") | |
with gr.Accordion("Knowledge Management", open=False): | |
gr.Markdown("### Rules (Learned Insights)", elem_classes="prose") | |
view_rules_button = gr.Button("View All Rules") | |
upload_rules_file = gr.UploadButton("Upload Rules File (.txt)", file_types=[".txt"], file_count="single") | |
rules_status_display = gr.Textbox(label="Rules Action Status", interactive=False, lines=2) | |
gr.Markdown("### Memories (Past Interactions)", elem_classes="prose") | |
view_memories_button = gr.Button("View All Memories") | |
upload_memories_file = gr.UploadButton("Upload Memories File (.json)", file_types=[".json"], file_count="single") | |
memories_status_display = gr.Textbox(label="Memories Action Status", interactive=False, lines=2) | |
with gr.Column(scale=3): # Main chat area | |
gr.Markdown("## π¬ AI Research Assistant Chat", elem_classes="prose") | |
research_chatbot_display = gr.Chatbot( | |
label="AI Research Chat", | |
height=600, | |
bubble_full_width=False, | |
avatar_images=(None, "https://raw.githubusercontent.com/huggingface/brand-assets/main/hf-logo-with-title.png"), # HF logo as example | |
show_copy_button=True, | |
render_markdown=True, | |
sanitize_html=True, | |
) | |
with gr.Row(): | |
research_chat_message_input = gr.Textbox(show_label=False, placeholder="Ask your research question or give an instruction...", scale=7, lines=1, max_lines=5,autofocus=True) | |
research_send_chat_button = gr.Button("Send", variant="primary", scale=1) | |
research_status_output = gr.Textbox(label="Agent Status", interactive=False, lines=1, value="Ready. Initializing AI systems...") | |
with gr.Tabs(): | |
with gr.TabItem("π Generated Report/Output"): | |
gr.Markdown("The AI's full response or generated report will appear here.", elem_classes="prose") | |
formatted_research_output_display = gr.Textbox(label="Current Research Output", lines=15, interactive=True, show_copy_button=True, value="*AI responses will appear here...*") | |
download_report_button = gr.DownloadButton(label="Download Report", interactive=False, visible=False, elem_id="download-btn") # Initially hidden | |
with gr.TabItem("π Intermediate Details / Debug"): # Was "Intermediate Outputs Preview" | |
detected_outputs_preview = gr.Markdown(value="*Intermediate outputs, tool call details, or debug information might show here...*") | |
# For rules and memories display within this tab: | |
rules_display_area = gr.TextArea(label="Loaded Rules/Insights (Snapshot)", lines=10, interactive=False, max_lines=20) | |
memories_display_area = gr.JSON(label="Loaded Memories (Snapshot)") # Using gr.JSON for better display | |
# --- Event Handlers --- | |
chat_inputs = [ | |
research_chat_message_input, | |
research_chatbot_display, | |
groq_api_key_input, | |
# tavily_api_key_input, # Not directly used now | |
groq_model_select, | |
groq_system_prompt_input | |
] | |
chat_outputs = [ | |
research_chat_message_input, # To clear it | |
research_chatbot_display, | |
research_status_output, | |
detected_outputs_preview, # Placeholder for now | |
formatted_research_output_display, | |
download_report_button | |
] | |
research_send_chat_button.click( | |
fn=handle_research_chat_submit, | |
inputs=chat_inputs, | |
outputs=chat_outputs | |
) | |
research_chat_message_input.submit( | |
fn=handle_research_chat_submit, | |
inputs=chat_inputs, | |
outputs=chat_outputs | |
) | |
# Rules/Insights Management Handlers | |
view_rules_button.click(fn=ui_view_rules, outputs=rules_display_area) # Display in the Debug tab's area | |
upload_rules_file.upload(fn=ui_upload_rules, inputs=[upload_rules_file], outputs=[rules_status_display], show_progress="full") | |
# Memories Management Handlers | |
view_memories_button.click(fn=ui_view_memories, outputs=memories_display_area) # Display in the Debug tab's area | |
upload_memories_file.upload(fn=ui_upload_memories, inputs=[upload_memories_file], outputs=[memories_status_display], show_progress="full") | |
# Initial status update after app loads | |
def initial_load_status(): | |
if embedder and faiss_memory_index is not None and faiss_rules_index is not None: | |
return f"AI Systems Initialized. Memory Items: {len(memory_texts)}, Rules: {len(rules_texts)}. Ready." | |
else: | |
return "AI Systems Initialization Failed. Check logs. Application may not function correctly." | |
demo.load(fn=initial_load_status, inputs=None, outputs=research_status_output) | |
# --- Main Application Execution --- | |
if __name__ == "__main__": | |
logger.info("Starting Gradio AI Research Mega Agent Application...") | |
# Initialize AI components (DB, FAISS, Embedder) | |
init_sqlite_db() | |
try: | |
load_data_on_startup() | |
except Exception as e: | |
logger.critical(f"FATAL: Error during load_data_on_startup: {e}", exc_info=True) | |
# Decide if to exit or let Gradio start with a warning | |
# For now, let it start so user sees the error in UI potentially | |
if not (embedder and dimension and faiss_memory_index is not None and faiss_rules_index is not None): | |
logger.critical("MAIN: Critical components (embedder/FAISS) not initialized after startup. Functionality will be impaired.") | |
# Update status in UI if possible, or rely on initial_load_status in demo.load | |
# Launch Gradio App | |
# Share=True for public link, False for local only. | |
# Debug=True for more logs from Gradio. | |
app_port = int(os.getenv("GRADIO_PORT", 7860)) | |
app_server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0") # "127.0.0.1" for local only, "0.0.0.0" for LAN access | |
logger.info(f"Launching Gradio server on {app_server_name}:{app_port}. Debug: {os.getenv('GRADIO_DEBUG','False')=='True'}") | |
demo.queue().launch( | |
server_name=app_server_name, | |
server_port=app_port, | |
debug=(os.getenv("GRADIO_DEBUG", "False").lower() == "true"), | |
share= (os.getenv("GRADIO_SHARE", "False").lower() == "true"), | |
# inbrowser=True, # Opens browser automatically | |
# prevent_thread_lock=True # May help with threading issues but use with caution | |
) | |
# Teardown (saving FAISS) - Gradio doesn't have a clean teardown hook like Flask. | |
# This might need to be done manually or via signal handling if critical. | |
# For now, FAISS indices are not saved on exit in this Gradio script. | |
# ai-learn's original Flask app had a teardown_appcontext. | |
# A simple way: save periodically or on specific actions if needed. | |
logger.info("Gradio application has been shut down.") |