Spaces:
Runtime error
Runtime error
File size: 18,488 Bytes
129400d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import os
import json
import logging
import tempfile
from dotenv import load_dotenv
import gradio as gr
load_dotenv()
MEMORY_STORAGE_TYPE = "HF_DATASET"
HF_DATASET_MEMORY_REPO = "broadfield-dev/ai-brain"
HF_DATASET_RULES_REPO = "broadfield-dev/ai-rules"
os.environ['STORAGE_BACKEND'] = MEMORY_STORAGE_TYPE
if MEMORY_STORAGE_TYPE == "HF_DATASET":
os.environ['HF_MEMORY_DATASET_REPO'] = HF_DATASET_MEMORY_REPO
os.environ['HF_RULES_DATASET_REPO'] = HF_DATASET_RULES_REPO
from model_logic import get_available_providers, get_model_display_names_for_provider, get_default_model_display_name_for_provider
from memory_logic import (
initialize_memory_system, add_memory_entry, get_all_memories_cached, clear_all_memory_data_backend,
add_rule_entry, remove_rule_entry, get_all_rules_cached, clear_all_rules_data_backend,
save_faiss_indices_to_disk, STORAGE_BACKEND as MEMORY_STORAGE_BACKEND, SQLITE_DB_PATH as MEMORY_SQLITE_PATH,
HF_MEMORY_DATASET_REPO as MEMORY_HF_MEM_REPO, HF_RULES_DATASET_REPO as MEMORY_HF_RULES_REPO
)
from tools.orchestrator import orchestrate_and_respond
from learning import perform_post_interaction_learning
from utils import load_rules_from_file, load_memories_from_file
from prompts import DEFAULT_SYSTEM_PROMPT
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(threadName)s - %(message)s')
logger = logging.getLogger(__name__)
for lib_name in ["urllib3", "requests", "huggingface_hub", "PIL.PngImagePlugin", "matplotlib", "gradio_client.client", "multipart.multipart", "httpx", "sentence_transformers", "faiss", "datasets"]:
if logging.getLogger(lib_name): logging.getLogger(lib_name).setLevel(logging.WARNING)
MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", 7))
LOAD_RULES_FILE = os.getenv("LOAD_RULES_FILE")
LOAD_MEMORIES_FILE = os.getenv("LOAD_MEMORIES_FILE")
current_chat_session_history = []
def handle_gradio_chat_submit(user_msg_txt: str, gr_hist_list: list, sel_prov_name: str, sel_model_disp_name: str, ui_api_key: str|None, cust_sys_prompt: str):
global current_chat_session_history
cleared_input, updated_gr_hist, status_txt = "", list(gr_hist_list), "Initializing..."
updated_rules_text = ui_refresh_rules_display_fn()
updated_mems_json = ui_refresh_memories_display_fn()
def_detect_out_md = gr.Markdown(visible=False)
def_fmt_out_txt = gr.Textbox(value="*Waiting...*", interactive=True, show_copy_button=True)
def_dl_btn = gr.DownloadButton(interactive=False, value=None, visible=False)
if not user_msg_txt.strip():
status_txt = "Error: Empty message."
updated_gr_hist.append((user_msg_txt or "(Empty)", status_txt))
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, updated_rules_text, updated_mems_json)
return
updated_gr_hist.append((user_msg_txt, "<i>Thinking...</i>"))
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, updated_rules_text, updated_mems_json)
internal_hist = list(current_chat_session_history); internal_hist.append({"role": "user", "content": user_msg_txt})
hist_len_check = MAX_HISTORY_TURNS * 2
if len(internal_hist) > hist_len_check:
current_chat_session_history = internal_hist[-(MAX_HISTORY_TURNS * 2):]
internal_hist = list(current_chat_session_history)
final_bot_resp_acc, insights_used_parsed = "", []
temp_dl_file_path = None
try:
processor_gen = orchestrate_and_respond(user_input=user_msg_txt, provider_name=sel_prov_name, model_display_name=sel_model_disp_name, chat_history_for_prompt=internal_hist, custom_system_prompt=cust_sys_prompt.strip() or None, ui_api_key_override=ui_api_key.strip() if ui_api_key else None)
curr_bot_disp_msg = ""
for upd_type, upd_data in processor_gen:
if upd_type == "status":
status_txt = upd_data
if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt:
updated_gr_hist[-1] = (user_msg_txt, f"{curr_bot_disp_msg} <i>{status_txt}</i>" if curr_bot_disp_msg else f"<i>{status_txt}</i>")
elif upd_type == "response_chunk":
curr_bot_disp_msg += upd_data
if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt:
updated_gr_hist[-1] = (user_msg_txt, curr_bot_disp_msg)
elif upd_type == "final_response_and_insights":
final_bot_resp_acc, insights_used_parsed = upd_data["response"], upd_data["insights_used"]
status_txt = "Response generated. Processing learning..."
if not curr_bot_disp_msg and final_bot_resp_acc : curr_bot_disp_msg = final_bot_resp_acc
if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt: updated_gr_hist[-1] = (user_msg_txt, curr_bot_disp_msg or "(No text)")
def_fmt_out_txt = gr.Textbox(value=curr_bot_disp_msg, interactive=True, show_copy_button=True)
if curr_bot_disp_msg and not curr_bot_disp_msg.startswith("Error:"):
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".md", encoding='utf-8') as tmpfile:
tmpfile.write(curr_bot_disp_msg)
temp_dl_file_path = tmpfile.name
def_dl_btn = gr.DownloadButton(value=temp_dl_file_path, visible=True, interactive=True)
insights_md_content = "### Insights Considered (Pre-Response):\n" + ("\n".join([f"- **[{i.get('type','N/A')}|{i.get('score','N/A')}]** {i.get('text','N/A')[:100]}..." for i in insights_used_parsed[:3]]) if insights_used_parsed else "*None specific.*")
def_detect_out_md = gr.Markdown(value=insights_md_content, visible=bool(insights_used_parsed))
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, updated_rules_text, updated_mems_json)
if upd_type == "final_response_and_insights": break
except Exception as e:
logger.error(f"Chat handler error: {e}", exc_info=True); status_txt = f"Error: {str(e)[:100]}"
error_message_for_chat = f"Sorry, an error occurred: {str(e)[:100]}"
if updated_gr_hist and updated_gr_hist[-1][0] == user_msg_txt: updated_gr_hist[-1] = (user_msg_txt, error_message_for_chat)
else: updated_gr_hist.append((user_msg_txt, error_message_for_chat))
yield (cleared_input, updated_gr_hist, status_txt, gr.Markdown(value="*Error processing request.*", visible=True), gr.Textbox(value=error_message_for_chat, interactive=True), def_dl_btn, ui_refresh_rules_display_fn(), ui_refresh_memories_display_fn())
if temp_dl_file_path and os.path.exists(temp_dl_file_path): os.unlink(temp_dl_file_path)
return
if final_bot_resp_acc and not final_bot_resp_acc.startswith("Error:"):
current_chat_session_history.extend([{"role": "user", "content": user_msg_txt}, {"role": "assistant", "content": final_bot_resp_acc}])
if len(current_chat_session_history) > MAX_HISTORY_TURNS * 2: current_chat_session_history = current_chat_session_history[-(MAX_HISTORY_TURNS * 2):]
status_txt = "<i>[Performing post-interaction learning...]</i>"
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, ui_refresh_rules_display_fn(), ui_refresh_memories_display_fn())
try:
perform_post_interaction_learning(user_input=user_msg_txt, bot_response=final_bot_resp_acc, provider=sel_prov_name, model_disp_name=sel_model_disp_name, insights_reflected=insights_used_parsed, api_key_override=ui_api_key.strip() if ui_api_key else None)
status_txt = "Response & Learning Complete."
except Exception as e_learn:
logger.error(f"Error during post-interaction learning: {e_learn}", exc_info=True)
status_txt = "Response complete. Error during learning."
else: status_txt = final_bot_resp_acc or "Processing finished; no valid response."
updated_rules_text = ui_refresh_rules_display_fn()
updated_mems_json = ui_refresh_memories_display_fn()
yield (cleared_input, updated_gr_hist, status_txt, def_detect_out_md, def_fmt_out_txt, def_dl_btn, updated_rules_text, updated_mems_json)
if temp_dl_file_path and os.path.exists(temp_dl_file_path): os.unlink(temp_dl_file_path)
def ui_refresh_rules_display_fn(): return "\n\n---\n\n".join(get_all_rules_cached()) or "No rules found."
def ui_download_rules_action_fn():
rules_content = "\n\n---\n\n".join(get_all_rules_cached())
if not rules_content.strip():
gr.Warning("No rules to download.")
return gr.DownloadButton(value=None, interactive=False, label="No Rules")
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt", encoding='utf-8') as tmpfile:
tmpfile.write(rules_content)
return tmpfile.name
def ui_upload_rules_action_fn(uploaded_file_obj, progress=gr.Progress()):
if not uploaded_file_obj: return "No file provided."
added, skipped, errors = load_rules_from_file(uploaded_file_obj.name, progress_callback=lambda p, d: progress(p, desc=d))
return f"Rules Upload: Added: {added}, Skipped (duplicates): {skipped}, Errors: {errors}."
def ui_refresh_memories_display_fn(): return get_all_memories_cached() or []
def ui_download_memories_action_fn():
memories = get_all_memories_cached()
if not memories:
gr.Warning("No memories to download.")
return gr.DownloadButton(value=None, interactive=False, label="No Memories")
jsonl_content = "\n".join([json.dumps(mem) for mem in memories])
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl", encoding='utf-g') as tmpfile:
tmpfile.write(jsonl_content)
return tmpfile.name
def ui_upload_memories_action_fn(uploaded_file_obj, progress=gr.Progress()):
if not uploaded_file_obj: return "No file provided."
added, format_err, save_err = load_memories_from_file(uploaded_file_obj.name, progress_callback=lambda p, d: progress(p, desc=d))
return f"Memories Upload: Added: {added}, Format Errors: {format_err}, Save Errors: {save_err}."
def save_edited_rules_action_fn(edited_rules_text: str, progress=gr.Progress()):
if not edited_rules_text.strip(): return "No rules text to save."
potential_rules = edited_rules_text.split("\n\n---\n\n")
if len(potential_rules) == 1 and "\n" in edited_rules_text:
potential_rules = [r.strip() for r in edited_rules_text.splitlines() if r.strip()]
unique_rules = sorted(list(set(filter(None, [r.strip() for r in potential_rules]))))
if not unique_rules: return "No unique, non-empty rules found."
added, skipped, errors, total = 0, 0, 0, len(unique_rules)
progress(0, desc=f"Saving {total} unique rules...")
for idx, rule_text in enumerate(unique_rules):
success, status_msg = add_rule_entry(rule_text)
if success: added += 1
elif status_msg == "duplicate": skipped += 1
else: errors += 1
progress((idx + 1) / total, desc=f"Processed {idx+1}/{total} rules...")
return f"Editor Save: Added: {added}, Skipped (duplicates): {skipped}, Errors: {errors} from {total} unique rules."
def app_load_fn():
logger.info("App loading. Initializing systems...")
initialize_memory_system()
rules_added, rules_skipped, rules_errors = load_rules_from_file(LOAD_RULES_FILE)
mems_added, mems_format_errors, mems_save_errors = load_memories_from_file(LOAD_MEMORIES_FILE)
status = f"Ready. Rules loaded: {rules_added}. Memories loaded: {mems_added}."
return (status, ui_refresh_rules_display_fn(), ui_refresh_memories_display_fn(), gr.Markdown(visible=False), gr.Textbox(value="*Waiting...*", interactive=True), gr.DownloadButton(interactive=False, visible=False))
with gr.Blocks(theme=gr.themes.Soft(), css=".gr-button { margin: 5px; } .status-text { font-size: 0.9em; color: #555; }") as demo:
gr.Markdown("# ๐ค AI Research Agent")
with gr.Row(variant="compact"):
agent_stat_tb = gr.Textbox(label="Agent Status", value="Initializing...", interactive=False, elem_classes=["status-text"], scale=4)
with gr.Column(scale=1, min_width=150):
memory_backend_info_tb = gr.Textbox(label="Memory Backend", value=MEMORY_STORAGE_BACKEND, interactive=False)
hf_repos_display = gr.Textbox(label="HF Repos", value=f"M: {MEMORY_HF_MEM_REPO}, R: {MEMORY_HF_RULES_REPO}", interactive=False, visible=MEMORY_STORAGE_BACKEND == "HF_DATASET")
with gr.Row():
with gr.Sidebar():
gr.Markdown("## โ๏ธ Configuration")
with gr.Group():
api_key_tb = gr.Textbox(label="API Key (Override)", type="password", placeholder="Uses .env if blank")
available_providers = get_available_providers()
default_provider = available_providers[0] if available_providers else None
prov_sel_dd = gr.Dropdown(label="AI Provider", choices=available_providers, value=default_provider, interactive=True)
model_sel_dd = gr.Dropdown(label="AI Model", choices=get_model_display_names_for_provider(default_provider) if default_provider else [], value=get_default_model_display_name_for_provider(default_provider), interactive=True)
with gr.Group():
sys_prompt_tb = gr.Textbox(label="System Prompt", lines=8, value=DEFAULT_SYSTEM_PROMPT, interactive=True)
if MEMORY_STORAGE_BACKEND == "RAM":
save_faiss_sidebar_btn = gr.Button("Save FAISS Indices", variant="secondary")
with gr.Column(scale=3):
with gr.Tabs():
with gr.TabItem("๐ฌ Chat & Research"):
main_chat_disp = gr.Chatbot(height=400, show_copy_button=True, render_markdown=True)
with gr.Row(variant="compact"):
user_msg_tb = gr.Textbox(show_label=False, placeholder="Ask your research question...", scale=7, lines=1)
send_btn = gr.Button("Send", variant="primary", scale=1, min_width=100)
with gr.Accordion("๐ Detailed Response & Insights", open=False):
fmt_report_tb = gr.Textbox(label="Full AI Response", lines=8, interactive=True, show_copy_button=True)
dl_report_btn = gr.DownloadButton("Download Report", value=None, interactive=False, visible=False)
detect_out_md = gr.Markdown(visible=False)
with gr.TabItem("๐ง Knowledge Base"):
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("### ๐ Rules Management")
rules_disp_ta = gr.TextArea(label="Current Rules", lines=10, interactive=True)
save_edited_rules_btn = gr.Button("๐พ Save Edited Text", variant="primary")
with gr.Row(variant="compact"):
dl_rules_btn = gr.DownloadButton("โฌ๏ธ Download Rules")
clear_rules_btn = gr.Button("๐๏ธ Clear All Rules", variant="stop")
upload_rules_fobj = gr.File(label="Upload Rules File (.txt/.jsonl)", file_types=[".txt", ".jsonl"])
rules_stat_tb = gr.Textbox(label="Rules Status", interactive=False, lines=1)
with gr.Column():
gr.Markdown("### ๐ Memories Management")
mems_disp_json = gr.JSON(label="Current Memories", value=[])
with gr.Row(variant="compact"):
dl_mems_btn = gr.DownloadButton("โฌ๏ธ Download Memories")
clear_mems_btn = gr.Button("๐๏ธ Clear All Memories", variant="stop")
upload_mems_fobj = gr.File(label="Upload Memories File (.json/.jsonl)", file_types=[".json", ".jsonl"])
mems_stat_tb = gr.Textbox(label="Memories Status", interactive=False, lines=1)
prov_sel_dd.change(lambda p: gr.Dropdown(choices=get_model_display_names_for_provider(p), value=get_default_model_display_name_for_provider(p), interactive=True), prov_sel_dd, model_sel_dd)
chat_ins = [user_msg_tb, main_chat_disp, prov_sel_dd, model_sel_dd, api_key_tb, sys_prompt_tb]
chat_outs = [user_msg_tb, main_chat_disp, agent_stat_tb, detect_out_md, fmt_report_tb, dl_report_btn, rules_disp_ta, mems_disp_json]
chat_event_args = {"fn": handle_gradio_chat_submit, "inputs": chat_ins, "outputs": chat_outs}
send_btn.click(**chat_event_args)
user_msg_tb.submit(**chat_event_args)
dl_rules_btn.click(ui_download_rules_action_fn, None, dl_rules_btn)
save_edited_rules_btn.click(save_edited_rules_action_fn, [rules_disp_ta], [rules_stat_tb]).then(ui_refresh_rules_display_fn, outputs=rules_disp_ta)
upload_rules_fobj.upload(ui_upload_rules_action_fn, [upload_rules_fobj], [rules_stat_tb]).then(ui_refresh_rules_display_fn, outputs=rules_disp_ta)
clear_rules_btn.click(lambda: ("Cleared." if clear_all_rules_data_backend() else "Error."), outputs=rules_stat_tb).then(ui_refresh_rules_display_fn, outputs=rules_disp_ta)
dl_mems_btn.click(ui_download_memories_action_fn, None, dl_mems_btn)
upload_mems_fobj.upload(ui_upload_memories_action_fn, [upload_mems_fobj], [mems_stat_tb]).then(ui_refresh_memories_display_fn, outputs=mems_disp_json)
clear_mems_btn.click(lambda: ("Cleared." if clear_all_memory_data_backend() else "Error."), outputs=mems_stat_tb).then(ui_refresh_memories_display_fn, outputs=mems_disp_json)
if MEMORY_STORAGE_BACKEND == "RAM" and 'save_faiss_sidebar_btn' in locals():
save_faiss_sidebar_btn.click(lambda: (gr.Info("Saved FAISS to disk.") if save_faiss_indices_to_disk() is None else gr.Error("Error saving FAISS.")), None, None)
app_load_outputs = [agent_stat_tb, rules_disp_ta, mems_disp_json, detect_out_md, fmt_report_tb, dl_report_btn]
demo.load(fn=app_load_fn, inputs=None, outputs=app_load_outputs)
if __name__ == "__main__":
app_port = int(os.getenv("GRADIO_PORT", 7860))
app_server = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
logger.info(f"Launching Gradio server: http://{app_server}:{app_port}")
demo.queue().launch(server_name=app_server, server_port=app_port) |