RAGnosis / app.py
asadsandhu's picture
Updated.
876d145
raw
history blame
2.64 kB
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
# Load retrieval corpus & FAISS index
df = pd.read_csv("retrieval_corpus.csv")
index = faiss.read_index("faiss_index.bin")
# Load embedding model
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
model_id = "stanford-crfm/BioMedLM"
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
generation_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
quantization_config=bnb_config,
)
def retrieve_top_k(query, k=5):
query_embedding = embedding_model.encode([query]).astype("float32")
D, I = index.search(query_embedding, k)
results = df.iloc[I[0]].copy()
results["score"] = D[0]
return results
def build_prompt(query, retrieved_docs):
context_text = "\n".join([f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()])
return f"""[INST] <<SYS>>
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
<</SYS>>
### Patient Query:
{query}
### Clinical Context:
{context_text}
### Diagnostic Explanation:
[/INST]
"""
def generate_local_answer(prompt, max_new_tokens=512):
device = torch.device("cpu")
print(f"Using device: {device}")
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
output = generation_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=0.5,
do_sample=True,
top_k=50,
top_p=0.95,
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return decoded.split("### Diagnostic Explanation:")[-1].strip()
def rag_chat(query):
top_docs = retrieve_top_k(query, k=5)
prompt = build_prompt(query, top_docs)
return generate_local_answer(prompt)
iface = gr.Interface(
fn=rag_chat,
inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."),
outputs="text",
title="🩺 Clinical Reasoning RAG Assistant",
description="Ask a medical question based on MIMIC‑IV‑Ext‑DiReCT’s diagnostic knowledge.",
allow_flagging="never"
)
iface.launch()