Yaron Koresh commited on
Commit
c56b0e3
·
verified ·
1 Parent(s): c87c14d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -172,7 +172,7 @@ def main():
172
  global progress
173
 
174
  device = "cuda"
175
- dtype = torch.bfloat16
176
  result=[]
177
  step = 2
178
 
@@ -183,8 +183,8 @@ def main():
183
  repo = "ByteDance/SDXL-Lightning"
184
  ckpt = f"sdxl_lightning_{step}step_unet.safetensors"
185
 
186
- #unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
187
- #unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
188
 
189
  repo = "ByteDance/AnimateDiff-Lightning"
190
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
@@ -192,7 +192,7 @@ def main():
192
  adapter = MotionAdapter().to(device, dtype)
193
  adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
194
 
195
- pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype, variant="fp16").to(device)
196
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
197
 
198
  mp.set_start_method("spawn", force=True)
 
172
  global progress
173
 
174
  device = "cuda"
175
+ dtype = torch.float16
176
  result=[]
177
  step = 2
178
 
 
183
  repo = "ByteDance/SDXL-Lightning"
184
  ckpt = f"sdxl_lightning_{step}step_unet.safetensors"
185
 
186
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
187
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
188
 
189
  repo = "ByteDance/AnimateDiff-Lightning"
190
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
 
192
  adapter = MotionAdapter().to(device, dtype)
193
  adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
194
 
195
+ pipe = AnimateDiffPipeline.from_pretrained(base, unet=unet, motion_adapter=adapter, torch_dtype=dtype, variant="fp16").to(device, dtype)
196
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
197
 
198
  mp.set_start_method("spawn", force=True)