Spaces:
Paused
Paused
Commit
·
34b3b8f
1
Parent(s):
d85b2a7
Use CUDA
Browse files- app.py +12 -1
- requirements.txt +1 -1
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
from transformers.utils import logging
|
4 |
|
@@ -8,10 +9,18 @@ logger = logging.get_logger("transformers")
|
|
8 |
|
9 |
# Load model directly from your Hugging Face repository
|
10 |
def load_model():
|
11 |
-
|
12 |
tokenizer = AutoTokenizer.from_pretrained("omi-health/sum-small", trust_remote_code=False)
|
13 |
model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
return model, tokenizer
|
16 |
|
17 |
def generate_soap_note(doctor_patient_conversation):
|
@@ -34,6 +43,8 @@ Please generate a structured SOAP (Subjective, Objective, Assessment, Plan) note
|
|
34 |
max_length=tokenizer.model_max_length
|
35 |
)
|
36 |
|
|
|
|
|
37 |
generate_ids = model.generate(
|
38 |
inputs.input_ids,
|
39 |
attention_mask=inputs.attention_mask, # Explicitly pass attention mask
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
from transformers.utils import logging
|
5 |
|
|
|
9 |
|
10 |
# Load model directly from your Hugging Face repository
|
11 |
def load_model():
|
|
|
12 |
tokenizer = AutoTokenizer.from_pretrained("omi-health/sum-small", trust_remote_code=False)
|
13 |
model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
|
14 |
|
15 |
+
# Move model to GPU if available
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
model = model.to(device)
|
18 |
+
|
19 |
+
print(f"Using device: {device}")
|
20 |
+
if device == "cuda":
|
21 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
22 |
+
print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
|
23 |
+
|
24 |
return model, tokenizer
|
25 |
|
26 |
def generate_soap_note(doctor_patient_conversation):
|
|
|
43 |
max_length=tokenizer.model_max_length
|
44 |
)
|
45 |
|
46 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
47 |
+
|
48 |
generate_ids = model.generate(
|
49 |
inputs.input_ids,
|
50 |
attention_mask=inputs.attention_mask, # Explicitly pass attention mask
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
--extra-index-url https://download.pytorch.org/whl/
|
2 |
torch
|
3 |
transformers>=4.36.0
|
4 |
gradio>=3.50.0
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
torch
|
3 |
transformers>=4.36.0
|
4 |
gradio>=3.50.0
|