Spaces:
Runtime error
Runtime error
import os | |
import json | |
import re | |
import logging | |
import time | |
import xml.etree.ElementTree as ET | |
from model_logic import call_model_stream, MODELS_BY_PROVIDER | |
from memory_logic import add_memory_entry, retrieve_rules_semantic, remove_rule_entry, add_rule_entry | |
import prompts | |
logger = logging.getLogger(__name__) | |
def generate_interaction_metrics(user_input: str, bot_response: str, provider: str, model_display_name: str, api_key_override: str = None) -> dict: | |
metric_start_time = time.time() | |
logger.info(f"Generating metrics with: {provider}/{model_display_name}") | |
metric_prompt_content = prompts.get_metrics_user_prompt(user_input, bot_response) | |
metric_messages = [{"role": "system", "content": prompts.METRICS_SYSTEM_PROMPT}, {"role": "user", "content": metric_prompt_content}] | |
try: | |
metrics_provider, metrics_model_display = provider, model_display_name | |
metrics_model_env = os.getenv("METRICS_MODEL") | |
if metrics_model_env and "/" in metrics_model_env: | |
m_prov, m_id = metrics_model_env.split('/', 1) | |
m_disp_name = next((dn for dn, mid in MODELS_BY_PROVIDER.get(m_prov.lower(), {}).get("models", {}).items() if mid == m_id), None) | |
if m_disp_name: metrics_provider, metrics_model_display = m_prov, m_disp_name | |
response_chunks = list(call_model_stream(provider=metrics_provider, model_display_name=metrics_model_display, messages=metric_messages, api_key_override=api_key_override, temperature=0.05, max_tokens=200)) | |
resp_str = "".join(response_chunks).strip() | |
json_match = re.search(r"\{.*\}", resp_str, re.DOTALL) | |
if json_match: metrics_data = json.loads(json_match.group(0)) | |
else: | |
logger.warning(f"METRICS_GEN: Non-JSON response from {metrics_provider}/{metrics_model_display}: '{resp_str}'") | |
return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": "metrics format error"} | |
parsed_metrics = {"takeaway": metrics_data.get("takeaway", "N/A"), "response_success_score": float(metrics_data.get("response_success_score", 0.5)), "future_confidence_score": float(metrics_data.get("future_confidence_score", 0.5))} | |
logger.info(f"METRICS_GEN: Generated in {time.time() - metric_start_time:.2f}s. Data: {parsed_metrics}") | |
return parsed_metrics | |
except Exception as e: | |
logger.error(f"METRICS_GEN Error: {e}") | |
return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": str(e)} | |
def perform_post_interaction_learning(user_input: str, bot_response: str, provider: str, model_disp_name: str, insights_reflected: list[dict], api_key_override: str = None): | |
task_id = os.urandom(4).hex() | |
logger.info(f"LEARNING [{task_id}]: Start User='{user_input[:40]}...'") | |
learning_start_time = time.time() | |
try: | |
metrics = generate_interaction_metrics(user_input, bot_response, provider, model_disp_name, api_key_override) | |
add_memory_entry(user_input, metrics, bot_response) | |
summary = f"User:\"{user_input}\"\nAI:\"{bot_response}\"\nMetrics(takeaway):{metrics.get('takeaway','N/A')},Success:{metrics.get('response_success_score','N/A')}" | |
existing_rules_ctx = "\n".join([f"- \"{r}\"" for r in retrieve_rules_semantic(f"{summary}\n{user_input}", k=10)]) or "No existing rules context." | |
insight_user_prompt = prompts.get_insight_user_prompt(summary, existing_rules_ctx, insights_reflected) | |
insight_msgs = [{"role":"system", "content":prompts.INSIGHT_SYSTEM_PROMPT}, {"role":"user", "content":insight_user_prompt}] | |
insight_prov, insight_model_disp = provider, model_disp_name | |
insight_env_model = os.getenv("INSIGHT_MODEL_OVERRIDE") | |
if insight_env_model and "/" in insight_env_model: | |
i_p, i_id = insight_env_model.split('/', 1) | |
i_d_n = next((dn for dn, mid in MODELS_BY_PROVIDER.get(i_p.lower(), {}).get("models", {}).items() if mid == i_id), None) | |
if i_d_n: insight_prov, insight_model_disp = i_p, i_d_n | |
raw_ops_xml = "".join(list(call_model_stream(provider=insight_prov, model_display_name=insight_model_disp, messages=insight_msgs, api_key_override=api_key_override, temperature=0.0, max_tokens=3500))).strip() | |
xml_match = re.search(r"<operations_list>.*</operations_list>", raw_ops_xml, re.DOTALL | re.IGNORECASE) | |
if not xml_match: | |
logger.info(f"LEARNING [{task_id}]: No <operations_list> XML found.") | |
return | |
ops_data_list = [] | |
root = ET.fromstring(xml_match.group(0)) | |
for op_el in root.findall("operation"): | |
action = op_el.find("action").text.strip().lower() if op_el.find("action") is not None and op_el.find("action").text else None | |
insight = op_el.find("insight").text.strip() if op_el.find("insight") is not None and op_el.find("insight").text else None | |
old_insight = op_el.find("old_insight_to_replace").text.strip() if op_el.find("old_insight_to_replace") is not None and op_el.find("old_insight_to_replace").text else None | |
if action and insight: ops_data_list.append({"action": action, "insight": insight, "old_insight_to_replace": old_insight}) | |
processed_count, core_learnings = 0, [] | |
for op_data in ops_data_list: | |
if not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\]", op_data["insight"], re.I|re.DOTALL): continue | |
if op_data["action"] == "add": | |
success, _ = add_rule_entry(op_data["insight"]) | |
if success: processed_count += 1; | |
elif op_data["action"] == "update" and op_data["old_insight_to_replace"]: | |
if remove_rule_entry(op_data["old_insight_to_replace"]): | |
logger.info(f"LEARNING [{task_id}]: Removed old rule for update: '{op_data['old_insight_to_replace'][:50]}...'") | |
success, _ = add_rule_entry(op_data["insight"]) | |
if success: processed_count += 1; | |
if op_data["insight"].upper().startswith("[CORE_RULE"): | |
core_learnings.append(op_data["insight"]) | |
if core_learnings: | |
learning_digest = "SYSTEM CORE LEARNING DIGEST:\n" + "\n".join(core_learnings) | |
add_memory_entry(user_input="SYSTEM_INTERNAL_REFLECTION", metrics={"takeaway": "Core knowledge refined.", "type": "SYSTEM_REFLECTION"}, bot_response=learning_digest) | |
logger.info(f"LEARNING [{task_id}]: Processed {processed_count}/{len(ops_data_list)} insight ops. Total time: {time.time() - learning_start_time:.2f}s") | |
except Exception as e: | |
logger.error(f"LEARNING [{task_id}]: CRITICAL ERROR in learning task: {e}", exc_info=True) |