Yaron Koresh commited on
Commit
788d672
·
verified ·
1 Parent(s): 9955312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -9,14 +9,14 @@ import gradio as gr
9
  import numpy as np
10
  from lxml.html import fromstring
11
  from pathos.threading import ThreadPool as Pool
12
- from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
13
  from diffusers.pipelines.flux import FluxPipeline
14
  from diffusers.utils import export_to_gif
15
  from huggingface_hub import hf_hub_download
16
  from safetensors.torch import load_file
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
- dtype = torch.bfloat16
20
 
21
  step = 4
22
  repo = "ByteDance/AnimateDiff-Lightning"
@@ -27,7 +27,20 @@ base = "black-forest-labs/FLUX.1-dev"
27
  adapter = MotionAdapter().to(device, dtype)
28
  adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
29
  pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype, token=os.getenv("hf_token")).to(device)
30
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def translate(text,lang):
33
 
 
9
  import numpy as np
10
  from lxml.html import fromstring
11
  from pathos.threading import ThreadPool as Pool
12
+ from diffusers import AnimateDiffPipeline, MotionAdapter, DDIMScheduler
13
  from diffusers.pipelines.flux import FluxPipeline
14
  from diffusers.utils import export_to_gif
15
  from huggingface_hub import hf_hub_download
16
  from safetensors.torch import load_file
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ dtype = torch.float16
20
 
21
  step = 4
22
  repo = "ByteDance/AnimateDiff-Lightning"
 
27
  adapter = MotionAdapter().to(device, dtype)
28
  adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
29
  pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype, token=os.getenv("hf_token")).to(device)
30
+ scheduler = DDIMScheduler.from_pretrained(
31
+ base,
32
+ subfolder="scheduler",
33
+ clip_sample=False,
34
+ timestep_spacing="linspace",
35
+ beta_schedule="linear",
36
+ steps_offset=1,
37
+ )
38
+ pipe.scheduler = scheduler
39
+ pipe.enable_vae_slicing()
40
+ pipe.enable_model_cpu_offload()
41
+
42
+
43
+
44
 
45
  def translate(text,lang):
46