import gradio as gr import pandas as pd import faiss import numpy as np import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from sentence_transformers import SentenceTransformer # ---------------------- # Load Retrieval Corpus & FAISS Index # ---------------------- df = pd.read_csv("retrieval_corpus.csv") index = faiss.read_index("faiss_index.bin") # ---------------------- # Load Embedding Model # ---------------------- embedding_model = SentenceTransformer("all-MiniLM-L6-v2") # ---------------------- # Load HuggingFace LLM (Nous-Hermes) # ---------------------- model_id = "BioMistral/BioMistral-7B" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token tokenizer.save_pretrained("fixed_tokenizer") generation_model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto", quantization_config=bnb_config ) # ---------------------- # RAG Functions # ---------------------- def retrieve_top_k(query, k=5): query_embedding = embedding_model.encode([query]).astype("float32") D, I = index.search(query_embedding, k) results = df.iloc[I[0]].copy() results["score"] = D[0] return results def build_prompt(query, retrieved_docs): context_text = "\n".join([ f"- {doc['text']}" for _, doc in retrieved_docs.iterrows() ]) prompt = f"""[INST] <> You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context. <> ### Patient Query: {query} ### Clinical Context: {context_text} ### Diagnostic Explanation: [/INST] """ return prompt def generate_local_answer(prompt, max_new_tokens=512): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] output = generation_model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=0.5, do_sample=True, top_k=50, top_p=0.95, ) decoded = tokenizer.decode(output[0], skip_special_tokens=True) return decoded.split("### Diagnostic Explanation:")[-1].strip() # ---------------------- # Gradio Interface # ---------------------- def rag_chat(query): top_docs = retrieve_top_k(query, k=5) prompt = build_prompt(query, top_docs) answer = generate_local_answer(prompt) return answer iface = gr.Interface( fn=rag_chat, inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."), outputs="text", title="🩺 Clinical Reasoning RAG Assistant", description="Ask a medical question based on MIMIC-IV-Ext-DiReCT's diagnostic knowledge.", allow_flagging="never" ) iface.launch()