Spaces:
Runtime error
Runtime error
Create learning.py
Browse files- learning.py +90 -0
learning.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
import xml.etree.ElementTree as ET
|
7 |
+
|
8 |
+
from model_logic import call_model_stream, MODELS_BY_PROVIDER
|
9 |
+
from memory_logic import add_memory_entry, retrieve_rules_semantic, remove_rule_entry, add_rule_entry
|
10 |
+
import prompts
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
def generate_interaction_metrics(user_input: str, bot_response: str, provider: str, model_display_name: str, api_key_override: str = None) -> dict:
|
15 |
+
metric_start_time = time.time()
|
16 |
+
logger.info(f"Generating metrics with: {provider}/{model_display_name}")
|
17 |
+
metric_prompt_content = prompts.get_metrics_user_prompt(user_input, bot_response)
|
18 |
+
metric_messages = [{"role": "system", "content": prompts.METRICS_SYSTEM_PROMPT}, {"role": "user", "content": metric_prompt_content}]
|
19 |
+
try:
|
20 |
+
metrics_provider, metrics_model_display = provider, model_display_name
|
21 |
+
metrics_model_env = os.getenv("METRICS_MODEL")
|
22 |
+
if metrics_model_env and "/" in metrics_model_env:
|
23 |
+
m_prov, m_id = metrics_model_env.split('/', 1)
|
24 |
+
m_disp_name = next((dn for dn, mid in MODELS_BY_PROVIDER.get(m_prov.lower(), {}).get("models", {}).items() if mid == m_id), None)
|
25 |
+
if m_disp_name: metrics_provider, metrics_model_display = m_prov, m_disp_name
|
26 |
+
response_chunks = list(call_model_stream(provider=metrics_provider, model_display_name=metrics_model_display, messages=metric_messages, api_key_override=api_key_override, temperature=0.05, max_tokens=200))
|
27 |
+
resp_str = "".join(response_chunks).strip()
|
28 |
+
json_match = re.search(r"\{.*\}", resp_str, re.DOTALL)
|
29 |
+
if json_match: metrics_data = json.loads(json_match.group(0))
|
30 |
+
else:
|
31 |
+
logger.warning(f"METRICS_GEN: Non-JSON response from {metrics_provider}/{metrics_model_display}: '{resp_str}'")
|
32 |
+
return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": "metrics format error"}
|
33 |
+
parsed_metrics = {"takeaway": metrics_data.get("takeaway", "N/A"), "response_success_score": float(metrics_data.get("response_success_score", 0.5)), "future_confidence_score": float(metrics_data.get("future_confidence_score", 0.5))}
|
34 |
+
logger.info(f"METRICS_GEN: Generated in {time.time() - metric_start_time:.2f}s. Data: {parsed_metrics}")
|
35 |
+
return parsed_metrics
|
36 |
+
except Exception as e:
|
37 |
+
logger.error(f"METRICS_GEN Error: {e}")
|
38 |
+
return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": str(e)}
|
39 |
+
|
40 |
+
def perform_post_interaction_learning(user_input: str, bot_response: str, provider: str, model_disp_name: str, insights_reflected: list[dict], api_key_override: str = None):
|
41 |
+
task_id = os.urandom(4).hex()
|
42 |
+
logger.info(f"LEARNING [{task_id}]: Start User='{user_input[:40]}...'")
|
43 |
+
learning_start_time = time.time()
|
44 |
+
try:
|
45 |
+
metrics = generate_interaction_metrics(user_input, bot_response, provider, model_disp_name, api_key_override)
|
46 |
+
add_memory_entry(user_input, metrics, bot_response)
|
47 |
+
summary = f"User:\"{user_input}\"\nAI:\"{bot_response}\"\nMetrics(takeaway):{metrics.get('takeaway','N/A')},Success:{metrics.get('response_success_score','N/A')}"
|
48 |
+
existing_rules_ctx = "\n".join([f"- \"{r}\"" for r in retrieve_rules_semantic(f"{summary}\n{user_input}", k=10)]) or "No existing rules context."
|
49 |
+
insight_user_prompt = prompts.get_insight_user_prompt(summary, existing_rules_ctx, insights_reflected)
|
50 |
+
insight_msgs = [{"role":"system", "content":prompts.INSIGHT_SYSTEM_PROMPT}, {"role":"user", "content":insight_user_prompt}]
|
51 |
+
insight_prov, insight_model_disp = provider, model_disp_name
|
52 |
+
insight_env_model = os.getenv("INSIGHT_MODEL_OVERRIDE")
|
53 |
+
if insight_env_model and "/" in insight_env_model:
|
54 |
+
i_p, i_id = insight_env_model.split('/', 1)
|
55 |
+
i_d_n = next((dn for dn, mid in MODELS_BY_PROVIDER.get(i_p.lower(), {}).get("models", {}).items() if mid == i_id), None)
|
56 |
+
if i_d_n: insight_prov, insight_model_disp = i_p, i_d_n
|
57 |
+
raw_ops_xml = "".join(list(call_model_stream(provider=insight_prov, model_display_name=insight_model_disp, messages=insight_msgs, api_key_override=api_key_override, temperature=0.0, max_tokens=3500))).strip()
|
58 |
+
xml_match = re.search(r"<operations_list>.*</operations_list>", raw_ops_xml, re.DOTALL | re.IGNORECASE)
|
59 |
+
if not xml_match:
|
60 |
+
logger.info(f"LEARNING [{task_id}]: No <operations_list> XML found.")
|
61 |
+
return
|
62 |
+
|
63 |
+
ops_data_list = []
|
64 |
+
root = ET.fromstring(xml_match.group(0))
|
65 |
+
for op_el in root.findall("operation"):
|
66 |
+
action = op_el.find("action").text.strip().lower() if op_el.find("action") is not None and op_el.find("action").text else None
|
67 |
+
insight = op_el.find("insight").text.strip() if op_el.find("insight") is not None and op_el.find("insight").text else None
|
68 |
+
old_insight = op_el.find("old_insight_to_replace").text.strip() if op_el.find("old_insight_to_replace") is not None and op_el.find("old_insight_to_replace").text else None
|
69 |
+
if action and insight: ops_data_list.append({"action": action, "insight": insight, "old_insight_to_replace": old_insight})
|
70 |
+
|
71 |
+
processed_count, core_learnings = 0, []
|
72 |
+
for op_data in ops_data_list:
|
73 |
+
if not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\]", op_data["insight"], re.I|re.DOTALL): continue
|
74 |
+
if op_data["action"] == "add":
|
75 |
+
success, _ = add_rule_entry(op_data["insight"])
|
76 |
+
if success: processed_count += 1;
|
77 |
+
elif op_data["action"] == "update" and op_data["old_insight_to_replace"]:
|
78 |
+
if remove_rule_entry(op_data["old_insight_to_replace"]):
|
79 |
+
logger.info(f"LEARNING [{task_id}]: Removed old rule for update: '{op_data['old_insight_to_replace'][:50]}...'")
|
80 |
+
success, _ = add_rule_entry(op_data["insight"])
|
81 |
+
if success: processed_count += 1;
|
82 |
+
if op_data["insight"].upper().startswith("[CORE_RULE"):
|
83 |
+
core_learnings.append(op_data["insight"])
|
84 |
+
|
85 |
+
if core_learnings:
|
86 |
+
learning_digest = "SYSTEM CORE LEARNING DIGEST:\n" + "\n".join(core_learnings)
|
87 |
+
add_memory_entry(user_input="SYSTEM_INTERNAL_REFLECTION", metrics={"takeaway": "Core knowledge refined.", "type": "SYSTEM_REFLECTION"}, bot_response=learning_digest)
|
88 |
+
logger.info(f"LEARNING [{task_id}]: Processed {processed_count}/{len(ops_data_list)} insight ops. Total time: {time.time() - learning_start_time:.2f}s")
|
89 |
+
except Exception as e:
|
90 |
+
logger.error(f"LEARNING [{task_id}]: CRITICAL ERROR in learning task: {e}", exc_info=True)
|