asadsandhu commited on
Commit
7ab76b7
·
1 Parent(s): 56dc0cd

Model changed.

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -18,26 +18,24 @@ index = faiss.read_index("faiss_index.bin")
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
  # ----------------------
21
- # Load HuggingFace LLM (BioMistral-7B)
22
  # ----------------------
23
  model_id = "BioMistral/BioMistral-7B"
 
24
  bnb_config = BitsAndBytesConfig(
25
  load_in_4bit=True,
26
- bnb_4bit_quant_type="nf4",
27
  bnb_4bit_use_double_quant=True,
 
28
  bnb_4bit_compute_dtype=torch.float16,
29
  )
30
 
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
-
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
-
35
  generation_model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
- quantization_config=bnb_config,
38
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
39
- device_map="auto" if torch.cuda.is_available() else None
40
- ).to(device)
41
 
42
  # ----------------------
43
  # RAG Functions
@@ -54,7 +52,7 @@ def build_prompt(query, retrieved_docs):
54
  context_text = "\n".join([
55
  f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
56
  ])
57
-
58
  prompt = f"""[INST] <<SYS>>
59
  You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
60
  <</SYS>>
@@ -71,7 +69,7 @@ You are a medical assistant trained on clinical reasoning data. Given the follow
71
  return prompt
72
 
73
  def generate_local_answer(prompt, max_new_tokens=512):
74
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
75
  output = generation_model.generate(
76
  input_ids=input_ids,
77
  max_new_tokens=max_new_tokens,
@@ -93,7 +91,7 @@ def rag_chat(query):
93
  answer = generate_local_answer(prompt)
94
  return answer
95
 
96
- # Optional: CSS for improved UX
97
  custom_css = """
98
  textarea, .input_textbox {
99
  font-size: 1.05rem !important;
@@ -127,4 +125,4 @@ Enter a natural-language query describing your patient's condition to receive an
127
 
128
  submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
129
 
130
- demo.launch()
 
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
  # ----------------------
21
+ # Load HuggingFace LLM (Nous-Hermes)
22
  # ----------------------
23
  model_id = "BioMistral/BioMistral-7B"
24
+
25
  bnb_config = BitsAndBytesConfig(
26
  load_in_4bit=True,
 
27
  bnb_4bit_use_double_quant=True,
28
+ bnb_4bit_quant_type="nf4",
29
  bnb_4bit_compute_dtype=torch.float16,
30
  )
31
 
 
 
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
33
  generation_model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
+ torch_dtype=torch.float16,
36
+ device_map="auto",
37
+ quantization_config=bnb_config
38
+ )
39
 
40
  # ----------------------
41
  # RAG Functions
 
52
  context_text = "\n".join([
53
  f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
54
  ])
55
+
56
  prompt = f"""[INST] <<SYS>>
57
  You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
58
  <</SYS>>
 
69
  return prompt
70
 
71
  def generate_local_answer(prompt, max_new_tokens=512):
72
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
73
  output = generation_model.generate(
74
  input_ids=input_ids,
75
  max_new_tokens=max_new_tokens,
 
91
  answer = generate_local_answer(prompt)
92
  return answer
93
 
94
+ # Optional: basic CSS to enhance layout
95
  custom_css = """
96
  textarea, .input_textbox {
97
  font-size: 1.05rem !important;
 
125
 
126
  submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
127
 
128
+ demo.launch(share=True)