asadsandhu commited on
Commit
e531b46
Β·
1 Parent(s): b651070

Updated model.

Browse files
Files changed (1) hide show
  1. app.py +91 -40
app.py CHANGED
@@ -1,52 +1,103 @@
1
  import gradio as gr
2
- import pandas as pd, faiss, torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
 
 
4
  from sentence_transformers import SentenceTransformer
 
5
 
6
- # β€”β€” Load data & embedding model β€”β€”
 
 
 
7
  df = pd.read_csv("retrieval_corpus.csv")
8
  index = faiss.read_index("faiss_index.bin")
 
 
9
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
10
 
11
- # β€”β€” Quantized BioMedLM with CPU offload β€”β€”
12
- model_id = "stanford-crfm/BioMedLM"
13
- bnb_config = BitsAndBytesConfig(
14
- load_in_8bit=True,
15
- llm_int8_threshold=6.0,
16
- llm_int8_enable_fp32_cpu_offload=True,
17
- )
18
- tokenizer = AutoTokenizer.from_pretrained(model_id)
19
- tokenizer.pad_token = tokenizer.eos_token
20
- generation_model = AutoModelForCausalLM.from_pretrained(
21
  model_id,
22
- quantization_config=bnb_config,
23
- device_map={"": "cpu"},
24
  )
 
 
 
 
 
 
 
 
 
25
 
26
- def retrieve_top_k(q, k=5):
27
- emb = embedding_model.encode([q]).astype("float32")
28
- D,I = index.search(emb, k)
29
- res = df.iloc[I[0]].copy(); res["score"]=D[0]; return res
30
-
31
- def build_prompt(q, docs):
32
- ctx = "\n".join(f"- {d['text']}" for _,d in docs.iterrows())
33
- return f"""[INST] <<SYS>>…[/INST]""" # your existing template
34
-
35
- def generate_local_answer(prompt, max_new_tokens=512):
36
- import time
37
- device = torch.device("cpu")
38
- start = time.time()
39
- inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
40
- out = generation_model.generate(
41
- input_ids=inputs.input_ids,
42
- attention_mask=inputs.attention_mask,
43
- max_new_tokens=max_new_tokens,
44
- do_sample=False,
45
- num_beams=1,
46
  )
47
- print(f"Gen time: {time.time()-start:.2f}s")
48
- return tokenizer.decode(out[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- iface = gr.Interface(fn=lambda q: generate_local_answer(build_prompt(q, retrieve_top_k(q))),
51
- inputs="text", outputs="text")
52
- iface.launch()
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import faiss
4
+ import torch
5
+ import numpy as np
6
+
7
  from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
 
10
+ # ===============================
11
+ # Load Retrieval Components
12
+ # ===============================
13
+ print("Loading corpus and FAISS index...")
14
  df = pd.read_csv("retrieval_corpus.csv")
15
  index = faiss.read_index("faiss_index.bin")
16
+
17
+ print("Loading embedding model...")
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
+ # ===============================
21
+ # Load LLM on CPU
22
+ # ===============================
23
+ model_id = "PrunaAI/BioMistral-7B-bnb-8bit-smashed"
24
+ tokenizer = AutoTokenizer.from_pretrained("BioMistral/BioMistral-7B")
25
+ model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
26
  model_id,
27
+ trust_remote_code=True,
28
+ device_map=None, # CPU only
29
  )
30
+ tokenizer.pad_token = tokenizer.eos_token
31
+
32
+ # ===============================
33
+ # RAG Pipeline
34
+ # ===============================
35
+ def get_top_k_chunks(query, k=5):
36
+ query_embedding = embedding_model.encode([query])
37
+ scores, indices = index.search(np.array(query_embedding).astype("float32"), k)
38
+ return df.iloc[indices[0]]["text"].tolist()
39
 
40
+ def build_prompt(query, chunks):
41
+ context = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks))
42
+ prompt = (
43
+ "You are a clinical reasoning assistant. Based on the following medical information, "
44
+ "answer the query with a detailed explanation.\n\n"
45
+ f"Context:\n{context}\n\n"
46
+ f"Query: {query}\n"
47
+ "Answer:"
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
+ return prompt
50
+
51
+ def generate_diagnosis(query):
52
+ chunks = get_top_k_chunks(query)
53
+ prompt = build_prompt(query, chunks)
54
+
55
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
56
+ input_ids = inputs.input_ids.to("cpu")
57
+
58
+ with torch.no_grad():
59
+ output = model.generate(
60
+ input_ids=input_ids,
61
+ max_new_tokens=256,
62
+ do_sample=True,
63
+ top_k=50,
64
+ top_p=0.95,
65
+ temperature=0.7,
66
+ pad_token_id=tokenizer.eos_token_id
67
+ )
68
+
69
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
70
+ answer = generated_text.split("Answer:")[-1].strip()
71
+ return answer, "\n\n".join(chunks)
72
+
73
+ # ===============================
74
+ # Gradio UI
75
+ # ===============================
76
+ def run_interface():
77
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
78
+ gr.Markdown("## 🧠 Clinical Diagnosis Assistant (RAG)")
79
+ gr.Markdown("Enter a clinical query. The assistant retrieves relevant medical facts and generates a diagnostic explanation.")
80
+
81
+ with gr.Row():
82
+ query_input = gr.Textbox(label="Clinical Query", placeholder="e.g. 65-year-old male with shortness of breath...")
83
+ generate_btn = gr.Button("Generate Diagnosis")
84
+
85
+ with gr.Accordion("πŸ“„ Retrieved Context", open=False):
86
+ context_output = gr.Textbox(label="Top-5 Retrieved Chunks", lines=10, interactive=False)
87
+
88
+ answer_output = gr.Textbox(label="Generated Diagnosis", lines=8)
89
+
90
+ generate_btn.click(
91
+ fn=generate_diagnosis,
92
+ inputs=query_input,
93
+ outputs=[answer_output, context_output]
94
+ )
95
+
96
+ return demo
97
 
98
+ # ===============================
99
+ # Launch App
100
+ # ===============================
101
+ if __name__ == "__main__":
102
+ demo = run_interface()
103
+ demo.launch()