harry900000 commited on
Commit
74308ee
·
1 Parent(s): 4953ce6

make chunking size as a function argument & add a slider to control it

Browse files
app.py CHANGED
@@ -296,7 +296,7 @@ def generate_video(
296
  negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
297
  seed=42,
298
  randomize_seed=False,
299
- chunking=False,
300
  progress=gr.Progress(track_tqdm=True),
301
  ):
302
  _dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
@@ -338,6 +338,8 @@ def generate_video(
338
  watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory usage: {x} (MiB)"))
339
 
340
  # start inference
 
 
341
  videos, prompts = inference(args, control_inputs, chunking)
342
 
343
  # print the generation time
@@ -386,7 +388,7 @@ with gr.Blocks() as demo:
386
  randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=False)
387
  seed_input = gr.Slider(minimum=0, maximum=1000000, value=1, step=1, label="Seed")
388
 
389
- chunking_checkbox = gr.Checkbox(label="Chunking", value=True)
390
  generate_button = gr.Button("Generate Image")
391
 
392
  with gr.Column():
@@ -403,7 +405,7 @@ with gr.Blocks() as demo:
403
  negative_prompt_input,
404
  seed_input,
405
  randomize_seed_checkbox,
406
- chunking_checkbox,
407
  ],
408
  outputs=[output_video, output_file, seed_input],
409
  )
 
296
  negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
297
  seed=42,
298
  randomize_seed=False,
299
+ chunking=None,
300
  progress=gr.Progress(track_tqdm=True),
301
  ):
302
  _dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
 
338
  watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory usage: {x} (MiB)"))
339
 
340
  # start inference
341
+ if chunking <= 0:
342
+ chunking = None
343
  videos, prompts = inference(args, control_inputs, chunking)
344
 
345
  # print the generation time
 
388
  randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=False)
389
  seed_input = gr.Slider(minimum=0, maximum=1000000, value=1, step=1, label="Seed")
390
 
391
+ chunking_input = gr.Slider(minimum=0, maximum=121, value=4, step=1, label="Chunking size")
392
  generate_button = gr.Button("Generate Image")
393
 
394
  with gr.Column():
 
405
  negative_prompt_input,
406
  seed_input,
407
  randomize_seed_checkbox,
408
+ chunking_input,
409
  ],
410
  outputs=[output_video, output_file, seed_input],
411
  )
cosmos_transfer1/diffusion/inference/inference_utils.py CHANGED
@@ -710,7 +710,7 @@ def generate_world_from_control(
710
  x_sigma_max=None,
711
  augment_sigma=None,
712
  use_batch_processing: bool = True,
713
- chunking: bool = False,
714
  ) -> Tuple[np.array, list, list]:
715
  """Generate video using a conditioning video/image input.
716
 
@@ -724,7 +724,7 @@ def generate_world_from_control(
724
  seed (int): Random seed for generation
725
  condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
726
  num_input_frames (int): Number of input frames
727
- chunking: Whether to use the chunking method in generation pipeline
728
 
729
  Returns:
730
  np.array: Generated video frames in shape [T,H,W,C], range [0,255]
 
710
  x_sigma_max=None,
711
  augment_sigma=None,
712
  use_batch_processing: bool = True,
713
+ chunking: Optional[int] = None,
714
  ) -> Tuple[np.array, list, list]:
715
  """Generate video using a conditioning video/image input.
716
 
 
724
  seed (int): Random seed for generation
725
  condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
726
  num_input_frames (int): Number of input frames
727
+ chunking: Chunking size, if None, chunking is disabled
728
 
729
  Returns:
730
  np.array: Generated video frames in shape [T,H,W,C], range [0,255]
cosmos_transfer1/diffusion/inference/world_generation_pipeline.py CHANGED
@@ -151,7 +151,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
151
  regional_prompts: List[str] = None,
152
  region_definitions: Union[List[List[float]], torch.Tensor] = None,
153
  waymo_example: bool = False,
154
- chunking: bool = False,
155
  ):
156
  """Initialize diffusion world generation pipeline.
157
 
@@ -179,7 +179,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
179
  offload_prompt_upsampler: Whether to offload prompt upsampler after use
180
  process_group: Process group for distributed training
181
  waymo_example: Whether to use the waymo example post-training checkpoint
182
- chunking: Whether to use the chunking method in generation pipeline
183
  """
184
  self.num_input_frames = num_input_frames
185
  self.control_inputs = control_inputs
 
151
  regional_prompts: List[str] = None,
152
  region_definitions: Union[List[List[float]], torch.Tensor] = None,
153
  waymo_example: bool = False,
154
+ chunking: Optional[int] = None,
155
  ):
156
  """Initialize diffusion world generation pipeline.
157
 
 
179
  offload_prompt_upsampler: Whether to offload prompt upsampler after use
180
  process_group: Process group for distributed training
181
  waymo_example: Whether to use the waymo example post-training checkpoint
182
+ chunking: Chunking size, if None, chunking is disabled
183
  """
184
  self.num_input_frames = num_input_frames
185
  self.control_inputs = control_inputs
cosmos_transfer1/diffusion/model/model_v2w.py CHANGED
@@ -168,19 +168,18 @@ class DiffusionV2WModel(DiffusionT2WModel):
168
  x0_pred_replaced=x0_pred_replaced,
169
  )
170
 
171
- CHUNKING_SIZE = 4
172
  CHUNKING_MODE = "rand_order" # ["shuffle", "in_order", "rand_order"]
173
  IS_STAGGERED = True
174
 
175
- def get_chunks_indices(self, total_flen) -> List[torch.Tensor]:
176
  chunks_indices = []
177
  if self.CHUNKING_MODE == "shuffle":
178
- for index in torch.arange(0, total_flen, 1).split(self.CHUNKING_SIZE):
179
  chunks_indices.append(index)
180
  np.random.shuffle(chunks_indices)
181
  else:
182
  first_chunk_end = (
183
- int(torch.randint(low=0, high=self.CHUNKING_SIZE, size=(1,)) + 1) if self.IS_STAGGERED else self.CHUNKING_SIZE
184
  )
185
 
186
  if first_chunk_end >= total_flen:
@@ -188,7 +187,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
188
  else:
189
  chunks_indices.append(torch.arange(first_chunk_end))
190
 
191
- for index in torch.arange(first_chunk_end, total_flen, 1).split(self.CHUNKING_SIZE):
192
  chunks_indices.append(index)
193
 
194
  if self.CHUNKING_MODE == "in_order":
@@ -216,7 +215,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
216
  add_input_frames_guidance: bool = False,
217
  x_sigma_max: Optional[torch.Tensor] = None,
218
  sigma_max: Optional[float] = None,
219
- chunking: bool = False,
220
  **kwargs,
221
  ) -> Tensor:
222
  """Generates video samples conditioned on input frames.
@@ -234,7 +233,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
234
  condition_video_augment_sigma_in_inference: Noise level for condition augmentation
235
  add_input_frames_guidance: Whether to apply guidance to input frames
236
  x_sigma_max: Maximum noise level tensor
237
- chunking: Whether to use the chunking method in generation pipeline
238
 
239
  Returns:
240
  Generated video samples tensor
@@ -294,7 +293,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
294
  condition_video_augment_sigma_in_inference: float = None,
295
  add_input_frames_guidance: bool = False,
296
  seed: int = 1,
297
- chunking: bool = False,
298
  ) -> Callable:
299
  """Creates denoising function for conditional video generation.
300
 
@@ -307,12 +306,12 @@ class DiffusionV2WModel(DiffusionT2WModel):
307
  condition_video_augment_sigma_in_inference: Noise level for condition augmentation
308
  add_input_frames_guidance: Whether to apply guidance to input frames
309
  seed: Random seed for reproducibility
310
- chunking: Whether to use the chunking method in generation pipeline
311
 
312
  Returns:
313
  Function that takes noisy input and noise level and returns denoised prediction
314
  """
315
- if not chunking:
316
  if is_negative_prompt:
317
  condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
318
  else:
@@ -348,8 +347,6 @@ class DiffusionV2WModel(DiffusionT2WModel):
348
 
349
  return x0_fn
350
  else:
351
- log.critical("GO CHUNKING !!!")
352
-
353
  def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
354
  if is_negative_prompt:
355
  condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
@@ -358,7 +355,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
358
 
359
  noises = torch.zeros_like(condition_latent)
360
  T = condition_latent.shape[2]
361
- for chunk_idx in self.get_chunks_indices(T):
362
  latents_ = condition_latent[:, :, chunk_idx, :, :]
363
  log.info(f"chunk_idx: {chunk_idx}, chunk shape: {latents_.shape}")
364
  # controlnet_cond_ = self.controlnet_data[:, chunk_idx]
 
168
  x0_pred_replaced=x0_pred_replaced,
169
  )
170
 
 
171
  CHUNKING_MODE = "rand_order" # ["shuffle", "in_order", "rand_order"]
172
  IS_STAGGERED = True
173
 
174
+ def get_chunks_indices(self, total_flen, chunking_size) -> List[torch.Tensor]:
175
  chunks_indices = []
176
  if self.CHUNKING_MODE == "shuffle":
177
+ for index in torch.arange(0, total_flen, 1).split(chunking_size):
178
  chunks_indices.append(index)
179
  np.random.shuffle(chunks_indices)
180
  else:
181
  first_chunk_end = (
182
+ int(torch.randint(low=0, high=chunking_size, size=(1,)) + 1) if self.IS_STAGGERED else chunking_size
183
  )
184
 
185
  if first_chunk_end >= total_flen:
 
187
  else:
188
  chunks_indices.append(torch.arange(first_chunk_end))
189
 
190
+ for index in torch.arange(first_chunk_end, total_flen, 1).split(chunking_size):
191
  chunks_indices.append(index)
192
 
193
  if self.CHUNKING_MODE == "in_order":
 
215
  add_input_frames_guidance: bool = False,
216
  x_sigma_max: Optional[torch.Tensor] = None,
217
  sigma_max: Optional[float] = None,
218
+ chunking: Optional[int] = None,
219
  **kwargs,
220
  ) -> Tensor:
221
  """Generates video samples conditioned on input frames.
 
233
  condition_video_augment_sigma_in_inference: Noise level for condition augmentation
234
  add_input_frames_guidance: Whether to apply guidance to input frames
235
  x_sigma_max: Maximum noise level tensor
236
+ chunking: Chunking size, if None, chunking is disabled
237
 
238
  Returns:
239
  Generated video samples tensor
 
293
  condition_video_augment_sigma_in_inference: float = None,
294
  add_input_frames_guidance: bool = False,
295
  seed: int = 1,
296
+ chunking: Optional[int] = None,
297
  ) -> Callable:
298
  """Creates denoising function for conditional video generation.
299
 
 
306
  condition_video_augment_sigma_in_inference: Noise level for condition augmentation
307
  add_input_frames_guidance: Whether to apply guidance to input frames
308
  seed: Random seed for reproducibility
309
+ chunking: Chunking size, if None, chunking is disabled
310
 
311
  Returns:
312
  Function that takes noisy input and noise level and returns denoised prediction
313
  """
314
+ if chunking is None:
315
  if is_negative_prompt:
316
  condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
317
  else:
 
347
 
348
  return x0_fn
349
  else:
 
 
350
  def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
351
  if is_negative_prompt:
352
  condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
 
355
 
356
  noises = torch.zeros_like(condition_latent)
357
  T = condition_latent.shape[2]
358
+ for chunk_idx in self.get_chunks_indices(T, chunking):
359
  latents_ = condition_latent[:, :, chunk_idx, :, :]
360
  log.info(f"chunk_idx: {chunk_idx}, chunk shape: {latents_.shape}")
361
  # controlnet_cond_ = self.controlnet_data[:, chunk_idx]