import gradio as gr import pandas as pd import faiss import torch import numpy as np from accelerate import init_empty_weights, load_checkpoint_and_dispatch from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM # =============================== # Load Retrieval Components # =============================== print("Loading corpus and FAISS index...") df = pd.read_csv("retrieval_corpus.csv") index = faiss.read_index("faiss_index.bin") print("Loading embedding model...") embedding_model = SentenceTransformer("all-MiniLM-L6-v2") # =============================== # Load LLM on CPU # =============================== model_id = "BioMistral/BioMistral-7B" print(f"Loading tokenizer and model: {model_id}") tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, ).to("cpu") tokenizer.pad_token = tokenizer.eos_token # =============================== # RAG Pipeline # =============================== def get_top_k_chunks(query, k=5): query_embedding = embedding_model.encode([query]) scores, indices = index.search(np.array(query_embedding).astype("float32"), k) return df.iloc[indices[0]]["text"].tolist() def build_prompt(query, chunks): context = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks)) prompt = ( "You are a clinical reasoning assistant. Based on the following medical information, " "answer the query with a detailed explanation.\n\n" f"Context:\n{context}\n\n" f"Query: {query}\n" "Answer:" ) return prompt def generate_diagnosis(query): chunks = get_top_k_chunks(query) prompt = build_prompt(query, chunks) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) input_ids = inputs.input_ids.to("cpu") with torch.no_grad(): output = model.generate( input_ids=input_ids, max_new_tokens=256, do_sample=True, top_k=50, top_p=0.95, temperature=0.7, pad_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) answer = generated_text.split("Answer:")[-1].strip() return answer, "\n\n".join(chunks) # =============================== # Gradio UI # =============================== def run_interface(): with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## 🧠 Clinical Diagnosis Assistant (RAG)") gr.Markdown("Enter a clinical query. The assistant retrieves relevant medical facts and generates a diagnostic explanation.") with gr.Row(): query_input = gr.Textbox(label="Clinical Query", placeholder="e.g. 65-year-old male with shortness of breath...") generate_btn = gr.Button("Generate Diagnosis") with gr.Accordion("📄 Retrieved Context", open=False): context_output = gr.Textbox(label="Top-5 Retrieved Chunks", lines=10, interactive=False) answer_output = gr.Textbox(label="Generated Diagnosis", lines=8) generate_btn.click( fn=generate_diagnosis, inputs=query_input, outputs=[answer_output, context_output] ) return demo # =============================== # Launch App # =============================== if __name__ == "__main__": demo = run_interface() demo.launch()