broadfield-dev commited on
Commit
500bf1e
·
verified ·
1 Parent(s): ae7ced4

Create learning.py

Browse files
Files changed (1) hide show
  1. learning.py +90 -0
learning.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ import logging
5
+ import time
6
+ import xml.etree.ElementTree as ET
7
+
8
+ from model_logic import call_model_stream, MODELS_BY_PROVIDER
9
+ from memory_logic import add_memory_entry, retrieve_rules_semantic, remove_rule_entry, add_rule_entry
10
+ import prompts
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def generate_interaction_metrics(user_input: str, bot_response: str, provider: str, model_display_name: str, api_key_override: str = None) -> dict:
15
+ metric_start_time = time.time()
16
+ logger.info(f"Generating metrics with: {provider}/{model_display_name}")
17
+ metric_prompt_content = prompts.get_metrics_user_prompt(user_input, bot_response)
18
+ metric_messages = [{"role": "system", "content": prompts.METRICS_SYSTEM_PROMPT}, {"role": "user", "content": metric_prompt_content}]
19
+ try:
20
+ metrics_provider, metrics_model_display = provider, model_display_name
21
+ metrics_model_env = os.getenv("METRICS_MODEL")
22
+ if metrics_model_env and "/" in metrics_model_env:
23
+ m_prov, m_id = metrics_model_env.split('/', 1)
24
+ m_disp_name = next((dn for dn, mid in MODELS_BY_PROVIDER.get(m_prov.lower(), {}).get("models", {}).items() if mid == m_id), None)
25
+ if m_disp_name: metrics_provider, metrics_model_display = m_prov, m_disp_name
26
+ response_chunks = list(call_model_stream(provider=metrics_provider, model_display_name=metrics_model_display, messages=metric_messages, api_key_override=api_key_override, temperature=0.05, max_tokens=200))
27
+ resp_str = "".join(response_chunks).strip()
28
+ json_match = re.search(r"\{.*\}", resp_str, re.DOTALL)
29
+ if json_match: metrics_data = json.loads(json_match.group(0))
30
+ else:
31
+ logger.warning(f"METRICS_GEN: Non-JSON response from {metrics_provider}/{metrics_model_display}: '{resp_str}'")
32
+ return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": "metrics format error"}
33
+ parsed_metrics = {"takeaway": metrics_data.get("takeaway", "N/A"), "response_success_score": float(metrics_data.get("response_success_score", 0.5)), "future_confidence_score": float(metrics_data.get("future_confidence_score", 0.5))}
34
+ logger.info(f"METRICS_GEN: Generated in {time.time() - metric_start_time:.2f}s. Data: {parsed_metrics}")
35
+ return parsed_metrics
36
+ except Exception as e:
37
+ logger.error(f"METRICS_GEN Error: {e}")
38
+ return {"takeaway": "N/A", "response_success_score": 0.5, "future_confidence_score": 0.5, "error": str(e)}
39
+
40
+ def perform_post_interaction_learning(user_input: str, bot_response: str, provider: str, model_disp_name: str, insights_reflected: list[dict], api_key_override: str = None):
41
+ task_id = os.urandom(4).hex()
42
+ logger.info(f"LEARNING [{task_id}]: Start User='{user_input[:40]}...'")
43
+ learning_start_time = time.time()
44
+ try:
45
+ metrics = generate_interaction_metrics(user_input, bot_response, provider, model_disp_name, api_key_override)
46
+ add_memory_entry(user_input, metrics, bot_response)
47
+ summary = f"User:\"{user_input}\"\nAI:\"{bot_response}\"\nMetrics(takeaway):{metrics.get('takeaway','N/A')},Success:{metrics.get('response_success_score','N/A')}"
48
+ existing_rules_ctx = "\n".join([f"- \"{r}\"" for r in retrieve_rules_semantic(f"{summary}\n{user_input}", k=10)]) or "No existing rules context."
49
+ insight_user_prompt = prompts.get_insight_user_prompt(summary, existing_rules_ctx, insights_reflected)
50
+ insight_msgs = [{"role":"system", "content":prompts.INSIGHT_SYSTEM_PROMPT}, {"role":"user", "content":insight_user_prompt}]
51
+ insight_prov, insight_model_disp = provider, model_disp_name
52
+ insight_env_model = os.getenv("INSIGHT_MODEL_OVERRIDE")
53
+ if insight_env_model and "/" in insight_env_model:
54
+ i_p, i_id = insight_env_model.split('/', 1)
55
+ i_d_n = next((dn for dn, mid in MODELS_BY_PROVIDER.get(i_p.lower(), {}).get("models", {}).items() if mid == i_id), None)
56
+ if i_d_n: insight_prov, insight_model_disp = i_p, i_d_n
57
+ raw_ops_xml = "".join(list(call_model_stream(provider=insight_prov, model_display_name=insight_model_disp, messages=insight_msgs, api_key_override=api_key_override, temperature=0.0, max_tokens=3500))).strip()
58
+ xml_match = re.search(r"<operations_list>.*</operations_list>", raw_ops_xml, re.DOTALL | re.IGNORECASE)
59
+ if not xml_match:
60
+ logger.info(f"LEARNING [{task_id}]: No <operations_list> XML found.")
61
+ return
62
+
63
+ ops_data_list = []
64
+ root = ET.fromstring(xml_match.group(0))
65
+ for op_el in root.findall("operation"):
66
+ action = op_el.find("action").text.strip().lower() if op_el.find("action") is not None and op_el.find("action").text else None
67
+ insight = op_el.find("insight").text.strip() if op_el.find("insight") is not None and op_el.find("insight").text else None
68
+ old_insight = op_el.find("old_insight_to_replace").text.strip() if op_el.find("old_insight_to_replace") is not None and op_el.find("old_insight_to_replace").text else None
69
+ if action and insight: ops_data_list.append({"action": action, "insight": insight, "old_insight_to_replace": old_insight})
70
+
71
+ processed_count, core_learnings = 0, []
72
+ for op_data in ops_data_list:
73
+ if not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\]", op_data["insight"], re.I|re.DOTALL): continue
74
+ if op_data["action"] == "add":
75
+ success, _ = add_rule_entry(op_data["insight"])
76
+ if success: processed_count += 1;
77
+ elif op_data["action"] == "update" and op_data["old_insight_to_replace"]:
78
+ if remove_rule_entry(op_data["old_insight_to_replace"]):
79
+ logger.info(f"LEARNING [{task_id}]: Removed old rule for update: '{op_data['old_insight_to_replace'][:50]}...'")
80
+ success, _ = add_rule_entry(op_data["insight"])
81
+ if success: processed_count += 1;
82
+ if op_data["insight"].upper().startswith("[CORE_RULE"):
83
+ core_learnings.append(op_data["insight"])
84
+
85
+ if core_learnings:
86
+ learning_digest = "SYSTEM CORE LEARNING DIGEST:\n" + "\n".join(core_learnings)
87
+ add_memory_entry(user_input="SYSTEM_INTERNAL_REFLECTION", metrics={"takeaway": "Core knowledge refined.", "type": "SYSTEM_REFLECTION"}, bot_response=learning_digest)
88
+ logger.info(f"LEARNING [{task_id}]: Processed {processed_count}/{len(ops_data_list)} insight ops. Total time: {time.time() - learning_start_time:.2f}s")
89
+ except Exception as e:
90
+ logger.error(f"LEARNING [{task_id}]: CRITICAL ERROR in learning task: {e}", exc_info=True)