Yaron Koresh commited on
Commit
c350929
·
verified ·
1 Parent(s): b328782

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -201,13 +201,13 @@ def main():
201
  ckpt = f"sdxl_lightning_{step}step_unet.safetensors"
202
 
203
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
204
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), map_location=device), strict=False)
205
 
206
  repo = "ByteDance/AnimateDiff-Lightning"
207
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
208
 
209
  adapter = MotionAdapter().to(device, dtype)
210
- adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
211
 
212
  pipe = AnimateDiffPipeline.from_pretrained(base, unet=unet, motion_adapter=adapter, torch_dtype=dtype, variant="fp16").to(device, dtype)
213
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
 
201
  ckpt = f"sdxl_lightning_{step}step_unet.safetensors"
202
 
203
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
204
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
205
 
206
  repo = "ByteDance/AnimateDiff-Lightning"
207
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
208
 
209
  adapter = MotionAdapter().to(device, dtype)
210
+ adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device), strict=False)
211
 
212
  pipe = AnimateDiffPipeline.from_pretrained(base, unet=unet, motion_adapter=adapter, torch_dtype=dtype, variant="fp16").to(device, dtype)
213
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")