asadsandhu commited on
Commit
eb2112e
·
1 Parent(s): fed899e

App.py updated.

Browse files
Files changed (1) hide show
  1. app.py +89 -117
app.py CHANGED
@@ -1,134 +1,106 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import faiss
4
- import numpy as np
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
7
  from sentence_transformers import SentenceTransformer
 
8
 
9
- # ----------------------
10
- # Load Retrieval Corpus & FAISS Index
11
- # ----------------------
 
12
  df = pd.read_csv("retrieval_corpus.csv")
13
  index = faiss.read_index("faiss_index.bin")
14
 
15
- # ----------------------
16
- # Load Embedding Model (very lightweight)
17
- # ----------------------
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
- # ----------------------
21
- # Load HuggingFace LLM (BioMistral-7B, 8bit CPU)
22
- # ----------------------
23
  model_id = "BioMistral/BioMistral-7B"
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
- tokenizer.pad_token = tokenizer.eos_token
26
 
27
- generation_model = AutoModelForCausalLM.from_pretrained(
 
 
28
  model_id,
29
- device_map="auto",
30
- offload_folder="offload",
31
- offload_state_dict=True,
32
- torch_dtype=torch.float16,
33
- low_cpu_mem_usage=True,
34
- )
35
-
36
- # ----------------------
37
- # RAG Functions
38
- # ----------------------
39
-
40
- def retrieve_top_k(query, k=5):
41
- query_embedding = embedding_model.encode([query]).astype("float32")
42
- D, I = index.search(query_embedding, k)
43
- results = df.iloc[I[0]].copy()
44
- results["score"] = D[0]
45
- return results
46
-
47
- def build_prompt(query, retrieved_docs):
48
- context_text = "\n".join([
49
- f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
50
- ])
51
- prompt = f"""[INST] <<SYS>>
52
- You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
53
- <</SYS>>
54
-
55
- ### Patient Query:
56
- {query}
57
-
58
- ### Clinical Context:
59
- {context_text}
60
-
61
- ### Diagnostic Explanation:
62
- [/INST]
63
- """
64
- return prompt
65
 
66
- def generate_local_answer(prompt, max_new_tokens=256): # ✅ Reduced token budget
67
- tokens = tokenizer(
68
- prompt,
69
- return_tensors="pt",
70
- padding=True,
71
- truncation=True,
72
- max_length=1024
73
- )
74
- input_ids = tokens["input_ids"]
75
- attention_mask = tokens["attention_mask"]
76
-
77
- output = generation_model.generate(
78
- input_ids=input_ids,
79
- attention_mask=attention_mask,
80
- max_new_tokens=max_new_tokens,
81
- temperature=0.5,
82
- do_sample=True,
83
- top_k=50,
84
- top_p=0.95,
85
- pad_token_id=tokenizer.pad_token_id
86
  )
 
87
 
88
- decoded = tokenizer.decode(output[0], skip_special_tokens=True)
89
- return decoded.split("### Diagnostic Explanation:")[-1].strip()
90
-
91
- # ----------------------
92
- # Gradio Interface
93
- # ----------------------
94
-
95
- def rag_chat(query):
96
- top_docs = retrieve_top_k(query, k=5)
97
- prompt = build_prompt(query, top_docs)
98
- answer = generate_local_answer(prompt)
99
- return answer
100
-
101
- custom_css = """
102
- textarea, .input_textbox {
103
- font-size: 1.05rem !important;
104
- }
105
- .output-markdown {
106
- font-size: 1.08rem !important;
107
- }
108
- """
109
-
110
- with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue="blue")) as demo:
111
- gr.Markdown("""
112
- # 🩺 RAGnosis — Clinical Reasoning Assistant
113
-
114
- Enter a natural-language query describing your patient's condition to receive an AI-generated diagnostic reasoning response.
115
-
116
- **Example:**
117
- *Patient has shortness of breath, fatigue, and leg swelling.*
118
- """)
119
-
120
- with gr.Row():
121
- with gr.Column():
122
- query_input = gr.Textbox(
123
- lines=4,
124
- label="📝 Patient Query",
125
- placeholder="Enter patient symptoms or findings..."
126
- )
127
- submit_btn = gr.Button("🔍 Generate Diagnosis")
128
-
129
- with gr.Column():
130
- output = gr.Markdown(label="���� Diagnostic Reasoning")
131
-
132
- submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
133
-
134
- demo.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import faiss
 
4
  import torch
5
+ import numpy as np
6
+
7
  from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
 
10
+ # ===============================
11
+ # Load Retrieval Components
12
+ # ===============================
13
+ print("Loading corpus and FAISS index...")
14
  df = pd.read_csv("retrieval_corpus.csv")
15
  index = faiss.read_index("faiss_index.bin")
16
 
17
+ print("Loading embedding model...")
 
 
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
+ # ===============================
21
+ # Load LLM on CPU
22
+ # ===============================
23
  model_id = "BioMistral/BioMistral-7B"
 
 
24
 
25
+ print(f"Loading tokenizer and model: {model_id}")
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+ model = AutoModelForCausalLM.from_pretrained(
28
  model_id,
29
+ torch_dtype=torch.float32,
30
+ low_cpu_mem_usage=True
31
+ ).to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+
35
+ # ===============================
36
+ # RAG Pipeline
37
+ # ===============================
38
+ def get_top_k_chunks(query, k=5):
39
+ query_embedding = embedding_model.encode([query])
40
+ scores, indices = index.search(np.array(query_embedding).astype("float32"), k)
41
+ return df.iloc[indices[0]]["text"].tolist()
42
+
43
+ def build_prompt(query, chunks):
44
+ context = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks))
45
+ prompt = (
46
+ "You are a clinical reasoning assistant. Based on the following medical information, "
47
+ "answer the query with a detailed explanation.\n\n"
48
+ f"Context:\n{context}\n\n"
49
+ f"Query: {query}\n"
50
+ "Answer:"
 
 
51
  )
52
+ return prompt
53
 
54
+ def generate_diagnosis(query):
55
+ chunks = get_top_k_chunks(query)
56
+ prompt = build_prompt(query, chunks)
57
+
58
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
59
+ input_ids = inputs.input_ids.to("cpu")
60
+
61
+ with torch.no_grad():
62
+ output = model.generate(
63
+ input_ids=input_ids,
64
+ max_new_tokens=256,
65
+ do_sample=True,
66
+ top_k=50,
67
+ top_p=0.95,
68
+ temperature=0.7,
69
+ pad_token_id=tokenizer.eos_token_id
70
+ )
71
+
72
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
73
+ answer = generated_text.split("Answer:")[-1].strip()
74
+ return answer, "\n\n".join(chunks)
75
+
76
+ # ===============================
77
+ # Gradio UI
78
+ # ===============================
79
+ def run_interface():
80
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
81
+ gr.Markdown("## 🧠 Clinical Diagnosis Assistant (RAG)")
82
+ gr.Markdown("Enter a clinical query. The assistant retrieves relevant medical facts and generates a diagnostic explanation.")
83
+
84
+ with gr.Row():
85
+ query_input = gr.Textbox(label="Clinical Query", placeholder="e.g. 65-year-old male with shortness of breath...")
86
+ generate_btn = gr.Button("Generate Diagnosis")
87
+
88
+ with gr.Accordion("📄 Retrieved Context", open=False):
89
+ context_output = gr.Textbox(label="Top-5 Retrieved Chunks", lines=10, interactive=False)
90
+
91
+ answer_output = gr.Textbox(label="Generated Diagnosis", lines=8)
92
+
93
+ generate_btn.click(
94
+ fn=generate_diagnosis,
95
+ inputs=query_input,
96
+ outputs=[answer_output, context_output]
97
+ )
98
+
99
+ return demo
100
+
101
+ # ===============================
102
+ # Launch App
103
+ # ===============================
104
+ if __name__ == "__main__":
105
+ demo = run_interface()
106
+ demo.launch()