File size: 3,200 Bytes
2bda614
 
 
eb2112e
e31fef3
 
2bda614
 
e31fef3
 
 
2bda614
 
 
e31fef3
 
 
2bda614
 
e31fef3
 
 
ff20a1e
ffa32ca
e31fef3
 
 
 
 
 
 
eb2112e
a2cbc8f
 
e31fef3
2bda614
e31fef3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb2112e
86d4de7
e31fef3
aeaead2
 
b4c917d
 
 
e31fef3
 
b4c917d
e31fef3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer

# ----------------------
# Load Retrieval Corpus & FAISS Index
# ----------------------
df = pd.read_csv("retrieval_corpus.csv")
index = faiss.read_index("faiss_index.bin")

# ----------------------
# Load Embedding Model
# ----------------------
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# ----------------------
# Load HuggingFace LLM (Nous-Hermes)
# ----------------------
model_id = "BioMistral/BioMistral-7B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.save_pretrained("fixed_tokenizer")
generation_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    quantization_config=bnb_config
)

# ----------------------
# RAG Functions
# ----------------------

def retrieve_top_k(query, k=5):
    query_embedding = embedding_model.encode([query]).astype("float32")
    D, I = index.search(query_embedding, k)
    results = df.iloc[I[0]].copy()
    results["score"] = D[0]
    return results

def build_prompt(query, retrieved_docs):
    context_text = "\n".join([
        f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
    ])

    prompt = f"""[INST] <<SYS>>
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
<</SYS>>

### Patient Query:
{query}

### Clinical Context:
{context_text}

### Diagnostic Explanation:
[/INST]
"""
    return prompt

def generate_local_answer(prompt, max_new_tokens=512):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    output = generation_model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=0.5,
        do_sample=True,
        top_k=50,
        top_p=0.95,
    )
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    return decoded.split("### Diagnostic Explanation:")[-1].strip()

# ----------------------
# Gradio Interface
# ----------------------

def rag_chat(query):
    top_docs = retrieve_top_k(query, k=5)
    prompt = build_prompt(query, top_docs)
    answer = generate_local_answer(prompt)
    return answer

iface = gr.Interface(
    fn=rag_chat,
    inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."),
    outputs="text",
    title="🩺 Clinical Reasoning RAG Assistant",
    description="Ask a medical question based on MIMIC-IV-Ext-DiReCT's diagnostic knowledge.",
    allow_flagging="never"
)

iface.launch()