Yaron Koresh commited on
Commit
a0d20c5
·
verified ·
1 Parent(s): 925da88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -15,8 +15,9 @@ from torch import multiprocessing as mp
15
  from pathos.threading import ThreadPool as Pool
16
  from diffusers.pipelines.flux import FluxPipeline
17
  from diffusers.utils import export_to_gif, load_image
 
18
  from huggingface_hub import hf_hub_download
19
- from safetensors.torch import load_file
20
  from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
21
  import jax
22
  import jax.numpy as jnp
@@ -207,12 +208,14 @@ def main():
207
 
208
  repo="stabilityai/sd-vae-ft-mse-original"
209
  ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
210
- vae = load_file(hf_hub_download(repo, ckpt), device=device)
 
211
 
212
  repo="ByteDance/SDXL-Lightning"
213
  ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
214
- unet = load_file(hf_hub_download(repo, ckpt), device=device)
215
-
 
216
  #repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
217
 
218
  #repo = "emilianJR/epiCRealism"
 
15
  from pathos.threading import ThreadPool as Pool
16
  from diffusers.pipelines.flux import FluxPipeline
17
  from diffusers.utils import export_to_gif, load_image
18
+ from diffusers.models.modeling_utils import ModelMixin
19
  from huggingface_hub import hf_hub_download
20
+ from safetensors.torch import load_file, save_file
21
  from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
22
  import jax
23
  import jax.numpy as jnp
 
208
 
209
  repo="stabilityai/sd-vae-ft-mse-original"
210
  ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
211
+ save_file(load_file(hf_hub_download(repo, ckpt), device=device),"./vae")
212
+ vae = ModelMixin("./vae")
213
 
214
  repo="ByteDance/SDXL-Lightning"
215
  ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
216
+ save_file(load_file(hf_hub_download(repo, ckpt), device=device),"./unet")
217
+ unet = ModelMixin("./unet")
218
+
219
  #repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
220
 
221
  #repo = "emilianJR/epiCRealism"