Yaron Koresh commited on
Commit
7219c3f
·
verified ·
1 Parent(s): c56b0e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -15
app.py CHANGED
@@ -74,23 +74,34 @@ def generate_random_string(length):
74
  characters = string.ascii_letters + string.digits
75
  return ''.join(random.choice(characters) for _ in range(length))
76
 
77
- @spaces.GPU(duration=40)
78
- def Piper(name,posi):
79
  global step
 
 
 
80
 
81
  print("starting piper")
82
-
 
 
 
 
 
 
 
83
  out = pipe(
84
- posi,
85
  height=512,
86
  width=512,
87
  num_inference_steps=step,
88
  guidance_scale=1,
89
  callback=progress_callback,
90
- callback_step=1
 
91
  )
92
 
93
- export_to_gif(out.frames[0],name)
94
  return name
95
 
96
  css="""
@@ -132,10 +143,10 @@ function custom(){
132
  }
133
  """
134
 
135
- def infer(p):
136
  print("infer: started")
137
 
138
- p1 = p["a"]
139
  name = generate_random_string(12)+".png"
140
 
141
  _do = ['beautiful', 'playful', 'photographed', 'realistic', 'dynamic poze', 'deep field', 'reasonable coloring', 'rough texture', 'best quality', 'focused']
@@ -143,17 +154,17 @@ def infer(p):
143
  _do.append(f'{p1}')
144
  posi = " ".join(_do)
145
 
146
- return Piper(name,posi)
147
 
148
- def run(p1,*result):
149
 
150
  p1_en = translate(p1,"english")
151
- p = {"a":p1_en}
152
  ln = len(result)
153
  print("images: "+str(ln))
154
  rng = list(range(ln))
155
 
156
- arr = [p for _ in rng]
157
  pool = Pool(ln)
158
  out = list(pool.imap(infer,arr))
159
  pool.close()
@@ -170,7 +181,13 @@ def main():
170
  global step
171
  global dtype
172
  global progress
173
-
 
 
 
 
 
 
174
  device = "cuda"
175
  dtype = torch.float16
176
  result=[]
@@ -184,7 +201,7 @@ def main():
184
  ckpt = f"sdxl_lightning_{step}step_unet.safetensors"
185
 
186
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
187
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
188
 
189
  repo = "ByteDance/AnimateDiff-Lightning"
190
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
@@ -209,6 +226,23 @@ def main():
209
  container=False,
210
  max_lines=1
211
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  with gr.Row():
213
  run_button = gr.Button("START",elem_classes="btn",scale=0)
214
  with gr.Row():
@@ -217,7 +251,7 @@ def main():
217
 
218
  gr.on(
219
  triggers=[run_button.click, prompt.submit],
220
- fn=run,inputs=[prompt,*result],outputs=result
221
  )
222
  demo.queue().launch()
223
 
 
74
  characters = string.ascii_letters + string.digits
75
  return ''.join(random.choice(characters) for _ in range(length))
76
 
77
+ @spaces.GPU(duration=45)
78
+ def Piper(name,positive_prompt,motion):
79
  global step
80
+ global fps
81
+ global time
82
+ global last_motion
83
 
84
  print("starting piper")
85
+
86
+ if motion_loaded != motion:
87
+ pipe.unload_lora_weights()
88
+ if motion != "":
89
+ pipe.load_lora_weights(motion, adapter_name="motion")
90
+ pipe.set_adapters(["motion"], [0.7])
91
+ last_motion = motion
92
+
93
  out = pipe(
94
+ positive_prompt,
95
  height=512,
96
  width=512,
97
  num_inference_steps=step,
98
  guidance_scale=1,
99
  callback=progress_callback,
100
+ callback_step=1,
101
+ frames=fps*time
102
  )
103
 
104
+ export_to_gif(out.frames[0],name,fps=fps)
105
  return name
106
 
107
  css="""
 
143
  }
144
  """
145
 
146
+ def infer(pm):
147
  print("infer: started")
148
 
149
+ p1 = pm["p"]
150
  name = generate_random_string(12)+".png"
151
 
152
  _do = ['beautiful', 'playful', 'photographed', 'realistic', 'dynamic poze', 'deep field', 'reasonable coloring', 'rough texture', 'best quality', 'focused']
 
154
  _do.append(f'{p1}')
155
  posi = " ".join(_do)
156
 
157
+ return Piper(name,posi,pm["m"])
158
 
159
+ def run(m,p1,*result):
160
 
161
  p1_en = translate(p1,"english")
162
+ pm = {"p":p1_en,"m":m}
163
  ln = len(result)
164
  print("images: "+str(ln))
165
  rng = list(range(ln))
166
 
167
+ arr = [pm for _ in rng]
168
  pool = Pool(ln)
169
  out = list(pool.imap(infer,arr))
170
  pool.close()
 
181
  global step
182
  global dtype
183
  global progress
184
+ global fps
185
+ global time
186
+ global last_motion
187
+
188
+ last_motion=None
189
+ fps=40
190
+ time=5
191
  device = "cuda"
192
  dtype = torch.float16
193
  result=[]
 
201
  ckpt = f"sdxl_lightning_{step}step_unet.safetensors"
202
 
203
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
204
+ unet.load_state_dict(torch.load(hf_hub_download(repo, ckpt), map_location=device), strict=False)
205
 
206
  repo = "ByteDance/AnimateDiff-Lightning"
207
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
 
226
  container=False,
227
  max_lines=1
228
  )
229
+ with gr.Row():
230
+ motion = gr.Dropdown(
231
+ label='Motion',
232
+ choices=[
233
+ ("Default", ""),
234
+ ("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
235
+ ("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
236
+ ("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
237
+ ("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
238
+ ("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
239
+ ("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
240
+ ("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
241
+ ("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
242
+ ],
243
+ value="",
244
+ interactive=True
245
+ )
246
  with gr.Row():
247
  run_button = gr.Button("START",elem_classes="btn",scale=0)
248
  with gr.Row():
 
251
 
252
  gr.on(
253
  triggers=[run_button.click, prompt.submit],
254
+ fn=run,inputs=[motion,prompt,*result],outputs=result
255
  )
256
  demo.queue().launch()
257