asadsandhu commited on
Commit
10c6208
·
1 Parent(s): 9760e23
Files changed (1) hide show
  1. app.py +98 -91
app.py CHANGED
@@ -1,106 +1,113 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import faiss
4
- import torch
5
  import numpy as np
6
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
7
  from sentence_transformers import SentenceTransformer
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
 
10
- # ===============================
11
- # Load Retrieval Components
12
- # ===============================
13
- print("Loading corpus and FAISS index...")
14
  df = pd.read_csv("retrieval_corpus.csv")
15
  index = faiss.read_index("faiss_index.bin")
16
 
17
- print("Loading embedding model...")
 
 
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
- # ===============================
21
- # Load LLM on CPU
22
- # ===============================
23
- model_id = "BioMistral/BioMistral-7B"
24
 
25
- print(f"Loading tokenizer and model: {model_id}")
26
  tokenizer = AutoTokenizer.from_pretrained(model_id)
27
- model = AutoModelForCausalLM.from_pretrained(
28
- model_id,
29
- torch_dtype=torch.float16,
30
- low_cpu_mem_usage=True,
31
- ).to("cpu")
32
-
33
- tokenizer.pad_token = tokenizer.eos_token
34
-
35
- # ===============================
36
- # RAG Pipeline
37
- # ===============================
38
- def get_top_k_chunks(query, k=5):
39
- query_embedding = embedding_model.encode([query])
40
- scores, indices = index.search(np.array(query_embedding).astype("float32"), k)
41
- return df.iloc[indices[0]]["text"].tolist()
42
-
43
- def build_prompt(query, chunks):
44
- context = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks))
45
- prompt = (
46
- "You are a clinical reasoning assistant. Based on the following medical information, "
47
- "answer the query with a detailed explanation.\n\n"
48
- f"Context:\n{context}\n\n"
49
- f"Query: {query}\n"
50
- "Answer:"
51
- )
 
 
 
52
  return prompt
53
 
54
- def generate_diagnosis(query):
55
- chunks = get_top_k_chunks(query)
56
- prompt = build_prompt(query, chunks)
57
-
58
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
59
- input_ids = inputs.input_ids.to("cpu")
60
-
61
- with torch.no_grad():
62
- output = model.generate(
63
- input_ids=input_ids,
64
- max_new_tokens=256,
65
- do_sample=True,
66
- top_k=50,
67
- top_p=0.95,
68
- temperature=0.7,
69
- pad_token_id=tokenizer.eos_token_id
70
- )
71
-
72
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
73
- answer = generated_text.split("Answer:")[-1].strip()
74
- return answer, "\n\n".join(chunks)
75
-
76
- # ===============================
77
- # Gradio UI
78
- # ===============================
79
- def run_interface():
80
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
81
- gr.Markdown("## 🧠 Clinical Diagnosis Assistant (RAG)")
82
- gr.Markdown("Enter a clinical query. The assistant retrieves relevant medical facts and generates a diagnostic explanation.")
83
-
84
- with gr.Row():
85
- query_input = gr.Textbox(label="Clinical Query", placeholder="e.g. 65-year-old male with shortness of breath...")
86
- generate_btn = gr.Button("Generate Diagnosis")
87
-
88
- with gr.Accordion("📄 Retrieved Context", open=False):
89
- context_output = gr.Textbox(label="Top-5 Retrieved Chunks", lines=10, interactive=False)
90
-
91
- answer_output = gr.Textbox(label="Generated Diagnosis", lines=8)
92
-
93
- generate_btn.click(
94
- fn=generate_diagnosis,
95
- inputs=query_input,
96
- outputs=[answer_output, context_output]
97
- )
98
-
99
- return demo
100
-
101
- # ===============================
102
- # Launch App
103
- # ===============================
104
- if __name__ == "__main__":
105
- demo = run_interface()
106
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import faiss
 
4
  import numpy as np
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  from sentence_transformers import SentenceTransformer
 
8
 
9
+ # ----------------------
10
+ # Load Retrieval Corpus & FAISS Index
11
+ # ----------------------
 
12
  df = pd.read_csv("retrieval_corpus.csv")
13
  index = faiss.read_index("faiss_index.bin")
14
 
15
+ # ----------------------
16
+ # Load Embedding Model
17
+ # ----------------------
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
+ # ----------------------
21
+ # Load Lightweight HuggingFace Model (FLAN-T5-Base)
22
+ # ----------------------
23
+ model_id = "google/flan-t5-base"
24
 
 
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ generation_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
27
+
28
+ # ----------------------
29
+ # RAG Functions
30
+ # ----------------------
31
+
32
+ def retrieve_top_k(query, k=5):
33
+ query_embedding = embedding_model.encode([query]).astype("float32")
34
+ D, I = index.search(query_embedding, k)
35
+ results = df.iloc[I[0]].copy()
36
+ results["score"] = D[0]
37
+ return results
38
+
39
+ def build_prompt(query, retrieved_docs):
40
+ context_text = "\n".join([
41
+ f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
42
+ ])
43
+
44
+ prompt = f"""You are a medical assistant trained on clinical reasoning data.
45
+ Given the following patient query and related clinical observations, generate a diagnostic explanation.
46
+
47
+ Patient Query:
48
+ {query}
49
+
50
+ Clinical Context:
51
+ {context_text}
52
+
53
+ Diagnostic Explanation:"""
54
  return prompt
55
 
56
+ def generate_local_answer(prompt, max_new_tokens=256):
57
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids # CPU only
58
+ output = generation_model.generate(
59
+ input_ids=input_ids,
60
+ max_new_tokens=max_new_tokens,
61
+ temperature=0.7,
62
+ do_sample=True,
63
+ top_k=50,
64
+ top_p=0.95,
65
+ )
66
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
67
+ return decoded.strip()
68
+
69
+ # ----------------------
70
+ # Gradio Interface
71
+ # ----------------------
72
+
73
+ def rag_chat(query):
74
+ top_docs = retrieve_top_k(query, k=5)
75
+ prompt = build_prompt(query, top_docs)
76
+ answer = generate_local_answer(prompt)
77
+ return answer
78
+
79
+ # Optional: basic CSS to enhance layout
80
+ custom_css = """
81
+ textarea, .input_textbox {
82
+ font-size: 1.05rem !important;
83
+ }
84
+ .output-markdown {
85
+ font-size: 1.08rem !important;
86
+ }
87
+ """
88
+
89
+ with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue="blue")) as demo:
90
+ gr.Markdown("""
91
+ # 🩺 RAGnosis Clinical Reasoning Assistant
92
+
93
+ Enter a natural-language query describing your patient's condition to receive an AI-generated diagnostic reasoning response.
94
+
95
+ **Example:**
96
+ *Patient has shortness of breath, fatigue, and leg swelling.*
97
+ """)
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ query_input = gr.Textbox(
102
+ lines=4,
103
+ label="📝 Patient Query",
104
+ placeholder="Enter patient symptoms or findings..."
105
+ )
106
+ submit_btn = gr.Button("🔍 Generate Diagnosis")
107
+
108
+ with gr.Column():
109
+ output = gr.Markdown(label="🧠 Diagnostic Reasoning")
110
+
111
+ submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
112
+
113
+ demo.launch(share=True)