asadsandhu commited on
Commit
b651070
·
1 Parent(s): a73c563
Files changed (1) hide show
  1. app.py +23 -61
app.py CHANGED
@@ -1,90 +1,52 @@
1
  import gradio as gr
2
- import pandas as pd
3
- import faiss
4
- import time
5
- import numpy as np
6
- import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
8
  from sentence_transformers import SentenceTransformer
9
 
10
- # Load retrieval corpus & FAISS index
11
  df = pd.read_csv("retrieval_corpus.csv")
12
  index = faiss.read_index("faiss_index.bin")
13
-
14
- # Load embedding model
15
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
16
- model_id = "stanford-crfm/BioMedLM"
17
 
 
 
18
  bnb_config = BitsAndBytesConfig(
19
  load_in_8bit=True,
20
  llm_int8_threshold=6.0,
 
21
  )
22
-
23
  tokenizer = AutoTokenizer.from_pretrained(model_id)
24
  tokenizer.pad_token = tokenizer.eos_token
25
-
26
  generation_model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
28
- device_map="auto",
29
  quantization_config=bnb_config,
 
30
  )
31
 
32
- def retrieve_top_k(query, k=5):
33
- query_embedding = embedding_model.encode([query]).astype("float32")
34
- D, I = index.search(query_embedding, k)
35
- results = df.iloc[I[0]].copy()
36
- results["score"] = D[0]
37
- return results
38
-
39
- def build_prompt(query, retrieved_docs):
40
- context_text = "\n".join([f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()])
41
- return f"""[INST] <<SYS>>
42
- You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
43
- <</SYS>>
44
-
45
- ### Patient Query:
46
- {query}
47
 
48
- ### Clinical Context:
49
- {context_text}
50
-
51
- ### Diagnostic Explanation:
52
- [/INST]
53
- """
54
 
55
  def generate_local_answer(prompt, max_new_tokens=512):
 
56
  device = torch.device("cpu")
57
- print(f"Using device: {device}")
58
  start = time.time()
59
-
60
- inputs = tokenizer(prompt, return_tensors="pt", padding=True)
61
- input_ids = inputs["input_ids"].to(device)
62
- attention_mask = inputs["attention_mask"].to(device)
63
-
64
- output = generation_model.generate(
65
- input_ids=input_ids,
66
- attention_mask=attention_mask,
67
  max_new_tokens=max_new_tokens,
68
- do_sample=False, # ← GREEDY
69
  num_beams=1,
70
  )
 
 
71
 
72
- decoded = tokenizer.decode(output[0], skip_special_tokens=True)
73
- print(f"Time taken: {time.time() - start:.2f}s")
74
- return decoded.split("### Diagnostic Explanation:")[-1].strip()
75
-
76
- def rag_chat(query):
77
- top_docs = retrieve_top_k(query, k=5)
78
- prompt = build_prompt(query, top_docs)
79
- return generate_local_answer(prompt)
80
-
81
- iface = gr.Interface(
82
- fn=rag_chat,
83
- inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."),
84
- outputs="text",
85
- title="🩺 Clinical Reasoning RAG Assistant",
86
- description="Ask a medical question based on MIMIC‑IV‑Ext‑DiReCT’s diagnostic knowledge.",
87
- allow_flagging="never"
88
- )
89
-
90
  iface.launch()
 
1
  import gradio as gr
2
+ import pandas as pd, faiss, torch
 
 
 
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  from sentence_transformers import SentenceTransformer
5
 
6
+ # —— Load data & embedding model ——
7
  df = pd.read_csv("retrieval_corpus.csv")
8
  index = faiss.read_index("faiss_index.bin")
 
 
9
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
 
10
 
11
+ # —— Quantized BioMedLM with CPU offload ——
12
+ model_id = "stanford-crfm/BioMedLM"
13
  bnb_config = BitsAndBytesConfig(
14
  load_in_8bit=True,
15
  llm_int8_threshold=6.0,
16
+ llm_int8_enable_fp32_cpu_offload=True,
17
  )
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_id)
19
  tokenizer.pad_token = tokenizer.eos_token
 
20
  generation_model = AutoModelForCausalLM.from_pretrained(
21
  model_id,
 
22
  quantization_config=bnb_config,
23
+ device_map={"": "cpu"},
24
  )
25
 
26
+ def retrieve_top_k(q, k=5):
27
+ emb = embedding_model.encode([q]).astype("float32")
28
+ D,I = index.search(emb, k)
29
+ res = df.iloc[I[0]].copy(); res["score"]=D[0]; return res
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def build_prompt(q, docs):
32
+ ctx = "\n".join(f"- {d['text']}" for _,d in docs.iterrows())
33
+ return f"""[INST] <<SYS>>…[/INST]""" # your existing template
 
 
 
34
 
35
  def generate_local_answer(prompt, max_new_tokens=512):
36
+ import time
37
  device = torch.device("cpu")
 
38
  start = time.time()
39
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
40
+ out = generation_model.generate(
41
+ input_ids=inputs.input_ids,
42
+ attention_mask=inputs.attention_mask,
 
 
 
 
43
  max_new_tokens=max_new_tokens,
44
+ do_sample=False,
45
  num_beams=1,
46
  )
47
+ print(f"Gen time: {time.time()-start:.2f}s")
48
+ return tokenizer.decode(out[0], skip_special_tokens=True)
49
 
50
+ iface = gr.Interface(fn=lambda q: generate_local_answer(build_prompt(q, retrieve_top_k(q))),
51
+ inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  iface.launch()