YulianSa commited on
Commit
05d74b5
·
1 Parent(s): 5e6f2af
Files changed (1) hide show
  1. infer_api.py +1 -1
infer_api.py CHANGED
@@ -279,6 +279,7 @@ def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
279
  generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
280
 
281
  if torch.cuda.is_available():
 
282
  pipeline.to(device)
283
 
284
  images_cond = []
@@ -344,7 +345,6 @@ def load_multiview_pipeline(cfg):
344
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
345
  cfg.pretrained_path,
346
  torch_dtype=torch.float16,)
347
- pipeline.unet.enable_xformers_memory_efficient_attention()
348
  return pipeline
349
 
350
 
 
279
  generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
280
 
281
  if torch.cuda.is_available():
282
+ pipeline.unet.enable_xformers_memory_efficient_attention()
283
  pipeline.to(device)
284
 
285
  images_cond = []
 
345
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
346
  cfg.pretrained_path,
347
  torch_dtype=torch.float16,)
 
348
  return pipeline
349
 
350