Commit
·
74308ee
1
Parent(s):
4953ce6
make chunking size as a function argument & add a slider to control it
Browse files
app.py
CHANGED
@@ -296,7 +296,7 @@ def generate_video(
|
|
296 |
negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
|
297 |
seed=42,
|
298 |
randomize_seed=False,
|
299 |
-
chunking=
|
300 |
progress=gr.Progress(track_tqdm=True),
|
301 |
):
|
302 |
_dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
|
@@ -338,6 +338,8 @@ def generate_video(
|
|
338 |
watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory usage: {x} (MiB)"))
|
339 |
|
340 |
# start inference
|
|
|
|
|
341 |
videos, prompts = inference(args, control_inputs, chunking)
|
342 |
|
343 |
# print the generation time
|
@@ -386,7 +388,7 @@ with gr.Blocks() as demo:
|
|
386 |
randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=False)
|
387 |
seed_input = gr.Slider(minimum=0, maximum=1000000, value=1, step=1, label="Seed")
|
388 |
|
389 |
-
|
390 |
generate_button = gr.Button("Generate Image")
|
391 |
|
392 |
with gr.Column():
|
@@ -403,7 +405,7 @@ with gr.Blocks() as demo:
|
|
403 |
negative_prompt_input,
|
404 |
seed_input,
|
405 |
randomize_seed_checkbox,
|
406 |
-
|
407 |
],
|
408 |
outputs=[output_video, output_file, seed_input],
|
409 |
)
|
|
|
296 |
negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
|
297 |
seed=42,
|
298 |
randomize_seed=False,
|
299 |
+
chunking=None,
|
300 |
progress=gr.Progress(track_tqdm=True),
|
301 |
):
|
302 |
_dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
|
|
|
338 |
watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory usage: {x} (MiB)"))
|
339 |
|
340 |
# start inference
|
341 |
+
if chunking <= 0:
|
342 |
+
chunking = None
|
343 |
videos, prompts = inference(args, control_inputs, chunking)
|
344 |
|
345 |
# print the generation time
|
|
|
388 |
randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=False)
|
389 |
seed_input = gr.Slider(minimum=0, maximum=1000000, value=1, step=1, label="Seed")
|
390 |
|
391 |
+
chunking_input = gr.Slider(minimum=0, maximum=121, value=4, step=1, label="Chunking size")
|
392 |
generate_button = gr.Button("Generate Image")
|
393 |
|
394 |
with gr.Column():
|
|
|
405 |
negative_prompt_input,
|
406 |
seed_input,
|
407 |
randomize_seed_checkbox,
|
408 |
+
chunking_input,
|
409 |
],
|
410 |
outputs=[output_video, output_file, seed_input],
|
411 |
)
|
cosmos_transfer1/diffusion/inference/inference_utils.py
CHANGED
@@ -710,7 +710,7 @@ def generate_world_from_control(
|
|
710 |
x_sigma_max=None,
|
711 |
augment_sigma=None,
|
712 |
use_batch_processing: bool = True,
|
713 |
-
chunking:
|
714 |
) -> Tuple[np.array, list, list]:
|
715 |
"""Generate video using a conditioning video/image input.
|
716 |
|
@@ -724,7 +724,7 @@ def generate_world_from_control(
|
|
724 |
seed (int): Random seed for generation
|
725 |
condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
|
726 |
num_input_frames (int): Number of input frames
|
727 |
-
chunking:
|
728 |
|
729 |
Returns:
|
730 |
np.array: Generated video frames in shape [T,H,W,C], range [0,255]
|
|
|
710 |
x_sigma_max=None,
|
711 |
augment_sigma=None,
|
712 |
use_batch_processing: bool = True,
|
713 |
+
chunking: Optional[int] = None,
|
714 |
) -> Tuple[np.array, list, list]:
|
715 |
"""Generate video using a conditioning video/image input.
|
716 |
|
|
|
724 |
seed (int): Random seed for generation
|
725 |
condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
|
726 |
num_input_frames (int): Number of input frames
|
727 |
+
chunking: Chunking size, if None, chunking is disabled
|
728 |
|
729 |
Returns:
|
730 |
np.array: Generated video frames in shape [T,H,W,C], range [0,255]
|
cosmos_transfer1/diffusion/inference/world_generation_pipeline.py
CHANGED
@@ -151,7 +151,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
|
|
151 |
regional_prompts: List[str] = None,
|
152 |
region_definitions: Union[List[List[float]], torch.Tensor] = None,
|
153 |
waymo_example: bool = False,
|
154 |
-
chunking:
|
155 |
):
|
156 |
"""Initialize diffusion world generation pipeline.
|
157 |
|
@@ -179,7 +179,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
|
|
179 |
offload_prompt_upsampler: Whether to offload prompt upsampler after use
|
180 |
process_group: Process group for distributed training
|
181 |
waymo_example: Whether to use the waymo example post-training checkpoint
|
182 |
-
chunking:
|
183 |
"""
|
184 |
self.num_input_frames = num_input_frames
|
185 |
self.control_inputs = control_inputs
|
|
|
151 |
regional_prompts: List[str] = None,
|
152 |
region_definitions: Union[List[List[float]], torch.Tensor] = None,
|
153 |
waymo_example: bool = False,
|
154 |
+
chunking: Optional[int] = None,
|
155 |
):
|
156 |
"""Initialize diffusion world generation pipeline.
|
157 |
|
|
|
179 |
offload_prompt_upsampler: Whether to offload prompt upsampler after use
|
180 |
process_group: Process group for distributed training
|
181 |
waymo_example: Whether to use the waymo example post-training checkpoint
|
182 |
+
chunking: Chunking size, if None, chunking is disabled
|
183 |
"""
|
184 |
self.num_input_frames = num_input_frames
|
185 |
self.control_inputs = control_inputs
|
cosmos_transfer1/diffusion/model/model_v2w.py
CHANGED
@@ -168,19 +168,18 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
168 |
x0_pred_replaced=x0_pred_replaced,
|
169 |
)
|
170 |
|
171 |
-
CHUNKING_SIZE = 4
|
172 |
CHUNKING_MODE = "rand_order" # ["shuffle", "in_order", "rand_order"]
|
173 |
IS_STAGGERED = True
|
174 |
|
175 |
-
def get_chunks_indices(self, total_flen) -> List[torch.Tensor]:
|
176 |
chunks_indices = []
|
177 |
if self.CHUNKING_MODE == "shuffle":
|
178 |
-
for index in torch.arange(0, total_flen, 1).split(
|
179 |
chunks_indices.append(index)
|
180 |
np.random.shuffle(chunks_indices)
|
181 |
else:
|
182 |
first_chunk_end = (
|
183 |
-
int(torch.randint(low=0, high=
|
184 |
)
|
185 |
|
186 |
if first_chunk_end >= total_flen:
|
@@ -188,7 +187,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
188 |
else:
|
189 |
chunks_indices.append(torch.arange(first_chunk_end))
|
190 |
|
191 |
-
for index in torch.arange(first_chunk_end, total_flen, 1).split(
|
192 |
chunks_indices.append(index)
|
193 |
|
194 |
if self.CHUNKING_MODE == "in_order":
|
@@ -216,7 +215,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
216 |
add_input_frames_guidance: bool = False,
|
217 |
x_sigma_max: Optional[torch.Tensor] = None,
|
218 |
sigma_max: Optional[float] = None,
|
219 |
-
chunking:
|
220 |
**kwargs,
|
221 |
) -> Tensor:
|
222 |
"""Generates video samples conditioned on input frames.
|
@@ -234,7 +233,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
234 |
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
235 |
add_input_frames_guidance: Whether to apply guidance to input frames
|
236 |
x_sigma_max: Maximum noise level tensor
|
237 |
-
chunking:
|
238 |
|
239 |
Returns:
|
240 |
Generated video samples tensor
|
@@ -294,7 +293,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
294 |
condition_video_augment_sigma_in_inference: float = None,
|
295 |
add_input_frames_guidance: bool = False,
|
296 |
seed: int = 1,
|
297 |
-
chunking:
|
298 |
) -> Callable:
|
299 |
"""Creates denoising function for conditional video generation.
|
300 |
|
@@ -307,12 +306,12 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
307 |
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
308 |
add_input_frames_guidance: Whether to apply guidance to input frames
|
309 |
seed: Random seed for reproducibility
|
310 |
-
chunking:
|
311 |
|
312 |
Returns:
|
313 |
Function that takes noisy input and noise level and returns denoised prediction
|
314 |
"""
|
315 |
-
if
|
316 |
if is_negative_prompt:
|
317 |
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
318 |
else:
|
@@ -348,8 +347,6 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
348 |
|
349 |
return x0_fn
|
350 |
else:
|
351 |
-
log.critical("GO CHUNKING !!!")
|
352 |
-
|
353 |
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
354 |
if is_negative_prompt:
|
355 |
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
@@ -358,7 +355,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
358 |
|
359 |
noises = torch.zeros_like(condition_latent)
|
360 |
T = condition_latent.shape[2]
|
361 |
-
for chunk_idx in self.get_chunks_indices(T):
|
362 |
latents_ = condition_latent[:, :, chunk_idx, :, :]
|
363 |
log.info(f"chunk_idx: {chunk_idx}, chunk shape: {latents_.shape}")
|
364 |
# controlnet_cond_ = self.controlnet_data[:, chunk_idx]
|
|
|
168 |
x0_pred_replaced=x0_pred_replaced,
|
169 |
)
|
170 |
|
|
|
171 |
CHUNKING_MODE = "rand_order" # ["shuffle", "in_order", "rand_order"]
|
172 |
IS_STAGGERED = True
|
173 |
|
174 |
+
def get_chunks_indices(self, total_flen, chunking_size) -> List[torch.Tensor]:
|
175 |
chunks_indices = []
|
176 |
if self.CHUNKING_MODE == "shuffle":
|
177 |
+
for index in torch.arange(0, total_flen, 1).split(chunking_size):
|
178 |
chunks_indices.append(index)
|
179 |
np.random.shuffle(chunks_indices)
|
180 |
else:
|
181 |
first_chunk_end = (
|
182 |
+
int(torch.randint(low=0, high=chunking_size, size=(1,)) + 1) if self.IS_STAGGERED else chunking_size
|
183 |
)
|
184 |
|
185 |
if first_chunk_end >= total_flen:
|
|
|
187 |
else:
|
188 |
chunks_indices.append(torch.arange(first_chunk_end))
|
189 |
|
190 |
+
for index in torch.arange(first_chunk_end, total_flen, 1).split(chunking_size):
|
191 |
chunks_indices.append(index)
|
192 |
|
193 |
if self.CHUNKING_MODE == "in_order":
|
|
|
215 |
add_input_frames_guidance: bool = False,
|
216 |
x_sigma_max: Optional[torch.Tensor] = None,
|
217 |
sigma_max: Optional[float] = None,
|
218 |
+
chunking: Optional[int] = None,
|
219 |
**kwargs,
|
220 |
) -> Tensor:
|
221 |
"""Generates video samples conditioned on input frames.
|
|
|
233 |
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
234 |
add_input_frames_guidance: Whether to apply guidance to input frames
|
235 |
x_sigma_max: Maximum noise level tensor
|
236 |
+
chunking: Chunking size, if None, chunking is disabled
|
237 |
|
238 |
Returns:
|
239 |
Generated video samples tensor
|
|
|
293 |
condition_video_augment_sigma_in_inference: float = None,
|
294 |
add_input_frames_guidance: bool = False,
|
295 |
seed: int = 1,
|
296 |
+
chunking: Optional[int] = None,
|
297 |
) -> Callable:
|
298 |
"""Creates denoising function for conditional video generation.
|
299 |
|
|
|
306 |
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
307 |
add_input_frames_guidance: Whether to apply guidance to input frames
|
308 |
seed: Random seed for reproducibility
|
309 |
+
chunking: Chunking size, if None, chunking is disabled
|
310 |
|
311 |
Returns:
|
312 |
Function that takes noisy input and noise level and returns denoised prediction
|
313 |
"""
|
314 |
+
if chunking is None:
|
315 |
if is_negative_prompt:
|
316 |
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
317 |
else:
|
|
|
347 |
|
348 |
return x0_fn
|
349 |
else:
|
|
|
|
|
350 |
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
351 |
if is_negative_prompt:
|
352 |
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
|
|
355 |
|
356 |
noises = torch.zeros_like(condition_latent)
|
357 |
T = condition_latent.shape[2]
|
358 |
+
for chunk_idx in self.get_chunks_indices(T, chunking):
|
359 |
latents_ = condition_latent[:, :, chunk_idx, :, :]
|
360 |
log.info(f"chunk_idx: {chunk_idx}, chunk shape: {latents_.shape}")
|
361 |
# controlnet_cond_ = self.controlnet_data[:, chunk_idx]
|