Spaces:
Running
Running
File size: 3,513 Bytes
2bda614 e531b46 10c6208 098c01c 2bda614 098c01c 2bda614 098c01c 2bda614 e531b46 098c01c 2bda614 e31fef3 098c01c 2833445 098c01c 2833445 098c01c e531b46 098c01c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import gradio as gr
import pandas as pd
import faiss
import torch
import numpy as np
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
# ===============================
# Load Retrieval Components
# ===============================
print("Loading corpus and FAISS index...")
df = pd.read_csv("retrieval_corpus.csv")
index = faiss.read_index("faiss_index.bin")
print("Loading embedding model...")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# ===============================
# Load LLM on CPU
# ===============================
model_id = "BioMistral/BioMistral-7B"
print(f"Loading tokenizer and model: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to("cpu")
tokenizer.pad_token = tokenizer.eos_token
# ===============================
# RAG Pipeline
# ===============================
def get_top_k_chunks(query, k=5):
query_embedding = embedding_model.encode([query])
scores, indices = index.search(np.array(query_embedding).astype("float32"), k)
return df.iloc[indices[0]]["text"].tolist()
def build_prompt(query, chunks):
context = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks))
prompt = (
"You are a clinical reasoning assistant. Based on the following medical information, "
"answer the query with a detailed explanation.\n\n"
f"Context:\n{context}\n\n"
f"Query: {query}\n"
"Answer:"
)
return prompt
def generate_diagnosis(query):
chunks = get_top_k_chunks(query)
prompt = build_prompt(query, chunks)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
input_ids = inputs.input_ids.to("cpu")
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
max_new_tokens=256,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
answer = generated_text.split("Answer:")[-1].strip()
return answer, "\n\n".join(chunks)
# ===============================
# Gradio UI
# ===============================
def run_interface():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🧠 Clinical Diagnosis Assistant (RAG)")
gr.Markdown("Enter a clinical query. The assistant retrieves relevant medical facts and generates a diagnostic explanation.")
with gr.Row():
query_input = gr.Textbox(label="Clinical Query", placeholder="e.g. 65-year-old male with shortness of breath...")
generate_btn = gr.Button("Generate Diagnosis")
with gr.Accordion("📄 Retrieved Context", open=False):
context_output = gr.Textbox(label="Top-5 Retrieved Chunks", lines=10, interactive=False)
answer_output = gr.Textbox(label="Generated Diagnosis", lines=8)
generate_btn.click(
fn=generate_diagnosis,
inputs=query_input,
outputs=[answer_output, context_output]
)
return demo
# ===============================
# Launch App
# ===============================
if __name__ == "__main__":
demo = run_interface()
demo.launch() |