|
import os |
|
import re |
|
import asyncio |
|
import gradio as gr |
|
from datetime import datetime |
|
import langdetect |
|
import RAG_Domain_know_doc |
|
from web_search import search_autism |
|
from RAG import rag_autism |
|
from openai import OpenAI |
|
from dotenv import load_dotenv |
|
import Old_Document |
|
import User_Specific_Documents |
|
from prompt_template import ( |
|
Prompt_template_translation, |
|
Prompt_template_LLM_Generation, |
|
Prompt_template_Reranker, |
|
Prompt_template_Wisal, |
|
Prompt_template_Halluciations, |
|
Prompt_template_paraphrasing, |
|
Prompt_template_Translate_to_original, |
|
Prompt_template_relevance, |
|
Prompt_template_User_document_prompt |
|
) |
|
|
|
|
|
|
|
env = os.getenv("ENVIRONMENT", "production") |
|
openai = OpenAI( |
|
api_key=DEEPINFRA_API_KEY, |
|
base_url="https://api.deepinfra.com/v1/openai", |
|
) |
|
SESSION_ID = "default" |
|
|
|
|
|
|
|
def call_llm(model: str, messages: list[dict], temperature: float = 0.0, **kwargs) -> str: |
|
resp = openai.chat.completions.create( |
|
model=model, |
|
messages=messages, |
|
temperature=temperature, |
|
**kwargs |
|
) |
|
return resp.choices[0].message.content.strip() |
|
|
|
def is_greeting(text: str) -> bool: |
|
return bool(re.search(r"\b(hi|hello|hey|good (morning|afternoon|evening))\b", text, re.I)) |
|
|
|
def process_query(query: str, first_turn: bool = False, session_id: str = "default"): |
|
intro = "" |
|
process_log = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if first_turn and (not query or query.strip() == ""): |
|
intro = "Hello! I’m Wisal, an AI assistant developed by Compumacy AI, specializing in Autism Spectrum Disorders. How can I help you today?" |
|
process_log.append(intro) |
|
_save_process_log(process_log) |
|
return intro |
|
|
|
if is_greeting(query): |
|
greeting = intro + "Hello! I’m Wisal, your AI assistant developed by Compumacy AI. How can I help you today?" |
|
process_log.append(f"Greeting detected.\n{greeting}") |
|
_save_process_log(process_log) |
|
return greeting |
|
|
|
corrected_query = call_llm( |
|
model="Qwen/Qwen3-32B", |
|
messages=[{"role": "user", "content": Prompt_template_translation.format(query=query)}], |
|
reasoning_effort="none" |
|
) |
|
process_log.append(f"Corrected Query: {corrected_query}") |
|
|
|
relevance = call_llm( |
|
model="Qwen/Qwen3-32B", |
|
messages=[{"role": "user", "content": Prompt_template_relevance.format(corrected_query=corrected_query)}], |
|
reasoning_effort="none" |
|
) |
|
process_log.append(f"Relevance Check: {relevance}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if relevance != "RELATED": |
|
process_log.append("Query not autism-related.") |
|
_save_process_log(process_log) |
|
return |
|
|
|
return process_autism_pipeline(corrected_query, process_log, intro) |
|
|
|
def process_autism_pipeline(corrected_query, process_log, intro): |
|
web_search_resp = asyncio.run(search_autism(corrected_query)) |
|
web_answer = web_search_resp.get("answer", "") |
|
process_log.append(f"Web Search: {web_answer}") |
|
|
|
gen_prompt = Prompt_template_LLM_Generation.format(new_query=corrected_query) |
|
generated = call_llm( |
|
model="Qwen/Qwen3-32B", |
|
messages=[{"role": "user", "content": gen_prompt}], |
|
reasoning_effort="none" |
|
) |
|
process_log.append(f"LLM Generated: {generated}") |
|
|
|
rag_resp = asyncio.run(rag_autism(corrected_query, top_k=3)) |
|
rag_contexts = rag_resp.get("answer", []) |
|
process_log.append(f"RAG Contexts: {rag_contexts}") |
|
|
|
answers_list = f"[1] {generated}\n[2] {web_answer}\n" + "\n".join(f"[{i+3}] {c}" for i, c in enumerate(rag_contexts)) |
|
rerank_prompt = Prompt_template_Reranker.format(new_query=corrected_query, answers_list=answers_list) |
|
reranked = call_llm( |
|
model="Qwen/Qwen3-32B", |
|
messages=[{"role": "user", "content": rerank_prompt}], |
|
reasoning_effort="none" |
|
) |
|
process_log.append(f"Reranked: {reranked}") |
|
|
|
wisal_prompt = Prompt_template_Wisal.format(new_query=corrected_query, document=reranked) |
|
wisal = call_llm( |
|
model="Qwen/Qwen3-32B", |
|
messages=[{"role": "user", "content": wisal_prompt}], |
|
reasoning_effort="none" |
|
) |
|
process_log.append(f"Wisal Answer: {wisal}") |
|
|
|
halluc_prompt = Prompt_template_Halluciations.format( |
|
new_query=corrected_query, |
|
answer=wisal, |
|
document=generated |
|
) |
|
halluc = call_llm( |
|
model="Qwen/Qwen3-32B", |
|
messages=[{"role": "user", "content": halluc_prompt}], |
|
reasoning_effort="none" |
|
) |
|
process_log.append(f"Hallucination Score: {halluc}") |
|
score = int(halluc.split("Score: ")[-1]) if "Score: " in halluc else 3 |
|
|
|
if score in (2, 3): |
|
paraphrased = call_llm( |
|
model="Qwen/Qwen3-32B", |
|
messages=[{"role": "user", "content": Prompt_template_paraphrasing.format(document=generated)}], |
|
reasoning_effort="none" |
|
) |
|
wisal = call_llm( |
|
model="Qwen/Qwen3-32B", |
|
messages=[{"role": "user", "content": Prompt_template_Wisal.format(new_query=corrected_query, document=paraphrased)}], |
|
reasoning_effort="none" |
|
) |
|
process_log.append(f"Paraphrased Wisal: {wisal}") |
|
|
|
try: |
|
detected_lang = langdetect.detect(corrected_query) |
|
except: |
|
detected_lang = "en" |
|
|
|
if detected_lang != "en": |
|
result = call_llm( |
|
model="Qwen/Qwen3-32B", |
|
messages=[{"role": "user", "content": Prompt_template_Translate_to_original.format(query=corrected_query, document=wisal)}], |
|
reasoning_effort="none" |
|
) |
|
process_log.append(f"Translated Back: {result}") |
|
else: |
|
result = wisal |
|
process_log.append(f"Final Result: {result}") |
|
|
|
_save_process_log(process_log) |
|
return intro + result |
|
|
|
|
|
def _save_process_log(log_lines, filename="process_output.txt"): |
|
import datetime |
|
logs_dir = os.path.join(os.path.dirname(__file__), "logs") |
|
os.makedirs(logs_dir, exist_ok=True) |
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
|
log_filename = os.path.join(logs_dir, f"log_{timestamp}.txt") |
|
with open(log_filename, "w", encoding="utf-8") as f: |
|
for line in log_lines: |
|
f.write(str(line) + "\n\n") |
|
|
|
def _save_process_log(log_lines, filename="process_output.txt"): |
|
import datetime |
|
import os |
|
|
|
logs_dir = os.path.join(os.path.dirname(__file__), "logs") |
|
os.makedirs(logs_dir, exist_ok=True) |
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
|
log_filename = os.path.join(logs_dir, f"log_{timestamp}.txt") |
|
try: |
|
with open(log_filename, "w", encoding="utf-8") as f: |
|
for line in log_lines: |
|
f.write(str(line) + "\n\n") |
|
except Exception as e: |
|
pass |
|
|
|
|
|
|
|
def main_pipeline_interface(query): |
|
return process_query(query, first_turn=True) |
|
|
|
def main_pipeline_with_doc_and_history(query, doc_file, doc_type, history): |
|
response = main_pipeline_with_doc(query, doc_file, doc_type) |
|
updated_history = history + f"\nUser: {query}\nWisal: {response}\n" |
|
return response, updated_history |
|
|
|
def main_pipeline_with_doc(query, doc_file, doc_type): |
|
|
|
if doc_file is None or doc_type == "None": |
|
return process_query(query, first_turn=True) |
|
|
|
safe_filename = os.path.basename(getattr(doc_file, 'name', str(doc_file))) |
|
upload_dir = os.path.join(os.path.dirname(__file__), "uploaded_docs") |
|
os.makedirs(upload_dir, exist_ok=True) |
|
|
|
save_path = os.path.join(upload_dir, safe_filename) |
|
|
|
|
|
if hasattr(doc_file, 'read'): |
|
|
|
file_bytes = doc_file.read() |
|
else: |
|
|
|
with open(str(doc_file), 'rb') as f: |
|
file_bytes = f.read() |
|
|
|
|
|
with open(save_path, "wb") as f: |
|
f.write(file_bytes) |
|
|
|
|
|
|
|
if doc_type == "Knowledge Document": |
|
status = RAG_Domain_know_doc.ingest_file(save_path) |
|
answer = RAG_Domain_know_doc.answer_question(query) |
|
return f"[Knowledge Document Uploaded]\n{status}\n\n{answer}" |
|
elif doc_type == "User-Specific Document": |
|
status = User_Specific_Documents.ingest_file(save_path) |
|
answer = User_Specific_Documents.answer_question(query) |
|
return f"[User-Specific Document Uploaded]\n{status}\n\n{answer}" |
|
elif doc_type == "Old Document": |
|
status = Old_Document.ingest_file(save_path) |
|
answer = Old_Document.answer_question(query) |
|
return f"[Old Document Uploaded]\n{status}\n\n{answer}" |
|
else: |
|
return "Invalid document type." |
|
|
|
def pipeline_with_history(message, doc_file, doc_type, history): |
|
if not message.strip(): |
|
return history, "" |
|
response = main_pipeline_with_doc(message, doc_file, doc_type) |
|
history = history + [[message, response]] |
|
return history, "" |
|
|
|
with gr.Blocks(title="Wisal Chatbot", theme=gr.themes.Base()) as demo: |
|
gr.Markdown("# 🤖 Wisal: Autism AI Assistant") |
|
|
|
chatbot = gr.Chatbot(label="Wisal Chat", height=500) |
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox(placeholder="Type your question here...", label="", lines=1) |
|
send_btn = gr.Button("Send") |
|
|
|
doc_file = gr.File(label="📎 Upload Document (PDF, DOCX, TXT)", file_types=[".pdf", ".docx", ".txt"]) |
|
doc_type = gr.Radio( |
|
["None", "Knowledge Document", "User-Specific Document", "Old Document"], |
|
value="None", |
|
label="Document Type" |
|
) |
|
|
|
send_btn.click( |
|
fn=pipeline_with_history, |
|
inputs=[user_input, doc_file, doc_type, chatbot], |
|
outputs=[chatbot, user_input] |
|
) |
|
|
|
clear_btn = gr.Button("Clear Chat") |
|
clear_btn.click(lambda: [], outputs=[chatbot]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |