Yaron Koresh commited on
Commit
bb63a49
·
verified ·
1 Parent(s): ce53544

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -22
app.py CHANGED
@@ -8,21 +8,20 @@ import requests
8
  import gradio as gr
9
  import numpy as np
10
  from lxml.html import fromstring
11
- #from transformers import pipeline
12
  from torch import multiprocessing as mp
13
  #from torch.multiprocessing import Pool
14
  #from pathos.multiprocessing import ProcessPool as Pool
15
  from pathos.threading import ThreadPool as Pool
16
- #from diffusers.pipelines.flux import FluxPipeline
17
- #from diffusers.utils import export_to_gif
18
- #from huggingface_hub import hf_hub_download
19
- #from safetensors.torch import load_file
20
- from diffusers import DiffusionPipeline
21
- #from diffusers.utils import load_image
22
- #import jax
23
- #import jax.numpy as jnp
24
-
25
- def pipe_t2i():
26
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
27
  return PIPE
28
 
@@ -72,19 +71,20 @@ def generate_random_string(length):
72
  return ''.join(random.choice(characters) for _ in range(length))
73
 
74
  @spaces.GPU(duration=40)
75
- def Piper(name,posi):
 
 
76
  print("starting piper")
77
 
78
- ret1 = pp1(
79
  posi,
80
  height=512,
81
  width=512,
82
- num_inference_steps=4,
83
- max_sequence_length=256,
84
- guidance_scale=0
85
  )
86
- ret1.images[0].save(name)
87
 
 
88
  return name
89
 
90
  css="""
@@ -122,6 +122,7 @@ footer {
122
  js="""
123
  function custom(){
124
  document.querySelector("div#prompt input").setAttribute("maxlength","38")
 
125
  }
126
  """
127
 
@@ -158,13 +159,34 @@ def run(p1,*result):
158
  def main():
159
 
160
  global result
161
- global pp1
 
 
 
162
 
 
 
163
  result=[]
164
- pp1=pipe_t2i()
165
 
166
- mp.set_start_method("spawn", force=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
 
 
168
  with gr.Blocks(theme=gr.themes.Soft(),css=css,js=js) as demo:
169
  with gr.Column(elem_id="col-container"):
170
  gr.Markdown(f"""
@@ -182,8 +204,7 @@ def main():
182
  with gr.Row():
183
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
184
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
185
- result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
186
-
187
  gr.on(
188
  triggers=[run_button.click, prompt.submit],
189
  fn=run,inputs=[prompt,*result],outputs=result
 
8
  import gradio as gr
9
  import numpy as np
10
  from lxml.html import fromstring
11
+ from transformers import pipeline
12
  from torch import multiprocessing as mp
13
  #from torch.multiprocessing import Pool
14
  #from pathos.multiprocessing import ProcessPool as Pool
15
  from pathos.threading import ThreadPool as Pool
16
+ from diffusers.pipelines.flux import FluxPipeline
17
+ from diffusers.utils import export_to_gif, load_image
18
+ from huggingface_hub import hf_hub_download
19
+ from safetensors.torch import load_file
20
+ from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
21
+ import jax
22
+ import jax.numpy as jnp
23
+
24
+ def forest_schnell():
 
25
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
26
  return PIPE
27
 
 
71
  return ''.join(random.choice(characters) for _ in range(length))
72
 
73
  @spaces.GPU(duration=40)
74
+ def Piper(name,posi,neg):
75
+ global step
76
+
77
  print("starting piper")
78
 
79
+ out = pipe(
80
  posi,
81
  height=512,
82
  width=512,
83
+ num_inference_steps=step,
84
+ guidance_scale=1
 
85
  )
 
86
 
87
+ export_to_gif(out.frames[0],name)
88
  return name
89
 
90
  css="""
 
122
  js="""
123
  function custom(){
124
  document.querySelector("div#prompt input").setAttribute("maxlength","38")
125
+ document.querySelector("div#prompt2 input").setAttribute("maxlength","38")
126
  }
127
  """
128
 
 
159
  def main():
160
 
161
  global result
162
+ global pipe
163
+ global device
164
+ global step
165
+ global dtype
166
 
167
+ device = "cuda"
168
+ dtype = torch.float16
169
  result=[]
170
+ step = 2
171
 
172
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
173
+ repo = "ByteDance/SDXL-Lightning"
174
+ ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
175
+
176
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
177
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
178
+
179
+ repo = "ByteDance/AnimateDiff-Lightning"
180
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
181
+
182
+ adapter = MotionAdapter().to(device, dtype)
183
+ adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
184
+
185
+ pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, unet=unet, torch_dtype=dtype, variant="fp16").to(device)
186
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
187
 
188
+ mp.set_start_method("spawn", force=True)
189
+
190
  with gr.Blocks(theme=gr.themes.Soft(),css=css,js=js) as demo:
191
  with gr.Column(elem_id="col-container"):
192
  gr.Markdown(f"""
 
204
  with gr.Row():
205
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
206
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
207
+
 
208
  gr.on(
209
  triggers=[run_button.click, prompt.submit],
210
  fn=run,inputs=[prompt,*result],outputs=result