File size: 3,513 Bytes
2bda614
e531b46
 
10c6208
098c01c
 
2bda614
098c01c
2bda614
098c01c
 
 
 
2bda614
 
e531b46
098c01c
2bda614
e31fef3
098c01c
 
 
 
2833445
098c01c
2833445
098c01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e531b46
 
098c01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import pandas as pd
import faiss
import torch
import numpy as np
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM

# ===============================
# Load Retrieval Components
# ===============================
print("Loading corpus and FAISS index...")
df = pd.read_csv("retrieval_corpus.csv")
index = faiss.read_index("faiss_index.bin")

print("Loading embedding model...")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# ===============================
# Load LLM on CPU
# ===============================
model_id = "BioMistral/BioMistral-7B"

print(f"Loading tokenizer and model: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
).to("cpu")

tokenizer.pad_token = tokenizer.eos_token

# ===============================
# RAG Pipeline
# ===============================
def get_top_k_chunks(query, k=5):
    query_embedding = embedding_model.encode([query])
    scores, indices = index.search(np.array(query_embedding).astype("float32"), k)
    return df.iloc[indices[0]]["text"].tolist()

def build_prompt(query, chunks):
    context = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks))
    prompt = (
        "You are a clinical reasoning assistant. Based on the following medical information, "
        "answer the query with a detailed explanation.\n\n"
        f"Context:\n{context}\n\n"
        f"Query: {query}\n"
        "Answer:"
    )
    return prompt

def generate_diagnosis(query):
    chunks = get_top_k_chunks(query)
    prompt = build_prompt(query, chunks)

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    input_ids = inputs.input_ids.to("cpu")

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_new_tokens=256,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    answer = generated_text.split("Answer:")[-1].strip()
    return answer, "\n\n".join(chunks)

# ===============================
# Gradio UI
# ===============================
def run_interface():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("## 🧠 Clinical Diagnosis Assistant (RAG)")
        gr.Markdown("Enter a clinical query. The assistant retrieves relevant medical facts and generates a diagnostic explanation.")

        with gr.Row():
            query_input = gr.Textbox(label="Clinical Query", placeholder="e.g. 65-year-old male with shortness of breath...")
            generate_btn = gr.Button("Generate Diagnosis")

        with gr.Accordion("📄 Retrieved Context", open=False):
            context_output = gr.Textbox(label="Top-5 Retrieved Chunks", lines=10, interactive=False)

        answer_output = gr.Textbox(label="Generated Diagnosis", lines=8)

        generate_btn.click(
            fn=generate_diagnosis,
            inputs=query_input,
            outputs=[answer_output, context_output]
        )

    return demo

# ===============================
# Launch App
# ===============================
if __name__ == "__main__":
    demo = run_interface()
    demo.launch()