Spaces:
Running
Running
Commit
·
b651070
1
Parent(s):
a73c563
fixed.
Browse files
app.py
CHANGED
@@ -1,90 +1,52 @@
|
|
1 |
import gradio as gr
|
2 |
-
import pandas as pd
|
3 |
-
import faiss
|
4 |
-
import time
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
|
10 |
-
# Load
|
11 |
df = pd.read_csv("retrieval_corpus.csv")
|
12 |
index = faiss.read_index("faiss_index.bin")
|
13 |
-
|
14 |
-
# Load embedding model
|
15 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
16 |
-
model_id = "stanford-crfm/BioMedLM"
|
17 |
|
|
|
|
|
18 |
bnb_config = BitsAndBytesConfig(
|
19 |
load_in_8bit=True,
|
20 |
llm_int8_threshold=6.0,
|
|
|
21 |
)
|
22 |
-
|
23 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
24 |
tokenizer.pad_token = tokenizer.eos_token
|
25 |
-
|
26 |
generation_model = AutoModelForCausalLM.from_pretrained(
|
27 |
model_id,
|
28 |
-
device_map="auto",
|
29 |
quantization_config=bnb_config,
|
|
|
30 |
)
|
31 |
|
32 |
-
def retrieve_top_k(
|
33 |
-
|
34 |
-
D,
|
35 |
-
|
36 |
-
results["score"] = D[0]
|
37 |
-
return results
|
38 |
-
|
39 |
-
def build_prompt(query, retrieved_docs):
|
40 |
-
context_text = "\n".join([f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()])
|
41 |
-
return f"""[INST] <<SYS>>
|
42 |
-
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
|
43 |
-
<</SYS>>
|
44 |
-
|
45 |
-
### Patient Query:
|
46 |
-
{query}
|
47 |
|
48 |
-
|
49 |
-
{
|
50 |
-
|
51 |
-
### Diagnostic Explanation:
|
52 |
-
[/INST]
|
53 |
-
"""
|
54 |
|
55 |
def generate_local_answer(prompt, max_new_tokens=512):
|
|
|
56 |
device = torch.device("cpu")
|
57 |
-
print(f"Using device: {device}")
|
58 |
start = time.time()
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
output = generation_model.generate(
|
65 |
-
input_ids=input_ids,
|
66 |
-
attention_mask=attention_mask,
|
67 |
max_new_tokens=max_new_tokens,
|
68 |
-
do_sample=False,
|
69 |
num_beams=1,
|
70 |
)
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
return decoded.split("### Diagnostic Explanation:")[-1].strip()
|
75 |
-
|
76 |
-
def rag_chat(query):
|
77 |
-
top_docs = retrieve_top_k(query, k=5)
|
78 |
-
prompt = build_prompt(query, top_docs)
|
79 |
-
return generate_local_answer(prompt)
|
80 |
-
|
81 |
-
iface = gr.Interface(
|
82 |
-
fn=rag_chat,
|
83 |
-
inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."),
|
84 |
-
outputs="text",
|
85 |
-
title="🩺 Clinical Reasoning RAG Assistant",
|
86 |
-
description="Ask a medical question based on MIMIC‑IV‑Ext‑DiReCT’s diagnostic knowledge.",
|
87 |
-
allow_flagging="never"
|
88 |
-
)
|
89 |
-
|
90 |
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import pandas as pd, faiss, torch
|
|
|
|
|
|
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
|
6 |
+
# —— Load data & embedding model ——
|
7 |
df = pd.read_csv("retrieval_corpus.csv")
|
8 |
index = faiss.read_index("faiss_index.bin")
|
|
|
|
|
9 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
10 |
|
11 |
+
# —— Quantized BioMedLM with CPU offload ——
|
12 |
+
model_id = "stanford-crfm/BioMedLM"
|
13 |
bnb_config = BitsAndBytesConfig(
|
14 |
load_in_8bit=True,
|
15 |
llm_int8_threshold=6.0,
|
16 |
+
llm_int8_enable_fp32_cpu_offload=True,
|
17 |
)
|
|
|
18 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
19 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
20 |
generation_model = AutoModelForCausalLM.from_pretrained(
|
21 |
model_id,
|
|
|
22 |
quantization_config=bnb_config,
|
23 |
+
device_map={"": "cpu"},
|
24 |
)
|
25 |
|
26 |
+
def retrieve_top_k(q, k=5):
|
27 |
+
emb = embedding_model.encode([q]).astype("float32")
|
28 |
+
D,I = index.search(emb, k)
|
29 |
+
res = df.iloc[I[0]].copy(); res["score"]=D[0]; return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
def build_prompt(q, docs):
|
32 |
+
ctx = "\n".join(f"- {d['text']}" for _,d in docs.iterrows())
|
33 |
+
return f"""[INST] <<SYS>>…[/INST]""" # your existing template
|
|
|
|
|
|
|
34 |
|
35 |
def generate_local_answer(prompt, max_new_tokens=512):
|
36 |
+
import time
|
37 |
device = torch.device("cpu")
|
|
|
38 |
start = time.time()
|
39 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
|
40 |
+
out = generation_model.generate(
|
41 |
+
input_ids=inputs.input_ids,
|
42 |
+
attention_mask=inputs.attention_mask,
|
|
|
|
|
|
|
|
|
43 |
max_new_tokens=max_new_tokens,
|
44 |
+
do_sample=False,
|
45 |
num_beams=1,
|
46 |
)
|
47 |
+
print(f"Gen time: {time.time()-start:.2f}s")
|
48 |
+
return tokenizer.decode(out[0], skip_special_tokens=True)
|
49 |
|
50 |
+
iface = gr.Interface(fn=lambda q: generate_local_answer(build_prompt(q, retrieve_top_k(q))),
|
51 |
+
inputs="text", outputs="text")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
iface.launch()
|