Spaces:
Sleeping
Sleeping
Commit
·
03adce4
1
Parent(s):
0619f86
Model Changes. Light weight model used.
Browse files
app.py
CHANGED
@@ -18,9 +18,9 @@ index = faiss.read_index("faiss_index.bin")
|
|
18 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
19 |
|
20 |
# ----------------------
|
21 |
-
# Load HuggingFace LLM (
|
22 |
# ----------------------
|
23 |
-
model_id = "
|
24 |
|
25 |
bnb_config = BitsAndBytesConfig(
|
26 |
load_in_4bit=True,
|
@@ -29,14 +29,15 @@ bnb_config = BitsAndBytesConfig(
|
|
29 |
bnb_4bit_compute_dtype=torch.float16,
|
30 |
)
|
31 |
|
32 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
33 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
|
|
|
|
|
35 |
generation_model = AutoModelForCausalLM.from_pretrained(
|
36 |
model_id,
|
|
|
37 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
38 |
-
device_map="auto" if torch.cuda.is_available() else None
|
39 |
-
quantization_config=bnb_config if torch.cuda.is_available() else None
|
40 |
).to(device)
|
41 |
|
42 |
# ----------------------
|
@@ -54,7 +55,7 @@ def build_prompt(query, retrieved_docs):
|
|
54 |
context_text = "\n".join([
|
55 |
f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
|
56 |
])
|
57 |
-
|
58 |
prompt = f"""[INST] <<SYS>>
|
59 |
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
|
60 |
<</SYS>>
|
@@ -71,7 +72,6 @@ You are a medical assistant trained on clinical reasoning data. Given the follow
|
|
71 |
return prompt
|
72 |
|
73 |
def generate_local_answer(prompt, max_new_tokens=512):
|
74 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
75 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
76 |
output = generation_model.generate(
|
77 |
input_ids=input_ids,
|
@@ -94,7 +94,7 @@ def rag_chat(query):
|
|
94 |
answer = generate_local_answer(prompt)
|
95 |
return answer
|
96 |
|
97 |
-
# Optional:
|
98 |
custom_css = """
|
99 |
textarea, .input_textbox {
|
100 |
font-size: 1.05rem !important;
|
@@ -128,4 +128,4 @@ Enter a natural-language query describing your patient's condition to receive an
|
|
128 |
|
129 |
submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
|
130 |
|
131 |
-
demo.launch(
|
|
|
18 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
19 |
|
20 |
# ----------------------
|
21 |
+
# Load HuggingFace LLM (BioMistral-7B)
|
22 |
# ----------------------
|
23 |
+
model_id = "royalhaze/BioMistral-7B"
|
24 |
|
25 |
bnb_config = BitsAndBytesConfig(
|
26 |
load_in_4bit=True,
|
|
|
29 |
bnb_4bit_compute_dtype=torch.float16,
|
30 |
)
|
31 |
|
|
|
32 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33 |
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
35 |
+
|
36 |
generation_model = AutoModelForCausalLM.from_pretrained(
|
37 |
model_id,
|
38 |
+
quantization_config=bnb_config if torch.cuda.is_available() else None,
|
39 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
40 |
+
device_map="auto" if torch.cuda.is_available() else None
|
|
|
41 |
).to(device)
|
42 |
|
43 |
# ----------------------
|
|
|
55 |
context_text = "\n".join([
|
56 |
f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
|
57 |
])
|
58 |
+
|
59 |
prompt = f"""[INST] <<SYS>>
|
60 |
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
|
61 |
<</SYS>>
|
|
|
72 |
return prompt
|
73 |
|
74 |
def generate_local_answer(prompt, max_new_tokens=512):
|
|
|
75 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
76 |
output = generation_model.generate(
|
77 |
input_ids=input_ids,
|
|
|
94 |
answer = generate_local_answer(prompt)
|
95 |
return answer
|
96 |
|
97 |
+
# Optional: CSS for improved UX
|
98 |
custom_css = """
|
99 |
textarea, .input_textbox {
|
100 |
font-size: 1.05rem !important;
|
|
|
128 |
|
129 |
submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
|
130 |
|
131 |
+
demo.launch()
|