harry900000 commited on
Commit
17d970d
·
1 Parent(s): 5559672

add chunking

Browse files
app.py CHANGED
@@ -55,6 +55,7 @@ import random
55
  from io import BytesIO
56
 
57
  import torch
 
58
  from cosmos_transfer1.checkpoints import (
59
  BASE_7B_CHECKPOINT_AV_SAMPLE_PATH,
60
  BASE_7B_CHECKPOINT_PATH,
@@ -70,14 +71,13 @@ from cosmos_transfer1.diffusion.inference.world_generation_pipeline import (
70
  )
71
  from cosmos_transfer1.utils import log, misc
72
  from cosmos_transfer1.utils.io import read_prompts_from_file, save_video
73
-
74
  from helper import parse_arguments
75
 
76
  torch.enable_grad(False)
77
  torch.serialization.add_safe_globals([BytesIO])
78
 
79
 
80
- def inference(cfg, control_inputs) -> Tuple[List[str], List[str]]:
81
  video_paths = []
82
  prompt_paths = []
83
 
@@ -87,9 +87,10 @@ def inference(cfg, control_inputs) -> Tuple[List[str], List[str]]:
87
  device_rank = 0
88
  process_group = None
89
  if cfg.num_gpus > 1:
90
- from cosmos_transfer1.utils import distributed
91
  from megatron.core import parallel_state
92
 
 
 
93
  distributed.init()
94
  parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus)
95
  process_group = parallel_state.get_context_parallel_group()
@@ -142,6 +143,7 @@ def inference(cfg, control_inputs) -> Tuple[List[str], List[str]]:
142
  upsample_prompt=cfg.upsample_prompt,
143
  offload_prompt_upsampler=cfg.offload_prompt_upsampler,
144
  process_group=process_group,
 
145
  )
146
 
147
  if cfg.batch_input_path:
@@ -278,6 +280,7 @@ def generate_video(
278
  negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
279
  seed=42,
280
  randomize_seed=False,
 
281
  progress=gr.Progress(track_tqdm=True),
282
  ):
283
  if randomize_seed:
@@ -315,7 +318,7 @@ def generate_video(
315
  watcher = watch_gpu_memory(10)
316
 
317
  # start inference
318
- videos, prompts = inference(args, control_inputs)
319
 
320
  # print the generation time
321
  end_time = time.time()
@@ -361,6 +364,7 @@ with gr.Blocks() as demo:
361
  randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=False)
362
  seed_input = gr.Slider(minimum=0, maximum=1000000, value=1, step=1, label="Seed")
363
 
 
364
  generate_button = gr.Button("Generate Image")
365
 
366
  with gr.Column():
@@ -369,7 +373,16 @@ with gr.Blocks() as demo:
369
 
370
  generate_button.click(
371
  fn=generate_video,
372
- inputs=[rgb_video_input, hdmap_input, lidar_input, prompt_input, negative_prompt_input, seed_input, randomize_seed_checkbox],
 
 
 
 
 
 
 
 
 
373
  outputs=[output_video, output_file, seed_input],
374
  )
375
 
 
55
  from io import BytesIO
56
 
57
  import torch
58
+
59
  from cosmos_transfer1.checkpoints import (
60
  BASE_7B_CHECKPOINT_AV_SAMPLE_PATH,
61
  BASE_7B_CHECKPOINT_PATH,
 
71
  )
72
  from cosmos_transfer1.utils import log, misc
73
  from cosmos_transfer1.utils.io import read_prompts_from_file, save_video
 
74
  from helper import parse_arguments
75
 
76
  torch.enable_grad(False)
77
  torch.serialization.add_safe_globals([BytesIO])
78
 
79
 
80
+ def inference(cfg, control_inputs, chunking) -> Tuple[List[str], List[str]]:
81
  video_paths = []
82
  prompt_paths = []
83
 
 
87
  device_rank = 0
88
  process_group = None
89
  if cfg.num_gpus > 1:
 
90
  from megatron.core import parallel_state
91
 
92
+ from cosmos_transfer1.utils import distributed
93
+
94
  distributed.init()
95
  parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus)
96
  process_group = parallel_state.get_context_parallel_group()
 
143
  upsample_prompt=cfg.upsample_prompt,
144
  offload_prompt_upsampler=cfg.offload_prompt_upsampler,
145
  process_group=process_group,
146
+ chunking=chunking,
147
  )
148
 
149
  if cfg.batch_input_path:
 
280
  negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
281
  seed=42,
282
  randomize_seed=False,
283
+ chunking=False,
284
  progress=gr.Progress(track_tqdm=True),
285
  ):
286
  if randomize_seed:
 
318
  watcher = watch_gpu_memory(10)
319
 
320
  # start inference
321
+ videos, prompts = inference(args, control_inputs, chunking)
322
 
323
  # print the generation time
324
  end_time = time.time()
 
364
  randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=False)
365
  seed_input = gr.Slider(minimum=0, maximum=1000000, value=1, step=1, label="Seed")
366
 
367
+ chunking_checkbox = gr.Checkbox(label="Chunking", value=True)
368
  generate_button = gr.Button("Generate Image")
369
 
370
  with gr.Column():
 
373
 
374
  generate_button.click(
375
  fn=generate_video,
376
+ inputs=[
377
+ rgb_video_input,
378
+ hdmap_input,
379
+ lidar_input,
380
+ prompt_input,
381
+ negative_prompt_input,
382
+ seed_input,
383
+ randomize_seed_checkbox,
384
+ chunking_checkbox,
385
+ ],
386
  outputs=[output_video, output_file, seed_input],
387
  )
388
 
cosmos_transfer1/diffusion/inference/inference_utils.py CHANGED
@@ -710,6 +710,7 @@ def generate_world_from_control(
710
  x_sigma_max=None,
711
  augment_sigma=None,
712
  use_batch_processing: bool = True,
 
713
  ) -> Tuple[np.array, list, list]:
714
  """Generate video using a conditioning video/image input.
715
 
@@ -723,6 +724,7 @@ def generate_world_from_control(
723
  seed (int): Random seed for generation
724
  condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
725
  num_input_frames (int): Number of input frames
 
726
 
727
  Returns:
728
  np.array: Generated video frames in shape [T,H,W,C], range [0,255]
@@ -761,6 +763,7 @@ def generate_world_from_control(
761
  patch_h=h,
762
  patch_w=w,
763
  use_batch_processing=use_batch_processing,
 
764
  )
765
  return sample
766
 
 
710
  x_sigma_max=None,
711
  augment_sigma=None,
712
  use_batch_processing: bool = True,
713
+ chunking: bool = False,
714
  ) -> Tuple[np.array, list, list]:
715
  """Generate video using a conditioning video/image input.
716
 
 
724
  seed (int): Random seed for generation
725
  condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
726
  num_input_frames (int): Number of input frames
727
+ chunking: Whether to use the chunking method in generation pipeline
728
 
729
  Returns:
730
  np.array: Generated video frames in shape [T,H,W,C], range [0,255]
 
763
  patch_h=h,
764
  patch_w=w,
765
  use_batch_processing=use_batch_processing,
766
+ chunking=chunking,
767
  )
768
  return sample
769
 
cosmos_transfer1/diffusion/inference/world_generation_pipeline.py CHANGED
@@ -151,6 +151,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
151
  regional_prompts: List[str] = None,
152
  region_definitions: Union[List[List[float]], torch.Tensor] = None,
153
  waymo_example: bool = False,
 
154
  ):
155
  """Initialize diffusion world generation pipeline.
156
 
@@ -178,6 +179,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
178
  offload_prompt_upsampler: Whether to offload prompt upsampler after use
179
  process_group: Process group for distributed training
180
  waymo_example: Whether to use the waymo example post-training checkpoint
 
181
  """
182
  self.num_input_frames = num_input_frames
183
  self.control_inputs = control_inputs
@@ -201,6 +203,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
201
  self.seed = seed
202
  self.regional_prompts = regional_prompts
203
  self.region_definitions = region_definitions
 
204
 
205
  super().__init__(
206
  checkpoint_dir=checkpoint_dir,
@@ -621,6 +624,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
621
  sigma_max=self.sigma_max if x_sigma_max is not None else None,
622
  x_sigma_max=x_sigma_max,
623
  use_batch_processing=False if is_upscale_case else True,
 
624
  )
625
  log.info("Completed diffusion sampling")
626
  log.info("Starting VAE decode")
 
151
  regional_prompts: List[str] = None,
152
  region_definitions: Union[List[List[float]], torch.Tensor] = None,
153
  waymo_example: bool = False,
154
+ chunking: bool = False,
155
  ):
156
  """Initialize diffusion world generation pipeline.
157
 
 
179
  offload_prompt_upsampler: Whether to offload prompt upsampler after use
180
  process_group: Process group for distributed training
181
  waymo_example: Whether to use the waymo example post-training checkpoint
182
+ chunking: Whether to use the chunking method in generation pipeline
183
  """
184
  self.num_input_frames = num_input_frames
185
  self.control_inputs = control_inputs
 
203
  self.seed = seed
204
  self.regional_prompts = regional_prompts
205
  self.region_definitions = region_definitions
206
+ self.chunking = chunking
207
 
208
  super().__init__(
209
  checkpoint_dir=checkpoint_dir,
 
624
  sigma_max=self.sigma_max if x_sigma_max is not None else None,
625
  x_sigma_max=x_sigma_max,
626
  use_batch_processing=False if is_upscale_case else True,
627
+ chunking=self.chunking,
628
  )
629
  log.info("Completed diffusion sampling")
630
  log.info("Starting VAE decode")
cosmos_transfer1/diffusion/model/model_v2w.py CHANGED
@@ -16,6 +16,7 @@
16
  from dataclasses import dataclass
17
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
18
 
 
19
  import torch
20
  from megatron.core import parallel_state
21
  from torch import Tensor
@@ -167,6 +168,39 @@ class DiffusionV2WModel(DiffusionT2WModel):
167
  x0_pred_replaced=x0_pred_replaced,
168
  )
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def generate_samples_from_batch(
171
  self,
172
  data_batch: Dict,
@@ -182,6 +216,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
182
  add_input_frames_guidance: bool = False,
183
  x_sigma_max: Optional[torch.Tensor] = None,
184
  sigma_max: Optional[float] = None,
 
185
  **kwargs,
186
  ) -> Tensor:
187
  """Generates video samples conditioned on input frames.
@@ -199,6 +234,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
199
  condition_video_augment_sigma_in_inference: Noise level for condition augmentation
200
  add_input_frames_guidance: Whether to apply guidance to input frames
201
  x_sigma_max: Maximum noise level tensor
 
202
 
203
  Returns:
204
  Generated video samples tensor
@@ -213,6 +249,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
213
 
214
  assert condition_latent is not None, "condition_latent should be provided"
215
 
 
216
  x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
217
  data_batch,
218
  guidance,
@@ -222,6 +259,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
222
  condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
223
  add_input_frames_guidance=add_input_frames_guidance,
224
  seed=seed,
 
225
  )
226
  if sigma_max is None:
227
  sigma_max = self.sde.sigma_max
@@ -256,6 +294,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
256
  condition_video_augment_sigma_in_inference: float = None,
257
  add_input_frames_guidance: bool = False,
258
  seed: int = 1,
 
259
  ) -> Callable:
260
  """Creates denoising function for conditional video generation.
261
 
@@ -268,44 +307,92 @@ class DiffusionV2WModel(DiffusionT2WModel):
268
  condition_video_augment_sigma_in_inference: Noise level for condition augmentation
269
  add_input_frames_guidance: Whether to apply guidance to input frames
270
  seed: Random seed for reproducibility
 
271
 
272
  Returns:
273
  Function that takes noisy input and noise level and returns denoised prediction
274
  """
275
- if is_negative_prompt:
276
- condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
277
- else:
278
- condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
279
-
280
- condition.video_cond_bool = True
281
- condition = self.add_condition_video_indicator_and_video_input_mask(
282
- condition_latent, condition, num_condition_t
283
- )
 
284
 
285
- uncondition.video_cond_bool = False if add_input_frames_guidance else True
286
- uncondition = self.add_condition_video_indicator_and_video_input_mask(
287
- condition_latent, uncondition, num_condition_t
288
- )
289
 
290
- def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
291
- cond_x0 = self.denoise(
292
- noise_x,
293
- sigma,
294
- condition,
295
- condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
296
- seed=seed,
297
- ).x0_pred_replaced
298
- uncond_x0 = self.denoise(
299
- noise_x,
300
- sigma,
301
- uncondition,
302
- condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
303
- seed=seed,
304
- ).x0_pred_replaced
305
-
306
- return cond_x0 + guidance * (cond_x0 - uncond_x0)
307
-
308
- return x0_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  def add_condition_video_indicator_and_video_input_mask(
311
  self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None
 
16
  from dataclasses import dataclass
17
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
18
 
19
+ import numpy as np
20
  import torch
21
  from megatron.core import parallel_state
22
  from torch import Tensor
 
168
  x0_pred_replaced=x0_pred_replaced,
169
  )
170
 
171
+ CHUNKING_SIZE = 4
172
+ CHUNKING_MODE = "rand_order" # ["shuffle", "in_order", "rand_order"]
173
+ IS_STAGGERED = True
174
+
175
+ def get_chunks_indices(self, total_flen) -> List[torch.Tensor]:
176
+ chunks_indices = []
177
+ if self.CHUNKING_MODE == "shuffle":
178
+ for index in torch.arange(0, total_flen, 1).split(self.CHUNKING_SIZE):
179
+ chunks_indices.append(index)
180
+ np.random.shuffle(chunks_indices)
181
+ else:
182
+ first_chunk_end = (
183
+ int(torch.randint(low=0, high=self.CHUNKING_SIZE, size=(1,)) + 1) if self.IS_STAGGERED else self.CHUNKING_SIZE
184
+ )
185
+
186
+ if first_chunk_end >= total_flen:
187
+ chunks_indices.append(torch.arange(total_flen))
188
+ else:
189
+ chunks_indices.append(torch.arange(first_chunk_end))
190
+
191
+ for index in torch.arange(first_chunk_end, total_flen, 1).split(self.CHUNKING_SIZE):
192
+ chunks_indices.append(index)
193
+
194
+ if self.CHUNKING_MODE == "in_order":
195
+ pass
196
+ elif self.CHUNKING_MODE == "rand_order":
197
+ if np.random.rand() > 0.5:
198
+ chunks_indices = chunks_indices[::-1]
199
+ else:
200
+ raise NotImplementedError(f"{self.CHUNKING_MODE} mode not implemented!!")
201
+
202
+ return chunks_indices
203
+
204
  def generate_samples_from_batch(
205
  self,
206
  data_batch: Dict,
 
216
  add_input_frames_guidance: bool = False,
217
  x_sigma_max: Optional[torch.Tensor] = None,
218
  sigma_max: Optional[float] = None,
219
+ chunking: bool = False,
220
  **kwargs,
221
  ) -> Tensor:
222
  """Generates video samples conditioned on input frames.
 
234
  condition_video_augment_sigma_in_inference: Noise level for condition augmentation
235
  add_input_frames_guidance: Whether to apply guidance to input frames
236
  x_sigma_max: Maximum noise level tensor
237
+ chunking: Whether to use the chunking method in generation pipeline
238
 
239
  Returns:
240
  Generated video samples tensor
 
249
 
250
  assert condition_latent is not None, "condition_latent should be provided"
251
 
252
+ # try to add chunking here !!!
253
  x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
254
  data_batch,
255
  guidance,
 
259
  condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
260
  add_input_frames_guidance=add_input_frames_guidance,
261
  seed=seed,
262
+ chunking=chunking,
263
  )
264
  if sigma_max is None:
265
  sigma_max = self.sde.sigma_max
 
294
  condition_video_augment_sigma_in_inference: float = None,
295
  add_input_frames_guidance: bool = False,
296
  seed: int = 1,
297
+ chunking: bool = False,
298
  ) -> Callable:
299
  """Creates denoising function for conditional video generation.
300
 
 
307
  condition_video_augment_sigma_in_inference: Noise level for condition augmentation
308
  add_input_frames_guidance: Whether to apply guidance to input frames
309
  seed: Random seed for reproducibility
310
+ chunking: Whether to use the chunking method in generation pipeline
311
 
312
  Returns:
313
  Function that takes noisy input and noise level and returns denoised prediction
314
  """
315
+ if not chunking:
316
+ if is_negative_prompt:
317
+ condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
318
+ else:
319
+ condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
320
+
321
+ condition.video_cond_bool = True
322
+ condition = self.add_condition_video_indicator_and_video_input_mask(
323
+ condition_latent, condition, num_condition_t
324
+ )
325
 
326
+ uncondition.video_cond_bool = False if add_input_frames_guidance else True
327
+ uncondition = self.add_condition_video_indicator_and_video_input_mask(
328
+ condition_latent, uncondition, num_condition_t
329
+ )
330
 
331
+ def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
332
+ cond_x0 = self.denoise(
333
+ noise_x,
334
+ sigma,
335
+ condition,
336
+ condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
337
+ seed=seed,
338
+ ).x0_pred_replaced
339
+ uncond_x0 = self.denoise(
340
+ noise_x,
341
+ sigma,
342
+ uncondition,
343
+ condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
344
+ seed=seed,
345
+ ).x0_pred_replaced
346
+
347
+ return cond_x0 + guidance * (cond_x0 - uncond_x0)
348
+
349
+ return x0_fn
350
+ else:
351
+ log.critical("GO CHUNKING !!!")
352
+
353
+ def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
354
+ if is_negative_prompt:
355
+ condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
356
+ else:
357
+ condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
358
+
359
+ noises = torch.zeros_like(condition_latent)
360
+ T = condition_latent.shape[2]
361
+ for chunk_idx in self.get_chunks_indices(T):
362
+ latents_ = condition_latent[:, :, chunk_idx, :, :]
363
+ log.info(f"chunk_idx: {chunk_idx}, chunk shape: {latents_.shape}")
364
+ # controlnet_cond_ = self.controlnet_data[:, chunk_idx]
365
+
366
+ condition.video_cond_bool = True
367
+ condition = self.add_condition_video_indicator_and_video_input_mask(
368
+ latents_, condition, num_condition_t
369
+ )
370
+
371
+ uncondition.video_cond_bool = False if add_input_frames_guidance else True
372
+ uncondition = self.add_condition_video_indicator_and_video_input_mask(
373
+ latents_, uncondition, num_condition_t
374
+ )
375
+
376
+ cond_x0 = self.denoise(
377
+ noise_x,
378
+ sigma,
379
+ condition,
380
+ condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
381
+ seed=seed,
382
+ ).x0_pred_replaced
383
+ uncond_x0 = self.denoise(
384
+ noise_x,
385
+ sigma,
386
+ uncondition,
387
+ condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
388
+ seed=seed,
389
+ ).x0_pred_replaced
390
+
391
+ noises[:, :, chunk_idx, :, :] = cond_x0 + guidance * (cond_x0 - uncond_x0)
392
+
393
+ # TODO: need scheduler ?
394
+ return noises
395
+ return x0_fn
396
 
397
  def add_condition_video_indicator_and_video_input_mask(
398
  self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None