Spaces:
Sleeping
Sleeping
File size: 3,200 Bytes
2bda614 eb2112e e31fef3 2bda614 e31fef3 2bda614 e31fef3 2bda614 e31fef3 ff20a1e ffa32ca e31fef3 eb2112e a2cbc8f e31fef3 2bda614 e31fef3 eb2112e 86d4de7 e31fef3 aeaead2 b4c917d e31fef3 b4c917d e31fef3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
# ----------------------
# Load Retrieval Corpus & FAISS Index
# ----------------------
df = pd.read_csv("retrieval_corpus.csv")
index = faiss.read_index("faiss_index.bin")
# ----------------------
# Load Embedding Model
# ----------------------
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# ----------------------
# Load HuggingFace LLM (Nous-Hermes)
# ----------------------
model_id = "BioMistral/BioMistral-7B"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.save_pretrained("fixed_tokenizer")
generation_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=bnb_config
)
# ----------------------
# RAG Functions
# ----------------------
def retrieve_top_k(query, k=5):
query_embedding = embedding_model.encode([query]).astype("float32")
D, I = index.search(query_embedding, k)
results = df.iloc[I[0]].copy()
results["score"] = D[0]
return results
def build_prompt(query, retrieved_docs):
context_text = "\n".join([
f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
])
prompt = f"""[INST] <<SYS>>
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
<</SYS>>
### Patient Query:
{query}
### Clinical Context:
{context_text}
### Diagnostic Explanation:
[/INST]
"""
return prompt
def generate_local_answer(prompt, max_new_tokens=512):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
output = generation_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=0.5,
do_sample=True,
top_k=50,
top_p=0.95,
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return decoded.split("### Diagnostic Explanation:")[-1].strip()
# ----------------------
# Gradio Interface
# ----------------------
def rag_chat(query):
top_docs = retrieve_top_k(query, k=5)
prompt = build_prompt(query, top_docs)
answer = generate_local_answer(prompt)
return answer
iface = gr.Interface(
fn=rag_chat,
inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."),
outputs="text",
title="🩺 Clinical Reasoning RAG Assistant",
description="Ask a medical question based on MIMIC-IV-Ext-DiReCT's diagnostic knowledge.",
allow_flagging="never"
)
iface.launch() |