jason-moore commited on
Commit
34b3b8f
·
1 Parent(s): d85b2a7
Files changed (2) hide show
  1. app.py +12 -1
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers.utils import logging
4
 
@@ -8,10 +9,18 @@ logger = logging.get_logger("transformers")
8
 
9
  # Load model directly from your Hugging Face repository
10
  def load_model():
11
-
12
  tokenizer = AutoTokenizer.from_pretrained("omi-health/sum-small", trust_remote_code=False)
13
  model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
14
 
 
 
 
 
 
 
 
 
 
15
  return model, tokenizer
16
 
17
  def generate_soap_note(doctor_patient_conversation):
@@ -34,6 +43,8 @@ Please generate a structured SOAP (Subjective, Objective, Assessment, Plan) note
34
  max_length=tokenizer.model_max_length
35
  )
36
 
 
 
37
  generate_ids = model.generate(
38
  inputs.input_ids,
39
  attention_mask=inputs.attention_mask, # Explicitly pass attention mask
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from transformers.utils import logging
5
 
 
9
 
10
  # Load model directly from your Hugging Face repository
11
  def load_model():
 
12
  tokenizer = AutoTokenizer.from_pretrained("omi-health/sum-small", trust_remote_code=False)
13
  model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
14
 
15
+ # Move model to GPU if available
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model = model.to(device)
18
+
19
+ print(f"Using device: {device}")
20
+ if device == "cuda":
21
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
22
+ print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
23
+
24
  return model, tokenizer
25
 
26
  def generate_soap_note(doctor_patient_conversation):
 
43
  max_length=tokenizer.model_max_length
44
  )
45
 
46
+ inputs = {k: v.to(device) for k, v in inputs.items()}
47
+
48
  generate_ids = model.generate(
49
  inputs.input_ids,
50
  attention_mask=inputs.attention_mask, # Explicitly pass attention mask
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch
3
  transformers>=4.36.0
4
  gradio>=3.50.0
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
  torch
3
  transformers>=4.36.0
4
  gradio>=3.50.0