Spaces:
Sleeping
Sleeping
Commit
·
86d4de7
1
Parent(s):
ffa32ca
Modified.
Browse files
app.py
CHANGED
@@ -67,18 +67,25 @@ You are a medical assistant trained on clinical reasoning data. Given the follow
|
|
67 |
"""
|
68 |
return prompt
|
69 |
|
70 |
-
# ✅ FIXED generate_local_answer
|
71 |
def generate_local_answer(prompt, max_new_tokens=512):
|
72 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
output = generation_model.generate(
|
75 |
input_ids=input_ids,
|
|
|
76 |
max_new_tokens=max_new_tokens,
|
77 |
temperature=0.5,
|
78 |
do_sample=True,
|
79 |
top_k=50,
|
80 |
top_p=0.95,
|
81 |
)
|
|
|
82 |
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
83 |
return decoded.split("### Diagnostic Explanation:")[-1].strip()
|
84 |
|
@@ -126,5 +133,4 @@ Enter a natural-language query describing your patient's condition to receive an
|
|
126 |
|
127 |
submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
|
128 |
|
129 |
-
# ✅ Use `share=False` inside Hugging Face Spaces
|
130 |
demo.launch(share=False)
|
|
|
67 |
"""
|
68 |
return prompt
|
69 |
|
|
|
70 |
def generate_local_answer(prompt, max_new_tokens=512):
|
71 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
72 |
+
|
73 |
+
# Tokenize with attention mask
|
74 |
+
tokens = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
75 |
+
input_ids = tokens["input_ids"].to(device)
|
76 |
+
attention_mask = tokens["attention_mask"].to(device)
|
77 |
+
|
78 |
+
# Generate
|
79 |
output = generation_model.generate(
|
80 |
input_ids=input_ids,
|
81 |
+
attention_mask=attention_mask,
|
82 |
max_new_tokens=max_new_tokens,
|
83 |
temperature=0.5,
|
84 |
do_sample=True,
|
85 |
top_k=50,
|
86 |
top_p=0.95,
|
87 |
)
|
88 |
+
|
89 |
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
90 |
return decoded.split("### Diagnostic Explanation:")[-1].strip()
|
91 |
|
|
|
133 |
|
134 |
submit_btn.click(fn=rag_chat, inputs=query_input, outputs=output)
|
135 |
|
|
|
136 |
demo.launch(share=False)
|