Yaron Koresh commited on
Commit
186741a
·
verified ·
1 Parent(s): f1950bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -31
app.py CHANGED
@@ -18,27 +18,10 @@ from diffusers.utils import export_to_gif, load_image
18
  from diffusers.models.modeling_utils import ModelMixin
19
  from huggingface_hub import hf_hub_download
20
  from safetensors.torch import load_file, save_file
21
- from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
22
  import jax
23
  import jax.numpy as jnp
24
 
25
- class Model(nn.Module):
26
- def __init__(self):
27
- super().__init__()
28
- self.register_buffer('buffer', torch.ones(1, 1))
29
-
30
- def forward(self, x):
31
- new_tensor = torch.randn(1, 1)
32
- self.buffer = torch.cat([self.buffer, new_tensor], dim=0)
33
- return self
34
-
35
- def dict2model(dict):
36
- model = Model()
37
- m = model(dict)
38
- mix = ModelMixin()
39
- mix(m)
40
- return mix
41
-
42
  def forest_schnell():
43
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
44
  return PIPE
@@ -223,23 +206,14 @@ def main():
223
 
224
  adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
225
 
226
- repo="stabilityai/sd-vae-ft-mse-original"
227
- ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
228
- vae = dict2model(load_file(hf_hub_download(repo, ckpt), device=device))
229
-
230
- repo="ByteDance/SDXL-Lightning"
231
- ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
232
- unet = dict2model(load_file(hf_hub_download(repo, ckpt), device=device))
233
-
234
- #repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
235
 
236
- #repo = "emilianJR/epiCRealism"
237
- #ckpt = "unet/diffusion_pytorch_model.safetensors"
238
- #unet = load_file(hf_hub_download(repo, ckpt), device=device)
239
-
240
  repo = "ByteDance/AnimateDiff-Lightning"
241
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
242
  base = "black-forest-labs/FLUX.1-schnell"
 
 
243
 
244
  pipe = AnimateDiffPipeline.from_pretrained(base, vae=vae, motion_adapter=adapter, feature_extractor=None, image_encoder=None, unet=unet, torch_dtype=dtype, token=os.getenv("hf_token")).to(device)
245
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
 
18
  from diffusers.models.modeling_utils import ModelMixin
19
  from huggingface_hub import hf_hub_download
20
  from safetensors.torch import load_file, save_file
21
+ from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL, UNet3DConditionModel
22
  import jax
23
  import jax.numpy as jnp
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def forest_schnell():
26
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
27
  return PIPE
 
206
 
207
  adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
208
 
209
+ vae = AutoencoderKL.from_single_file("stabilityai/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.safetensors")
210
+ unet = UNet3DConditionModel.to(device, dtype).load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", f"sdxl_lightning_{step}step_unet.safetensors"), device=device))
 
 
 
 
 
 
 
211
 
 
 
 
 
212
  repo = "ByteDance/AnimateDiff-Lightning"
213
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
214
  base = "black-forest-labs/FLUX.1-schnell"
215
+ #base = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
216
+ #base = "emilianJR/epiCRealism"
217
 
218
  pipe = AnimateDiffPipeline.from_pretrained(base, vae=vae, motion_adapter=adapter, feature_extractor=None, image_encoder=None, unet=unet, torch_dtype=dtype, token=os.getenv("hf_token")).to(device)
219
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")