asadsandhu commited on
Commit
dd74b32
·
1 Parent(s): a2cbc8f

Model Changes.

Browse files
Files changed (1) hide show
  1. app.py +17 -47
app.py CHANGED
@@ -3,45 +3,23 @@ import pandas as pd
3
  import faiss
4
  import numpy as np
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from sentence_transformers import SentenceTransformer
8
 
9
- # ----------------------
10
- # Load Retrieval Corpus & FAISS Index
11
- # ----------------------
12
  df = pd.read_csv("retrieval_corpus.csv")
13
  index = faiss.read_index("faiss_index.bin")
14
 
15
- # ----------------------
16
- # Load Embedding Model
17
- # ----------------------
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
- # ----------------------
21
- # Load HuggingFace LLM (Nous-Hermes)
22
- # ----------------------
23
- model_id = "BioMistral/BioMistral-7B"
24
-
25
- bnb_config = BitsAndBytesConfig(
26
- load_in_4bit=True,
27
- bnb_4bit_use_double_quant=True,
28
- bnb_4bit_quant_type="nf4",
29
- bnb_4bit_compute_dtype=torch.float16,
30
- )
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
33
- tokenizer.pad_token = tokenizer.eos_token
34
- tokenizer.save_pretrained("fixed_tokenizer")
35
- generation_model = AutoModelForCausalLM.from_pretrained(
36
- model_id,
37
- torch_dtype=torch.float16,
38
- device_map="auto",
39
- quantization_config=bnb_config
40
- )
41
 
42
- # ----------------------
43
- # RAG Functions
44
- # ----------------------
45
 
46
  def retrieve_top_k(query, k=5):
47
  query_embedding = embedding_model.encode([query]).astype("float32")
@@ -51,11 +29,8 @@ def retrieve_top_k(query, k=5):
51
  return results
52
 
53
  def build_prompt(query, retrieved_docs):
54
- context_text = "\n".join([
55
- f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
56
- ])
57
-
58
- prompt = f"""[INST] <<SYS>>
59
  You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
60
  <</SYS>>
61
 
@@ -68,14 +43,14 @@ You are a medical assistant trained on clinical reasoning data. Given the follow
68
  ### Diagnostic Explanation:
69
  [/INST]
70
  """
71
- return prompt
72
 
73
  def generate_local_answer(prompt, max_new_tokens=512):
74
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
  print(f"Using device: {device}")
76
- inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
77
- input_ids = inputs["input_ids"]
78
- attention_mask = inputs["attention_mask"]
 
79
  output = generation_model.generate(
80
  input_ids=input_ids,
81
  attention_mask=attention_mask,
@@ -88,23 +63,18 @@ def generate_local_answer(prompt, max_new_tokens=512):
88
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
89
  return decoded.split("### Diagnostic Explanation:")[-1].strip()
90
 
91
- # ----------------------
92
- # Gradio Interface
93
- # ----------------------
94
-
95
  def rag_chat(query):
96
  top_docs = retrieve_top_k(query, k=5)
97
  prompt = build_prompt(query, top_docs)
98
- answer = generate_local_answer(prompt)
99
- return answer
100
 
101
  iface = gr.Interface(
102
  fn=rag_chat,
103
  inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."),
104
  outputs="text",
105
  title="🩺 Clinical Reasoning RAG Assistant",
106
- description="Ask a medical question based on MIMIC-IV-Ext-DiReCT's diagnostic knowledge.",
107
  allow_flagging="never"
108
  )
109
 
110
- iface.launch()
 
3
  import faiss
4
  import numpy as np
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
 
9
+ # Load retrieval corpus & FAISS index
 
 
10
  df = pd.read_csv("retrieval_corpus.csv")
11
  index = faiss.read_index("faiss_index.bin")
12
 
13
+ # Load embedding model
 
 
14
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
15
 
16
+ # Swap to BioMedLM 2.7B (CPU-friendly biomedical model)
17
+ model_id = "stanford-crfm/BioMedLM"
 
 
 
 
 
 
 
 
 
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+ tokenizer.pad_token = tokenizer.eos_token # fix padding issue
 
 
 
 
 
 
 
21
 
22
+ generation_model = AutoModelForCausalLM.from_pretrained(model_id)
 
 
23
 
24
  def retrieve_top_k(query, k=5):
25
  query_embedding = embedding_model.encode([query]).astype("float32")
 
29
  return results
30
 
31
  def build_prompt(query, retrieved_docs):
32
+ context_text = "\n".join([f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()])
33
+ return f"""[INST] <<SYS>>
 
 
 
34
  You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
35
  <</SYS>>
36
 
 
43
  ### Diagnostic Explanation:
44
  [/INST]
45
  """
 
46
 
47
  def generate_local_answer(prompt, max_new_tokens=512):
48
+ device = torch.device("cpu")
49
  print(f"Using device: {device}")
50
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
51
+ input_ids = inputs["input_ids"].to(device)
52
+ attention_mask = inputs["attention_mask"].to(device)
53
+
54
  output = generation_model.generate(
55
  input_ids=input_ids,
56
  attention_mask=attention_mask,
 
63
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
64
  return decoded.split("### Diagnostic Explanation:")[-1].strip()
65
 
 
 
 
 
66
  def rag_chat(query):
67
  top_docs = retrieve_top_k(query, k=5)
68
  prompt = build_prompt(query, top_docs)
69
+ return generate_local_answer(prompt)
 
70
 
71
  iface = gr.Interface(
72
  fn=rag_chat,
73
  inputs=gr.Textbox(lines=3, placeholder="Enter a clinical query..."),
74
  outputs="text",
75
  title="🩺 Clinical Reasoning RAG Assistant",
76
+ description="Ask a medical question based on MIMICIVExtDiReCTs diagnostic knowledge.",
77
  allow_flagging="never"
78
  )
79
 
80
+ iface.launch()