cocktailpeanut commited on
Commit
0d03cf7
Β·
1 Parent(s): ac9cb8b
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -21,6 +21,12 @@ from einops import rearrange
21
  import math
22
 
23
 
 
 
 
 
 
 
24
 
25
  #### Description ####
26
  title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
@@ -65,7 +71,7 @@ If you have any questions, please feel free to reach me out at <b>yuanzy22@mails
65
  negtive_prompt = ""
66
 
67
  # load model
68
- device = torch.device("cuda")
69
  preprocess_model = load_preprocess_model()
70
  config = OmegaConf.load("configs/config_customnet.yaml")
71
  model = instantiate_from_config(config.model)
@@ -299,4 +305,4 @@ if __name__=="__main__":
299
  parser.add_argument("--port", type=int, default=12345)
300
  args = parser.parse_args()
301
 
302
- main(args)
 
21
  import math
22
 
23
 
24
+ if torch.cuda.is_available():
25
+ device = "cuda"
26
+ elif torch.backends.mps.is_available():
27
+ device = "mps"
28
+ else:
29
+ device = "cpu"
30
 
31
  #### Description ####
32
  title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
 
71
  negtive_prompt = ""
72
 
73
  # load model
74
+ device = torch.device(device)
75
  preprocess_model = load_preprocess_model()
76
  config = OmegaConf.load("configs/config_customnet.yaml")
77
  model = instantiate_from_config(config.model)
 
305
  parser.add_argument("--port", type=int, default=12345)
306
  args = parser.parse_args()
307
 
308
+ main(args)