linoyts HF Staff commited on
Commit
2f7883b
·
verified ·
1 Parent(s): 8ae1c05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -169,17 +169,24 @@ def process_video_for_pose(video):
169
 
170
  return pose_video
171
 
172
- def process_video_for_control(video, control_type):
 
173
  """Process video based on the selected control type"""
174
  if control_type == "canny":
175
- return process_video_for_canny(video)
176
  elif control_type == "depth":
177
- return process_video_for_depth(video)
178
  elif control_type == "pose":
179
- return process_video_for_pose(video)
180
  else:
181
- return video
182
-
 
 
 
 
 
 
183
  @spaces.GPU(duration=160)
184
  def generate_video(
185
  reference_video,
@@ -213,6 +220,13 @@ def generate_video(
213
  # Handle seed
214
  if randomize_seed:
215
  seed = random.randint(0, 2**32 - 1)
 
 
 
 
 
 
 
216
 
217
  progress(0.05, desc="Loading control LoRA...")
218
 
@@ -221,20 +235,14 @@ def generate_video(
221
 
222
  # Loads video into a list of pil images
223
  video = load_video(reference_video)
224
- progress(0.1, desc="Processing video for control...")
225
 
226
  # Process video based on control type
227
- processed_video = process_video_for_control(video, control_type)
228
- processed_video = read_video(processed_video) # turns to tensor
229
 
230
- progress(0.2, desc="Preparing generation parameters...")
231
 
232
- # Calculate number of frames from duration (24 fps)
233
- fps = 24
234
- num_frames = int(duration * fps) + 1 # +1 for proper frame count
235
- # Ensure num_frames is valid for the model (multiple of temporal compression + 1)
236
- temporal_compression = pipeline.vae_temporal_compression_ratio
237
- num_frames = ((num_frames - 1) // temporal_compression) * temporal_compression + 1
238
 
239
  # Calculate downscaled dimensions
240
  downscale_factor = 2 / 3
@@ -451,14 +459,20 @@ with gr.Blocks() as demo:
451
  label="Generated Video",
452
  height=400
453
  )
 
 
 
 
454
 
455
 
456
 
457
  # Event handlers
458
  generate_btn.click(
 
 
459
  fn=generate_video,
460
  inputs=[
461
- reference_video,
462
  prompt,
463
  control_type,
464
  current_lora_state,
 
169
 
170
  return pose_video
171
 
172
+ def process_video_for_control(reference_video, control_type):
173
+ video = load_video(reference_video)
174
  """Process video based on the selected control type"""
175
  if control_type == "canny":
176
+ processed_video = process_video_for_canny(video)
177
  elif control_type == "depth":
178
+ processed_video = process_video_for_depth(video)
179
  elif control_type == "pose":
180
+ processed_video = process_video_for_pose(video)
181
  else:
182
+ processed_video = reference_video
183
+ fps = 24
184
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp2_file:
185
+ output2_path = tmp2_file.name
186
+ export_to_video(processed_video, output2_path, fps=fps)
187
+ return output2_path
188
+
189
+
190
  @spaces.GPU(duration=160)
191
  def generate_video(
192
  reference_video,
 
220
  # Handle seed
221
  if randomize_seed:
222
  seed = random.randint(0, 2**32 - 1)
223
+
224
+ # Calculate number of frames from duration (24 fps)
225
+ fps = 24
226
+ num_frames = int(duration * fps) + 1 # +1 for proper frame count
227
+ # Ensure num_frames is valid for the model (multiple of temporal compression + 1)
228
+ temporal_compression = pipeline.vae_temporal_compression_ratio
229
+ num_frames = ((num_frames - 1) // temporal_compression) * temporal_compression + 1
230
 
231
  progress(0.05, desc="Loading control LoRA...")
232
 
 
235
 
236
  # Loads video into a list of pil images
237
  video = load_video(reference_video)
238
+ # progress(0.1, desc="Processing video for control...")
239
 
240
  # Process video based on control type
241
+ #processed_video = process_video_for_control(video, control_type)
 
242
 
243
+ processed_video = read_video(video) # turns to tensor
244
 
245
+ progress(0.2, desc="Preparing generation parameters...")
 
 
 
 
 
246
 
247
  # Calculate downscaled dimensions
248
  downscale_factor = 2 / 3
 
459
  label="Generated Video",
460
  height=400
461
  )
462
+ control_video = gr.Video(
463
+ label="Control Video",
464
+ height=400
465
+ )
466
 
467
 
468
 
469
  # Event handlers
470
  generate_btn.click(
471
+ fn = process_video_for_control,
472
+ inputs = [reference_video, control_type], outputs = [control_video]).then(
473
  fn=generate_video,
474
  inputs=[
475
+ control_video,
476
  prompt,
477
  control_type,
478
  current_lora_state,