asadsandhu commited on
Commit
e31fef3
·
1 Parent(s): 0fced6a
Files changed (2) hide show
  1. app.py +86 -89
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,106 +1,103 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import faiss
4
- import torch
5
  import numpy as np
6
-
 
7
  from sentence_transformers import SentenceTransformer
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
 
10
- # ===============================
11
- # Load Retrieval Components
12
- # ===============================
13
- print("Loading corpus and FAISS index...")
14
  df = pd.read_csv("retrieval_corpus.csv")
15
  index = faiss.read_index("faiss_index.bin")
16
 
17
- print("Loading embedding model...")
 
 
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
- # ===============================
21
- # Load LLM on CPU
22
- # ===============================
23
  model_id = "BioMistral/BioMistral-7B"
24
 
25
- print(f"Loading tokenizer and model: {model_id}")
 
 
 
 
 
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_id)
27
- model = AutoModelForCausalLM.from_pretrained(
28
  model_id,
29
- torch_dtype=torch.float32,
30
- low_cpu_mem_usage=True
31
- ).to("cpu")
32
-
33
- tokenizer.pad_token = tokenizer.eos_token
34
-
35
- # ===============================
36
- # RAG Pipeline
37
- # ===============================
38
- def get_top_k_chunks(query, k=5):
39
- query_embedding = embedding_model.encode([query])
40
- scores, indices = index.search(np.array(query_embedding).astype("float32"), k)
41
- return df.iloc[indices[0]]["text"].tolist()
42
-
43
- def build_prompt(query, chunks):
44
- context = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks))
45
- prompt = (
46
- "You are a clinical reasoning assistant. Based on the following medical information, "
47
- "answer the query with a detailed explanation.\n\n"
48
- f"Context:\n{context}\n\n"
49
- f"Query: {query}\n"
50
- "Answer:"
51
- )
 
 
 
 
 
 
 
 
 
 
 
52
  return prompt
53
 
54
- def generate_diagnosis(query):
55
- chunks = get_top_k_chunks(query)
56
- prompt = build_prompt(query, chunks)
57
-
58
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
59
- input_ids = inputs.input_ids.to("cpu")
60
-
61
- with torch.no_grad():
62
- output = model.generate(
63
- input_ids=input_ids,
64
- max_new_tokens=256,
65
- do_sample=True,
66
- top_k=50,
67
- top_p=0.95,
68
- temperature=0.7,
69
- pad_token_id=tokenizer.eos_token_id
70
- )
71
-
72
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
73
- answer = generated_text.split("Answer:")[-1].strip()
74
- return answer, "\n\n".join(chunks)
75
-
76
- # ===============================
77
- # Gradio UI
78
- # ===============================
79
- def run_interface():
80
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
81
- gr.Markdown("## 🧠 Clinical Diagnosis Assistant (RAG)")
82
- gr.Markdown("Enter a clinical query. The assistant retrieves relevant medical facts and generates a diagnostic explanation.")
83
-
84
- with gr.Row():
85
- query_input = gr.Textbox(label="Clinical Query", placeholder="e.g. 65-year-old male with shortness of breath...")
86
- generate_btn = gr.Button("Generate Diagnosis")
87
-
88
- with gr.Accordion("📄 Retrieved Context", open=False):
89
- context_output = gr.Textbox(label="Top-5 Retrieved Chunks", lines=10, interactive=False)
90
-
91
- answer_output = gr.Textbox(label="Generated Diagnosis", lines=8)
92
-
93
- generate_btn.click(
94
- fn=generate_diagnosis,
95
- inputs=query_input,
96
- outputs=[answer_output, context_output]
97
- )
98
-
99
- return demo
100
-
101
- # ===============================
102
- # Launch App
103
- # ===============================
104
- if __name__ == "__main__":
105
- demo = run_interface()
106
- demo.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
  import faiss
 
4
  import numpy as np
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from sentence_transformers import SentenceTransformer
 
8
 
9
+ # ----------------------
10
+ # Load Retrieval Corpus & FAISS Index
11
+ # ----------------------
 
12
  df = pd.read_csv("retrieval_corpus.csv")
13
  index = faiss.read_index("faiss_index.bin")
14
 
15
+ # ----------------------
16
+ # Load Embedding Model
17
+ # ----------------------
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
+ # ----------------------
21
+ # Load HuggingFace LLM (Nous-Hermes)
22
+ # ----------------------
23
  model_id = "BioMistral/BioMistral-7B"
24
 
25
+ bnb_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_use_double_quant=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_compute_dtype=torch.float16,
30
+ )
31
+
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
33
+ generation_model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
+ torch_dtype=torch.float16,
36
+ device_map="auto",
37
+ quantization_config=bnb_config
38
+ )
39
+
40
+ # ----------------------
41
+ # RAG Functions
42
+ # ----------------------
43
+
44
+ def retrieve_top_k(query, k=5):
45
+ query_embedding = embedding_model.encode([query]).astype("float32")
46
+ D, I = index.search(query_embedding, k)
47
+ results = df.iloc[I[0]].copy()
48
+ results["score"] = D[0]
49
+ return results
50
+
51
+ def build_prompt(query, retrieved_docs):
52
+ context_text = "\n".join([
53
+ f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
54
+ ])
55
+
56
+ prompt = f"""[INST] <<SYS>>
57
+ You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
58
+ <</SYS>>
59
+
60
+ ### Patient Query:
61
+ {query}
62
+
63
+ ### Clinical Context:
64
+ {context_text}
65
+
66
+ ### Diagnostic Explanation:
67
+ [/INST]
68
+ """
69
  return prompt
70
 
71
+ def generate_local_answer(prompt, max_new_tokens=512):
72
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
73
+ output = generation_model.generate(
74
+ input_ids=input_ids,
75
+ max_new_tokens=max_new_tokens,
76
+ temperature=0.5,
77
+ do_sample=True,
78
+ top_k=50,
79
+ top_p=0.95,
80
+ )
81
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
82
+ return decoded.split("### Diagnostic Explanation:")[-1].strip()
83
+
84
+ # ----------------------
85
+ # Gradio Interface
86
+ # ----------------------
87
+
88
+ def rag_chat(query):
89
+ top_docs = retrieve_top_k(query, k=5)
90
+ prompt = build_prompt(query, top_docs)
91
+ answer = generate_local_answer(prompt)
92
+ return answer
93
+
94
+ iface = gr.Interface(
95
+ fn=rag_chat,
96
+ inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."),
97
+ outputs="text",
98
+ title="🩺 Clinical Reasoning RAG Assistant",
99
+ description="Ask a medical question based on MIMIC-IV-Ext-DiReCT's diagnostic knowledge.",
100
+ allow_flagging="never"
101
+ )
102
+
103
+ iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -4,4 +4,6 @@ faiss-cpu
4
  torch
5
  gradio
6
  accelerate
7
- sentencepiece
 
 
 
4
  torch
5
  gradio
6
  accelerate
7
+ sentencepiece
8
+ bitsandbytes
9
+ blobfile