File size: 3,749 Bytes
2bda614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03adce4
2bda614
ff20a1e
2bda614
 
 
ff20a1e
2bda614
 
 
0619f86
 
03adce4
 
2bda614
 
ff20a1e
0619f86
03adce4
0619f86
2bda614
ff20a1e
2bda614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03adce4
2bda614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0619f86
2bda614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03adce4
8690055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bda614
03adce4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer

# ----------------------
# Load Retrieval Corpus & FAISS Index
# ----------------------
df = pd.read_csv("retrieval_corpus.csv")
index = faiss.read_index("faiss_index.bin")

# ----------------------
# Load Embedding Model
# ----------------------
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# ----------------------
# Load HuggingFace LLM (BioMistral-7B)
# ----------------------
model_id = "BioMistral/BioMistral-7B"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(model_id)

generation_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None
).to(device)


# ----------------------
# RAG Functions
# ----------------------

def retrieve_top_k(query, k=5):
    query_embedding = embedding_model.encode([query]).astype("float32")
    D, I = index.search(query_embedding, k)
    results = df.iloc[I[0]].copy()
    results["score"] = D[0]
    return results

def build_prompt(query, retrieved_docs):
    context_text = "\n".join([
        f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
    ])
    
    prompt = f"""[INST] <<SYS>>
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
<</SYS>>

### Patient Query:
{query}

### Clinical Context:
{context_text}

### Diagnostic Explanation:
[/INST]
"""
    return prompt

def generate_local_answer(prompt, max_new_tokens=512):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    output = generation_model.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        temperature=0.5,
        do_sample=True,
        top_k=50,
        top_p=0.95,
    )
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    return decoded.split("### Diagnostic Explanation:")[-1].strip()

# ----------------------
# Gradio Interface
# ----------------------

def rag_chat(query):
    top_docs = retrieve_top_k(query, k=5)
    prompt = build_prompt(query, top_docs)
    answer = generate_local_answer(prompt)
    return answer

# Optional: CSS for improved UX
custom_css = """
textarea, .input_textbox {
    font-size: 1.05rem !important;
}
.output-markdown {
    font-size: 1.08rem !important;
}
"""

with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue="blue")) as demo:
    gr.Markdown("""
# 🩺 RAGnosis — Clinical Reasoning Assistant

Enter a natural-language query describing your patient's condition to receive an AI-generated diagnostic reasoning response.

**Example:**  
*Patient has shortness of breath, fatigue, and leg swelling.*
""")

    with gr.Row():
        with gr.Column():
            query_input = gr.Textbox(
                lines=4,
                label="📝 Patient Query",
                placeholder="Enter patient symptoms or findings..."
            )
            submit_btn = gr.Button("🔍 Generate Diagnosis")

        with gr.Column():
            output = gr.Markdown(label="🧠 Diagnostic Reasoning")

    submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)

demo.launch()