Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import faiss | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer | |
# Load retrieval corpus & FAISS index | |
df = pd.read_csv("retrieval_corpus.csv") | |
index = faiss.read_index("faiss_index.bin") | |
# Load embedding model | |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
model_id = "stanford-crfm/BioMedLM" | |
bnb_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
llm_int8_threshold=6.0, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.pad_token = tokenizer.eos_token | |
generation_model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
quantization_config=bnb_config, | |
) | |
def retrieve_top_k(query, k=5): | |
query_embedding = embedding_model.encode([query]).astype("float32") | |
D, I = index.search(query_embedding, k) | |
results = df.iloc[I[0]].copy() | |
results["score"] = D[0] | |
return results | |
def build_prompt(query, retrieved_docs): | |
context_text = "\n".join([f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()]) | |
return f"""[INST] <<SYS>> | |
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context. | |
<</SYS>> | |
### Patient Query: | |
{query} | |
### Clinical Context: | |
{context_text} | |
### Diagnostic Explanation: | |
[/INST] | |
""" | |
def generate_local_answer(prompt, max_new_tokens=512): | |
device = torch.device("cpu") | |
print(f"Using device: {device}") | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True) | |
input_ids = inputs["input_ids"].to(device) | |
attention_mask = inputs["attention_mask"].to(device) | |
output = generation_model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
temperature=0.5, | |
do_sample=True, | |
top_k=50, | |
top_p=0.95, | |
) | |
decoded = tokenizer.decode(output[0], skip_special_tokens=True) | |
return decoded.split("### Diagnostic Explanation:")[-1].strip() | |
def rag_chat(query): | |
top_docs = retrieve_top_k(query, k=5) | |
prompt = build_prompt(query, top_docs) | |
return generate_local_answer(prompt) | |
iface = gr.Interface( | |
fn=rag_chat, | |
inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."), | |
outputs="text", | |
title="🩺 Clinical Reasoning RAG Assistant", | |
description="Ask a medical question based on MIMIC‑IV‑Ext‑DiReCT’s diagnostic knowledge.", | |
allow_flagging="never" | |
) | |
iface.launch() | |