stzhao commited on
Commit
9fd6eb6
·
verified ·
1 Parent(s): 2efad28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -7,6 +7,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
7
  # # Set up environment
8
  # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
9
 
 
 
 
 
 
10
  # Load models
11
  def load_models():
12
  model_name = "X-ART/LeX-Enhancer-full"
@@ -22,7 +27,8 @@ def load_models():
22
  "X-ART/LeX-Lumina",
23
  torch_dtype=torch.bfloat16
24
  )
25
- pipe.to("cuda")
 
26
 
27
  return model, tokenizer, pipe
28
 
 
7
  # # Set up environment
8
  # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
9
 
10
+ if torch.cuda.is_available():
11
+ torch_dtype = torch.bfloat16
12
+ else:
13
+ torch_dtype = torch.float32
14
+
15
  # Load models
16
  def load_models():
17
  model_name = "X-ART/LeX-Enhancer-full"
 
27
  "X-ART/LeX-Lumina",
28
  torch_dtype=torch.bfloat16
29
  )
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ pipe.to(device, torch_dtype)
32
 
33
  return model, tokenizer, pipe
34