asadsandhu commited on
Commit
03adce4
·
1 Parent(s): 0619f86

Model Changes. Light weight model used.

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -18,9 +18,9 @@ index = faiss.read_index("faiss_index.bin")
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
  # ----------------------
21
- # Load HuggingFace LLM (Nous-Hermes)
22
  # ----------------------
23
- model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO"
24
 
25
  bnb_config = BitsAndBytesConfig(
26
  load_in_4bit=True,
@@ -29,14 +29,15 @@ bnb_config = BitsAndBytesConfig(
29
  bnb_4bit_compute_dtype=torch.float16,
30
  )
31
 
32
- tokenizer = AutoTokenizer.from_pretrained(model_id)
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
 
 
 
35
  generation_model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
 
37
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
- device_map="auto" if torch.cuda.is_available() else None,
39
- quantization_config=bnb_config if torch.cuda.is_available() else None
40
  ).to(device)
41
 
42
  # ----------------------
@@ -54,7 +55,7 @@ def build_prompt(query, retrieved_docs):
54
  context_text = "\n".join([
55
  f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
56
  ])
57
-
58
  prompt = f"""[INST] <<SYS>>
59
  You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
60
  <</SYS>>
@@ -71,7 +72,6 @@ You are a medical assistant trained on clinical reasoning data. Given the follow
71
  return prompt
72
 
73
  def generate_local_answer(prompt, max_new_tokens=512):
74
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
76
  output = generation_model.generate(
77
  input_ids=input_ids,
@@ -94,7 +94,7 @@ def rag_chat(query):
94
  answer = generate_local_answer(prompt)
95
  return answer
96
 
97
- # Optional: basic CSS to enhance layout
98
  custom_css = """
99
  textarea, .input_textbox {
100
  font-size: 1.05rem !important;
@@ -128,4 +128,4 @@ Enter a natural-language query describing your patient's condition to receive an
128
 
129
  submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
130
 
131
- demo.launch(share=True)
 
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
 
20
  # ----------------------
21
+ # Load HuggingFace LLM (BioMistral-7B)
22
  # ----------------------
23
+ model_id = "royalhaze/BioMistral-7B"
24
 
25
  bnb_config = BitsAndBytesConfig(
26
  load_in_4bit=True,
 
29
  bnb_4bit_compute_dtype=torch.float16,
30
  )
31
 
 
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
 
34
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
35
+
36
  generation_model = AutoModelForCausalLM.from_pretrained(
37
  model_id,
38
+ quantization_config=bnb_config if torch.cuda.is_available() else None,
39
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
+ device_map="auto" if torch.cuda.is_available() else None
 
41
  ).to(device)
42
 
43
  # ----------------------
 
55
  context_text = "\n".join([
56
  f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
57
  ])
58
+
59
  prompt = f"""[INST] <<SYS>>
60
  You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
61
  <</SYS>>
 
72
  return prompt
73
 
74
  def generate_local_answer(prompt, max_new_tokens=512):
 
75
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
76
  output = generation_model.generate(
77
  input_ids=input_ids,
 
94
  answer = generate_local_answer(prompt)
95
  return answer
96
 
97
+ # Optional: CSS for improved UX
98
  custom_css = """
99
  textarea, .input_textbox {
100
  font-size: 1.05rem !important;
 
128
 
129
  submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
130
 
131
+ demo.launch()