File size: 18,488 Bytes
129400d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import os
import json
import logging
import tempfile
from dotenv import load_dotenv
import gradio as gr

load_dotenv()

MEMORY_STORAGE_TYPE = "HF_DATASET"
HF_DATASET_MEMORY_REPO = "broadfield-dev/ai-brain"
HF_DATASET_RULES_REPO = "broadfield-dev/ai-rules"

os.environ['STORAGE_BACKEND'] = MEMORY_STORAGE_TYPE
if MEMORY_STORAGE_TYPE == "HF_DATASET":
    os.environ['HF_MEMORY_DATASET_REPO'] = HF_DATASET_MEMORY_REPO
    os.environ['HF_RULES_DATASET_REPO'] = HF_DATASET_RULES_REPO

from model_logic import get_available_providers, get_model_display_names_for_provider, get_default_model_display_name_for_provider
from memory_logic import (
    initialize_memory_system, add_memory_entry, get_all_memories_cached, clear_all_memory_data_backend,
    add_rule_entry, remove_rule_entry, get_all_rules_cached, clear_all_rules_data_backend,
    save_faiss_indices_to_disk, STORAGE_BACKEND as MEMORY_STORAGE_BACKEND, SQLITE_DB_PATH as MEMORY_SQLITE_PATH,
    HF_MEMORY_DATASET_REPO as MEMORY_HF_MEM_REPO, HF_RULES_DATASET_REPO as MEMORY_HF_RULES_REPO
)
from tools.orchestrator import orchestrate_and_respond
from learning import perform_post_interaction_learning
from utils import load_rules_from_file, load_memories_from_file
from prompts import DEFAULT_SYSTEM_PROMPT

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(threadName)s - %(message)s')
logger = logging.getLogger(__name__)
for lib_name in ["urllib3", "requests", "huggingface_hub", "PIL.PngImagePlugin", "matplotlib", "gradio_client.client", "multipart.multipart", "httpx", "sentence_transformers", "faiss", "datasets"]:
    if logging.getLogger(lib_name): logging.getLogger(lib_name).setLevel(logging.WARNING)

MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", 7))
LOAD_RULES_FILE = os.getenv("LOAD_RULES_FILE")
LOAD_MEMORIES_FILE = os.getenv("LOAD_MEMORIES_FILE")
current_chat_session_history = []

def handle_gradio_chat_submit(user_msg_txt: str, gr_hist_list: list, sel_prov_name: str, sel_model_disp_name: str, ui_api_key: str|None, cust_sys_prompt: str):
    global current_chat_session_history
    cleared_input, updated_gr_hist, status_txt = "", list(gr_hist_list), "Initializing..."
    updated_rules_text = ui_refresh_rules_display_fn()
    updated_mems_json = ui_refresh_memories_display_fn()
    def_detect_out_md = gr.Markdown(visible=False)
    def_fmt_out_txt = gr.Textbox(value="*Waiting...*", interactive=True, show_copy_button=True)
    def_dl_btn = gr.DownloadButton(interactive=False, value=None, visible=False)

    if not user_msg_txt.strip():
        status_txt = "Error: Empty message."
        updated_gr_hist.append((user_msg_txt or "(Empty)", status_txt))
        yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, updated_rules_text, updated_mems_json)
        return

    updated_gr_hist.append((user_msg_txt, "<i>Thinking...</i>"))
    yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, updated_rules_text, updated_mems_json)

    internal_hist = list(current_chat_session_history); internal_hist.append({"role": "user", "content": user_msg_txt})
    hist_len_check = MAX_HISTORY_TURNS * 2
    if len(internal_hist) > hist_len_check:
         current_chat_session_history = internal_hist[-(MAX_HISTORY_TURNS * 2):]
         internal_hist = list(current_chat_session_history)

    final_bot_resp_acc, insights_used_parsed = "", []
    temp_dl_file_path = None

    try:
        processor_gen = orchestrate_and_respond(user_input=user_msg_txt, provider_name=sel_prov_name, model_display_name=sel_model_disp_name, chat_history_for_prompt=internal_hist, custom_system_prompt=cust_sys_prompt.strip() or None, ui_api_key_override=ui_api_key.strip() if ui_api_key else None)
        curr_bot_disp_msg = ""
        for upd_type, upd_data in processor_gen:
            if upd_type == "status":
                status_txt = upd_data
                if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt:
                    updated_gr_hist[-1] = (user_msg_txt, f"{curr_bot_disp_msg} <i>{status_txt}</i>" if curr_bot_disp_msg else f"<i>{status_txt}</i>")
            elif upd_type == "response_chunk":
                curr_bot_disp_msg += upd_data
                if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt:
                    updated_gr_hist[-1] = (user_msg_txt, curr_bot_disp_msg)
            elif upd_type == "final_response_and_insights":
                final_bot_resp_acc, insights_used_parsed = upd_data["response"], upd_data["insights_used"]
                status_txt = "Response generated. Processing learning..."
                if not curr_bot_disp_msg and final_bot_resp_acc : curr_bot_disp_msg = final_bot_resp_acc
                if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt: updated_gr_hist[-1] = (user_msg_txt, curr_bot_disp_msg or "(No text)")
                def_fmt_out_txt = gr.Textbox(value=curr_bot_disp_msg, interactive=True, show_copy_button=True)
                if curr_bot_disp_msg and not curr_bot_disp_msg.startswith("Error:"):
                    with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".md", encoding='utf-8') as tmpfile:
                        tmpfile.write(curr_bot_disp_msg)
                        temp_dl_file_path = tmpfile.name
                    def_dl_btn = gr.DownloadButton(value=temp_dl_file_path, visible=True, interactive=True)
                insights_md_content = "### Insights Considered (Pre-Response):\n" + ("\n".join([f"- **[{i.get('type','N/A')}|{i.get('score','N/A')}]** {i.get('text','N/A')[:100]}..." for i in insights_used_parsed[:3]]) if insights_used_parsed else "*None specific.*")
                def_detect_out_md = gr.Markdown(value=insights_md_content, visible=bool(insights_used_parsed))

            yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, updated_rules_text, updated_mems_json)
            if upd_type == "final_response_and_insights": break
    except Exception as e:
        logger.error(f"Chat handler error: {e}", exc_info=True); status_txt = f"Error: {str(e)[:100]}"
        error_message_for_chat = f"Sorry, an error occurred: {str(e)[:100]}"
        if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt: updated_gr_hist[-1] = (user_msg_txt, error_message_for_chat)
        else: updated_gr_hist.append((user_msg_txt, error_message_for_chat))
        yield (cleared_input, updated_gr_hist, status_txt, gr.Markdown(value="*Error processing request.*", visible=True), gr.Textbox(value=error_message_for_chat, interactive=True), def_dl_btn, ui_refresh_rules_display_fn(), ui_refresh_memories_display_fn())
        if temp_dl_file_path and os.path.exists(temp_dl_file_path): os.unlink(temp_dl_file_path)
        return

    if final_bot_resp_acc and not final_bot_resp_acc.startswith("Error:"):
        current_chat_session_history.extend([{"role": "user", "content": user_msg_txt}, {"role": "assistant", "content": final_bot_resp_acc}])
        if len(current_chat_session_history) > MAX_HISTORY_TURNS * 2: current_chat_session_history = current_chat_session_history[-(MAX_HISTORY_TURNS * 2):]
        status_txt = "<i>[Performing post-interaction learning...]</i>"
        yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, ui_refresh_rules_display_fn(), ui_refresh_memories_display_fn())
        try:
            perform_post_interaction_learning(user_input=user_msg_txt, bot_response=final_bot_resp_acc, provider=sel_prov_name, model_disp_name=sel_model_disp_name, insights_reflected=insights_used_parsed, api_key_override=ui_api_key.strip() if ui_api_key else None)
            status_txt = "Response & Learning Complete."
        except Exception as e_learn:
            logger.error(f"Error during post-interaction learning: {e_learn}", exc_info=True)
            status_txt = "Response complete. Error during learning."
    else: status_txt = final_bot_resp_acc or "Processing finished; no valid response."

    updated_rules_text = ui_refresh_rules_display_fn()
    updated_mems_json = ui_refresh_memories_display_fn()
    yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, updated_rules_text, updated_mems_json)

    if temp_dl_file_path and os.path.exists(temp_dl_file_path): os.unlink(temp_dl_file_path)

def ui_refresh_rules_display_fn(): return "\n\n---\n\n".join(get_all_rules_cached()) or "No rules found."
def ui_download_rules_action_fn():
    rules_content = "\n\n---\n\n".join(get_all_rules_cached())
    if not rules_content.strip():
        gr.Warning("No rules to download.")
        return gr.DownloadButton(value=None, interactive=False, label="No Rules")
    with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt", encoding='utf-8') as tmpfile:
        tmpfile.write(rules_content)
        return tmpfile.name
def ui_upload_rules_action_fn(uploaded_file_obj, progress=gr.Progress()):
    if not uploaded_file_obj: return "No file provided."
    added, skipped, errors = load_rules_from_file(uploaded_file_obj.name, progress_callback=lambda p, d: progress(p, desc=d))
    return f"Rules Upload: Added: {added}, Skipped (duplicates): {skipped}, Errors: {errors}."

def ui_refresh_memories_display_fn(): return get_all_memories_cached() or []
def ui_download_memories_action_fn():
    memories = get_all_memories_cached()
    if not memories:
        gr.Warning("No memories to download.")
        return gr.DownloadButton(value=None, interactive=False, label="No Memories")
    jsonl_content = "\n".join([json.dumps(mem) for mem in memories])
    with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl", encoding='utf-g') as tmpfile:
        tmpfile.write(jsonl_content)
        return tmpfile.name
def ui_upload_memories_action_fn(uploaded_file_obj, progress=gr.Progress()):
    if not uploaded_file_obj: return "No file provided."
    added, format_err, save_err = load_memories_from_file(uploaded_file_obj.name, progress_callback=lambda p, d: progress(p, desc=d))
    return f"Memories Upload: Added: {added}, Format Errors: {format_err}, Save Errors: {save_err}."

def save_edited_rules_action_fn(edited_rules_text: str, progress=gr.Progress()):
    if not edited_rules_text.strip(): return "No rules text to save."
    potential_rules = edited_rules_text.split("\n\n---\n\n")
    if len(potential_rules) == 1 and "\n" in edited_rules_text:
        potential_rules = [r.strip() for r in edited_rules_text.splitlines() if r.strip()]
    unique_rules = sorted(list(set(filter(None, [r.strip() for r in potential_rules]))))
    if not unique_rules: return "No unique, non-empty rules found."
    added, skipped, errors, total = 0, 0, 0, len(unique_rules)
    progress(0, desc=f"Saving {total} unique rules...")
    for idx, rule_text in enumerate(unique_rules):
        success, status_msg = add_rule_entry(rule_text)
        if success: added += 1
        elif status_msg == "duplicate": skipped += 1
        else: errors += 1
        progress((idx + 1) / total, desc=f"Processed {idx+1}/{total} rules...")
    return f"Editor Save: Added: {added}, Skipped (duplicates): {skipped}, Errors: {errors} from {total} unique rules."

def app_load_fn():
    logger.info("App loading. Initializing systems...")
    initialize_memory_system()
    rules_added, rules_skipped, rules_errors = load_rules_from_file(LOAD_RULES_FILE)
    mems_added, mems_format_errors, mems_save_errors = load_memories_from_file(LOAD_MEMORIES_FILE)
    status = f"Ready. Rules loaded: {rules_added}. Memories loaded: {mems_added}."
    return (status, ui_refresh_rules_display_fn(), ui_refresh_memories_display_fn(), gr.Markdown(visible=False), gr.Textbox(value="*Waiting...*", interactive=True), gr.DownloadButton(interactive=False, visible=False))

with gr.Blocks(theme=gr.themes.Soft(), css=".gr-button { margin: 5px; } .status-text { font-size: 0.9em; color: #555; }") as demo:
    gr.Markdown("# ๐Ÿค– AI Research Agent")
    with gr.Row(variant="compact"):
        agent_stat_tb = gr.Textbox(label="Agent Status", value="Initializing...", interactive=False, elem_classes=["status-text"], scale=4)
        with gr.Column(scale=1, min_width=150):
            memory_backend_info_tb = gr.Textbox(label="Memory Backend", value=MEMORY_STORAGE_BACKEND, interactive=False)
            hf_repos_display = gr.Textbox(label="HF Repos", value=f"M: {MEMORY_HF_MEM_REPO}, R: {MEMORY_HF_RULES_REPO}", interactive=False, visible=MEMORY_STORAGE_BACKEND == "HF_DATASET")
    with gr.Row():
        with gr.Sidebar():
            gr.Markdown("## โš™๏ธ Configuration")
            with gr.Group():
                api_key_tb = gr.Textbox(label="API Key (Override)", type="password", placeholder="Uses .env if blank")
                available_providers = get_available_providers()
                default_provider = available_providers[0] if available_providers else None
                prov_sel_dd = gr.Dropdown(label="AI Provider", choices=available_providers, value=default_provider, interactive=True)
                model_sel_dd = gr.Dropdown(label="AI Model", choices=get_model_display_names_for_provider(default_provider) if default_provider else [], value=get_default_model_display_name_for_provider(default_provider), interactive=True)
            with gr.Group():
                sys_prompt_tb = gr.Textbox(label="System Prompt", lines=8, value=DEFAULT_SYSTEM_PROMPT, interactive=True)
            if MEMORY_STORAGE_BACKEND == "RAM":
                save_faiss_sidebar_btn = gr.Button("Save FAISS Indices", variant="secondary")
        with gr.Column(scale=3):
            with gr.Tabs():
                with gr.TabItem("๐Ÿ’ฌ Chat & Research"):
                    main_chat_disp = gr.Chatbot(height=400, show_copy_button=True, render_markdown=True)
                    with gr.Row(variant="compact"):
                        user_msg_tb = gr.Textbox(show_label=False, placeholder="Ask your research question...", scale=7, lines=1)
                        send_btn = gr.Button("Send", variant="primary", scale=1, min_width=100)
                    with gr.Accordion("๐Ÿ“ Detailed Response & Insights", open=False):
                        fmt_report_tb = gr.Textbox(label="Full AI Response", lines=8, interactive=True, show_copy_button=True)
                        dl_report_btn = gr.DownloadButton("Download Report", value=None, interactive=False, visible=False)
                        detect_out_md = gr.Markdown(visible=False)
                with gr.TabItem("๐Ÿง  Knowledge Base"):
                    with gr.Row(equal_height=True):
                        with gr.Column():
                            gr.Markdown("### ๐Ÿ“œ Rules Management")
                            rules_disp_ta = gr.TextArea(label="Current Rules", lines=10, interactive=True)
                            save_edited_rules_btn = gr.Button("๐Ÿ’พ Save Edited Text", variant="primary")
                            with gr.Row(variant="compact"):
                                dl_rules_btn = gr.DownloadButton("โฌ‡๏ธ Download Rules")
                                clear_rules_btn = gr.Button("๐Ÿ—‘๏ธ Clear All Rules", variant="stop")
                            upload_rules_fobj = gr.File(label="Upload Rules File (.txt/.jsonl)", file_types=[".txt", ".jsonl"])
                            rules_stat_tb = gr.Textbox(label="Rules Status", interactive=False, lines=1)
                        with gr.Column():
                            gr.Markdown("### ๐Ÿ“š Memories Management")
                            mems_disp_json = gr.JSON(label="Current Memories", value=[])
                            with gr.Row(variant="compact"):
                                dl_mems_btn = gr.DownloadButton("โฌ‡๏ธ Download Memories")
                                clear_mems_btn = gr.Button("๐Ÿ—‘๏ธ Clear All Memories", variant="stop")
                            upload_mems_fobj = gr.File(label="Upload Memories File (.json/.jsonl)", file_types=[".json", ".jsonl"])
                            mems_stat_tb = gr.Textbox(label="Memories Status", interactive=False, lines=1)

    prov_sel_dd.change(lambda p: gr.Dropdown(choices=get_model_display_names_for_provider(p), value=get_default_model_display_name_for_provider(p), interactive=True), prov_sel_dd, model_sel_dd)
    chat_ins = [user_msg_tb, main_chat_disp, prov_sel_dd, model_sel_dd, api_key_tb, sys_prompt_tb]
    chat_outs = [user_msg_tb, main_chat_disp, agent_stat_tb, detect_out_md, fmt_report_tb, dl_report_btn, rules_disp_ta, mems_disp_json]
    chat_event_args = {"fn": handle_gradio_chat_submit, "inputs": chat_ins, "outputs": chat_outs}
    send_btn.click(**chat_event_args)
    user_msg_tb.submit(**chat_event_args)

    dl_rules_btn.click(ui_download_rules_action_fn, None, dl_rules_btn)
    save_edited_rules_btn.click(save_edited_rules_action_fn, [rules_disp_ta], [rules_stat_tb]).then(ui_refresh_rules_display_fn, outputs=rules_disp_ta)
    upload_rules_fobj.upload(ui_upload_rules_action_fn, [upload_rules_fobj], [rules_stat_tb]).then(ui_refresh_rules_display_fn, outputs=rules_disp_ta)
    clear_rules_btn.click(lambda: ("Cleared." if clear_all_rules_data_backend() else "Error."), outputs=rules_stat_tb).then(ui_refresh_rules_display_fn, outputs=rules_disp_ta)

    dl_mems_btn.click(ui_download_memories_action_fn, None, dl_mems_btn)
    upload_mems_fobj.upload(ui_upload_memories_action_fn, [upload_mems_fobj], [mems_stat_tb]).then(ui_refresh_memories_display_fn, outputs=mems_disp_json)
    clear_mems_btn.click(lambda: ("Cleared." if clear_all_memory_data_backend() else "Error."), outputs=mems_stat_tb).then(ui_refresh_memories_display_fn, outputs=mems_disp_json)

    if MEMORY_STORAGE_BACKEND == "RAM" and 'save_faiss_sidebar_btn' in locals():
        save_faiss_sidebar_btn.click(lambda: (gr.Info("Saved FAISS to disk.") if save_faiss_indices_to_disk() is None else gr.Error("Error saving FAISS.")), None, None)

    app_load_outputs = [agent_stat_tb, rules_disp_ta, mems_disp_json, detect_out_md, fmt_report_tb, dl_report_btn]
    demo.load(fn=app_load_fn, inputs=None, outputs=app_load_outputs)

if __name__ == "__main__":
    app_port = int(os.getenv("GRADIO_PORT", 7860))
    app_server = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
    logger.info(f"Launching Gradio server: http://{app_server}:{app_port}")
    demo.queue().launch(server_name=app_server, server_port=app_port)