Spaces:
Sleeping
Sleeping
Yaron Koresh
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -172,7 +172,7 @@ def main():
|
|
172 |
global progress
|
173 |
|
174 |
device = "cuda"
|
175 |
-
dtype = torch.
|
176 |
result=[]
|
177 |
step = 2
|
178 |
|
@@ -183,8 +183,8 @@ def main():
|
|
183 |
repo = "ByteDance/SDXL-Lightning"
|
184 |
ckpt = f"sdxl_lightning_{step}step_unet.safetensors"
|
185 |
|
186 |
-
|
187 |
-
|
188 |
|
189 |
repo = "ByteDance/AnimateDiff-Lightning"
|
190 |
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
@@ -192,7 +192,7 @@ def main():
|
|
192 |
adapter = MotionAdapter().to(device, dtype)
|
193 |
adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
|
194 |
|
195 |
-
pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype, variant="fp16").to(device)
|
196 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
197 |
|
198 |
mp.set_start_method("spawn", force=True)
|
|
|
172 |
global progress
|
173 |
|
174 |
device = "cuda"
|
175 |
+
dtype = torch.float16
|
176 |
result=[]
|
177 |
step = 2
|
178 |
|
|
|
183 |
repo = "ByteDance/SDXL-Lightning"
|
184 |
ckpt = f"sdxl_lightning_{step}step_unet.safetensors"
|
185 |
|
186 |
+
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
|
187 |
+
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
|
188 |
|
189 |
repo = "ByteDance/AnimateDiff-Lightning"
|
190 |
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
|
|
192 |
adapter = MotionAdapter().to(device, dtype)
|
193 |
adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
|
194 |
|
195 |
+
pipe = AnimateDiffPipeline.from_pretrained(base, unet=unet, motion_adapter=adapter, torch_dtype=dtype, variant="fp16").to(device, dtype)
|
196 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
197 |
|
198 |
mp.set_start_method("spawn", force=True)
|