Spaces:
Running
Running
File size: 3,749 Bytes
2bda614 03adce4 2bda614 ff20a1e 2bda614 ff20a1e 2bda614 0619f86 03adce4 2bda614 ff20a1e 0619f86 03adce4 0619f86 2bda614 ff20a1e 2bda614 03adce4 2bda614 0619f86 2bda614 03adce4 8690055 2bda614 03adce4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
# ----------------------
# Load Retrieval Corpus & FAISS Index
# ----------------------
df = pd.read_csv("retrieval_corpus.csv")
index = faiss.read_index("faiss_index.bin")
# ----------------------
# Load Embedding Model
# ----------------------
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# ----------------------
# Load HuggingFace LLM (BioMistral-7B)
# ----------------------
model_id = "BioMistral/BioMistral-7B"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_id)
generation_model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
).to(device)
# ----------------------
# RAG Functions
# ----------------------
def retrieve_top_k(query, k=5):
query_embedding = embedding_model.encode([query]).astype("float32")
D, I = index.search(query_embedding, k)
results = df.iloc[I[0]].copy()
results["score"] = D[0]
return results
def build_prompt(query, retrieved_docs):
context_text = "\n".join([
f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
])
prompt = f"""[INST] <<SYS>>
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
<</SYS>>
### Patient Query:
{query}
### Clinical Context:
{context_text}
### Diagnostic Explanation:
[/INST]
"""
return prompt
def generate_local_answer(prompt, max_new_tokens=512):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
output = generation_model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=0.5,
do_sample=True,
top_k=50,
top_p=0.95,
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return decoded.split("### Diagnostic Explanation:")[-1].strip()
# ----------------------
# Gradio Interface
# ----------------------
def rag_chat(query):
top_docs = retrieve_top_k(query, k=5)
prompt = build_prompt(query, top_docs)
answer = generate_local_answer(prompt)
return answer
# Optional: CSS for improved UX
custom_css = """
textarea, .input_textbox {
font-size: 1.05rem !important;
}
.output-markdown {
font-size: 1.08rem !important;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue="blue")) as demo:
gr.Markdown("""
# 🩺 RAGnosis — Clinical Reasoning Assistant
Enter a natural-language query describing your patient's condition to receive an AI-generated diagnostic reasoning response.
**Example:**
*Patient has shortness of breath, fatigue, and leg swelling.*
""")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
lines=4,
label="📝 Patient Query",
placeholder="Enter patient symptoms or findings..."
)
submit_btn = gr.Button("🔍 Generate Diagnosis")
with gr.Column():
output = gr.Markdown(label="🧠 Diagnostic Reasoning")
submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
demo.launch()
|