Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import faiss | |
import torch | |
import numpy as np | |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# =============================== | |
# Load Retrieval Components | |
# =============================== | |
print("Loading corpus and FAISS index...") | |
df = pd.read_csv("retrieval_corpus.csv") | |
index = faiss.read_index("faiss_index.bin") | |
print("Loading embedding model...") | |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
# =============================== | |
# Load LLM on CPU | |
# =============================== | |
model_id = "BioMistral/BioMistral-7B" | |
print(f"Loading tokenizer and model: {model_id}") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
).to("cpu") | |
tokenizer.pad_token = tokenizer.eos_token | |
# =============================== | |
# RAG Pipeline | |
# =============================== | |
def get_top_k_chunks(query, k=5): | |
query_embedding = embedding_model.encode([query]) | |
scores, indices = index.search(np.array(query_embedding).astype("float32"), k) | |
return df.iloc[indices[0]]["text"].tolist() | |
def build_prompt(query, chunks): | |
context = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks)) | |
prompt = ( | |
"You are a clinical reasoning assistant. Based on the following medical information, " | |
"answer the query with a detailed explanation.\n\n" | |
f"Context:\n{context}\n\n" | |
f"Query: {query}\n" | |
"Answer:" | |
) | |
return prompt | |
def generate_diagnosis(query): | |
chunks = get_top_k_chunks(query) | |
prompt = build_prompt(query, chunks) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
input_ids = inputs.input_ids.to("cpu") | |
with torch.no_grad(): | |
output = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=256, | |
do_sample=True, | |
top_k=50, | |
top_p=0.95, | |
temperature=0.7, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
answer = generated_text.split("Answer:")[-1].strip() | |
return answer, "\n\n".join(chunks) | |
# =============================== | |
# Gradio UI | |
# =============================== | |
def run_interface(): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## 🧠 Clinical Diagnosis Assistant (RAG)") | |
gr.Markdown("Enter a clinical query. The assistant retrieves relevant medical facts and generates a diagnostic explanation.") | |
with gr.Row(): | |
query_input = gr.Textbox(label="Clinical Query", placeholder="e.g. 65-year-old male with shortness of breath...") | |
generate_btn = gr.Button("Generate Diagnosis") | |
with gr.Accordion("📄 Retrieved Context", open=False): | |
context_output = gr.Textbox(label="Top-5 Retrieved Chunks", lines=10, interactive=False) | |
answer_output = gr.Textbox(label="Generated Diagnosis", lines=8) | |
generate_btn.click( | |
fn=generate_diagnosis, | |
inputs=query_input, | |
outputs=[answer_output, context_output] | |
) | |
return demo | |
# =============================== | |
# Launch App | |
# =============================== | |
if __name__ == "__main__": | |
demo = run_interface() | |
demo.launch() |