Commit
·
17d970d
1
Parent(s):
5559672
add chunking
Browse files
app.py
CHANGED
@@ -55,6 +55,7 @@ import random
|
|
55 |
from io import BytesIO
|
56 |
|
57 |
import torch
|
|
|
58 |
from cosmos_transfer1.checkpoints import (
|
59 |
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH,
|
60 |
BASE_7B_CHECKPOINT_PATH,
|
@@ -70,14 +71,13 @@ from cosmos_transfer1.diffusion.inference.world_generation_pipeline import (
|
|
70 |
)
|
71 |
from cosmos_transfer1.utils import log, misc
|
72 |
from cosmos_transfer1.utils.io import read_prompts_from_file, save_video
|
73 |
-
|
74 |
from helper import parse_arguments
|
75 |
|
76 |
torch.enable_grad(False)
|
77 |
torch.serialization.add_safe_globals([BytesIO])
|
78 |
|
79 |
|
80 |
-
def inference(cfg, control_inputs) -> Tuple[List[str], List[str]]:
|
81 |
video_paths = []
|
82 |
prompt_paths = []
|
83 |
|
@@ -87,9 +87,10 @@ def inference(cfg, control_inputs) -> Tuple[List[str], List[str]]:
|
|
87 |
device_rank = 0
|
88 |
process_group = None
|
89 |
if cfg.num_gpus > 1:
|
90 |
-
from cosmos_transfer1.utils import distributed
|
91 |
from megatron.core import parallel_state
|
92 |
|
|
|
|
|
93 |
distributed.init()
|
94 |
parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus)
|
95 |
process_group = parallel_state.get_context_parallel_group()
|
@@ -142,6 +143,7 @@ def inference(cfg, control_inputs) -> Tuple[List[str], List[str]]:
|
|
142 |
upsample_prompt=cfg.upsample_prompt,
|
143 |
offload_prompt_upsampler=cfg.offload_prompt_upsampler,
|
144 |
process_group=process_group,
|
|
|
145 |
)
|
146 |
|
147 |
if cfg.batch_input_path:
|
@@ -278,6 +280,7 @@ def generate_video(
|
|
278 |
negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
|
279 |
seed=42,
|
280 |
randomize_seed=False,
|
|
|
281 |
progress=gr.Progress(track_tqdm=True),
|
282 |
):
|
283 |
if randomize_seed:
|
@@ -315,7 +318,7 @@ def generate_video(
|
|
315 |
watcher = watch_gpu_memory(10)
|
316 |
|
317 |
# start inference
|
318 |
-
videos, prompts = inference(args, control_inputs)
|
319 |
|
320 |
# print the generation time
|
321 |
end_time = time.time()
|
@@ -361,6 +364,7 @@ with gr.Blocks() as demo:
|
|
361 |
randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=False)
|
362 |
seed_input = gr.Slider(minimum=0, maximum=1000000, value=1, step=1, label="Seed")
|
363 |
|
|
|
364 |
generate_button = gr.Button("Generate Image")
|
365 |
|
366 |
with gr.Column():
|
@@ -369,7 +373,16 @@ with gr.Blocks() as demo:
|
|
369 |
|
370 |
generate_button.click(
|
371 |
fn=generate_video,
|
372 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
outputs=[output_video, output_file, seed_input],
|
374 |
)
|
375 |
|
|
|
55 |
from io import BytesIO
|
56 |
|
57 |
import torch
|
58 |
+
|
59 |
from cosmos_transfer1.checkpoints import (
|
60 |
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH,
|
61 |
BASE_7B_CHECKPOINT_PATH,
|
|
|
71 |
)
|
72 |
from cosmos_transfer1.utils import log, misc
|
73 |
from cosmos_transfer1.utils.io import read_prompts_from_file, save_video
|
|
|
74 |
from helper import parse_arguments
|
75 |
|
76 |
torch.enable_grad(False)
|
77 |
torch.serialization.add_safe_globals([BytesIO])
|
78 |
|
79 |
|
80 |
+
def inference(cfg, control_inputs, chunking) -> Tuple[List[str], List[str]]:
|
81 |
video_paths = []
|
82 |
prompt_paths = []
|
83 |
|
|
|
87 |
device_rank = 0
|
88 |
process_group = None
|
89 |
if cfg.num_gpus > 1:
|
|
|
90 |
from megatron.core import parallel_state
|
91 |
|
92 |
+
from cosmos_transfer1.utils import distributed
|
93 |
+
|
94 |
distributed.init()
|
95 |
parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus)
|
96 |
process_group = parallel_state.get_context_parallel_group()
|
|
|
143 |
upsample_prompt=cfg.upsample_prompt,
|
144 |
offload_prompt_upsampler=cfg.offload_prompt_upsampler,
|
145 |
process_group=process_group,
|
146 |
+
chunking=chunking,
|
147 |
)
|
148 |
|
149 |
if cfg.batch_input_path:
|
|
|
280 |
negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
|
281 |
seed=42,
|
282 |
randomize_seed=False,
|
283 |
+
chunking=False,
|
284 |
progress=gr.Progress(track_tqdm=True),
|
285 |
):
|
286 |
if randomize_seed:
|
|
|
318 |
watcher = watch_gpu_memory(10)
|
319 |
|
320 |
# start inference
|
321 |
+
videos, prompts = inference(args, control_inputs, chunking)
|
322 |
|
323 |
# print the generation time
|
324 |
end_time = time.time()
|
|
|
364 |
randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=False)
|
365 |
seed_input = gr.Slider(minimum=0, maximum=1000000, value=1, step=1, label="Seed")
|
366 |
|
367 |
+
chunking_checkbox = gr.Checkbox(label="Chunking", value=True)
|
368 |
generate_button = gr.Button("Generate Image")
|
369 |
|
370 |
with gr.Column():
|
|
|
373 |
|
374 |
generate_button.click(
|
375 |
fn=generate_video,
|
376 |
+
inputs=[
|
377 |
+
rgb_video_input,
|
378 |
+
hdmap_input,
|
379 |
+
lidar_input,
|
380 |
+
prompt_input,
|
381 |
+
negative_prompt_input,
|
382 |
+
seed_input,
|
383 |
+
randomize_seed_checkbox,
|
384 |
+
chunking_checkbox,
|
385 |
+
],
|
386 |
outputs=[output_video, output_file, seed_input],
|
387 |
)
|
388 |
|
cosmos_transfer1/diffusion/inference/inference_utils.py
CHANGED
@@ -710,6 +710,7 @@ def generate_world_from_control(
|
|
710 |
x_sigma_max=None,
|
711 |
augment_sigma=None,
|
712 |
use_batch_processing: bool = True,
|
|
|
713 |
) -> Tuple[np.array, list, list]:
|
714 |
"""Generate video using a conditioning video/image input.
|
715 |
|
@@ -723,6 +724,7 @@ def generate_world_from_control(
|
|
723 |
seed (int): Random seed for generation
|
724 |
condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
|
725 |
num_input_frames (int): Number of input frames
|
|
|
726 |
|
727 |
Returns:
|
728 |
np.array: Generated video frames in shape [T,H,W,C], range [0,255]
|
@@ -761,6 +763,7 @@ def generate_world_from_control(
|
|
761 |
patch_h=h,
|
762 |
patch_w=w,
|
763 |
use_batch_processing=use_batch_processing,
|
|
|
764 |
)
|
765 |
return sample
|
766 |
|
|
|
710 |
x_sigma_max=None,
|
711 |
augment_sigma=None,
|
712 |
use_batch_processing: bool = True,
|
713 |
+
chunking: bool = False,
|
714 |
) -> Tuple[np.array, list, list]:
|
715 |
"""Generate video using a conditioning video/image input.
|
716 |
|
|
|
724 |
seed (int): Random seed for generation
|
725 |
condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
|
726 |
num_input_frames (int): Number of input frames
|
727 |
+
chunking: Whether to use the chunking method in generation pipeline
|
728 |
|
729 |
Returns:
|
730 |
np.array: Generated video frames in shape [T,H,W,C], range [0,255]
|
|
|
763 |
patch_h=h,
|
764 |
patch_w=w,
|
765 |
use_batch_processing=use_batch_processing,
|
766 |
+
chunking=chunking,
|
767 |
)
|
768 |
return sample
|
769 |
|
cosmos_transfer1/diffusion/inference/world_generation_pipeline.py
CHANGED
@@ -151,6 +151,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
|
|
151 |
regional_prompts: List[str] = None,
|
152 |
region_definitions: Union[List[List[float]], torch.Tensor] = None,
|
153 |
waymo_example: bool = False,
|
|
|
154 |
):
|
155 |
"""Initialize diffusion world generation pipeline.
|
156 |
|
@@ -178,6 +179,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
|
|
178 |
offload_prompt_upsampler: Whether to offload prompt upsampler after use
|
179 |
process_group: Process group for distributed training
|
180 |
waymo_example: Whether to use the waymo example post-training checkpoint
|
|
|
181 |
"""
|
182 |
self.num_input_frames = num_input_frames
|
183 |
self.control_inputs = control_inputs
|
@@ -201,6 +203,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
|
|
201 |
self.seed = seed
|
202 |
self.regional_prompts = regional_prompts
|
203 |
self.region_definitions = region_definitions
|
|
|
204 |
|
205 |
super().__init__(
|
206 |
checkpoint_dir=checkpoint_dir,
|
@@ -621,6 +624,7 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
|
|
621 |
sigma_max=self.sigma_max if x_sigma_max is not None else None,
|
622 |
x_sigma_max=x_sigma_max,
|
623 |
use_batch_processing=False if is_upscale_case else True,
|
|
|
624 |
)
|
625 |
log.info("Completed diffusion sampling")
|
626 |
log.info("Starting VAE decode")
|
|
|
151 |
regional_prompts: List[str] = None,
|
152 |
region_definitions: Union[List[List[float]], torch.Tensor] = None,
|
153 |
waymo_example: bool = False,
|
154 |
+
chunking: bool = False,
|
155 |
):
|
156 |
"""Initialize diffusion world generation pipeline.
|
157 |
|
|
|
179 |
offload_prompt_upsampler: Whether to offload prompt upsampler after use
|
180 |
process_group: Process group for distributed training
|
181 |
waymo_example: Whether to use the waymo example post-training checkpoint
|
182 |
+
chunking: Whether to use the chunking method in generation pipeline
|
183 |
"""
|
184 |
self.num_input_frames = num_input_frames
|
185 |
self.control_inputs = control_inputs
|
|
|
203 |
self.seed = seed
|
204 |
self.regional_prompts = regional_prompts
|
205 |
self.region_definitions = region_definitions
|
206 |
+
self.chunking = chunking
|
207 |
|
208 |
super().__init__(
|
209 |
checkpoint_dir=checkpoint_dir,
|
|
|
624 |
sigma_max=self.sigma_max if x_sigma_max is not None else None,
|
625 |
x_sigma_max=x_sigma_max,
|
626 |
use_batch_processing=False if is_upscale_case else True,
|
627 |
+
chunking=self.chunking,
|
628 |
)
|
629 |
log.info("Completed diffusion sampling")
|
630 |
log.info("Starting VAE decode")
|
cosmos_transfer1/diffusion/model/model_v2w.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
from dataclasses import dataclass
|
17 |
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
18 |
|
|
|
19 |
import torch
|
20 |
from megatron.core import parallel_state
|
21 |
from torch import Tensor
|
@@ -167,6 +168,39 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
167 |
x0_pred_replaced=x0_pred_replaced,
|
168 |
)
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
def generate_samples_from_batch(
|
171 |
self,
|
172 |
data_batch: Dict,
|
@@ -182,6 +216,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
182 |
add_input_frames_guidance: bool = False,
|
183 |
x_sigma_max: Optional[torch.Tensor] = None,
|
184 |
sigma_max: Optional[float] = None,
|
|
|
185 |
**kwargs,
|
186 |
) -> Tensor:
|
187 |
"""Generates video samples conditioned on input frames.
|
@@ -199,6 +234,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
199 |
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
200 |
add_input_frames_guidance: Whether to apply guidance to input frames
|
201 |
x_sigma_max: Maximum noise level tensor
|
|
|
202 |
|
203 |
Returns:
|
204 |
Generated video samples tensor
|
@@ -213,6 +249,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
213 |
|
214 |
assert condition_latent is not None, "condition_latent should be provided"
|
215 |
|
|
|
216 |
x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
|
217 |
data_batch,
|
218 |
guidance,
|
@@ -222,6 +259,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
222 |
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
|
223 |
add_input_frames_guidance=add_input_frames_guidance,
|
224 |
seed=seed,
|
|
|
225 |
)
|
226 |
if sigma_max is None:
|
227 |
sigma_max = self.sde.sigma_max
|
@@ -256,6 +294,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
256 |
condition_video_augment_sigma_in_inference: float = None,
|
257 |
add_input_frames_guidance: bool = False,
|
258 |
seed: int = 1,
|
|
|
259 |
) -> Callable:
|
260 |
"""Creates denoising function for conditional video generation.
|
261 |
|
@@ -268,44 +307,92 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
268 |
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
269 |
add_input_frames_guidance: Whether to apply guidance to input frames
|
270 |
seed: Random seed for reproducibility
|
|
|
271 |
|
272 |
Returns:
|
273 |
Function that takes noisy input and noise level and returns denoised prediction
|
274 |
"""
|
275 |
-
if
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
284 |
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
def add_condition_video_indicator_and_video_input_mask(
|
311 |
self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None
|
|
|
16 |
from dataclasses import dataclass
|
17 |
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
18 |
|
19 |
+
import numpy as np
|
20 |
import torch
|
21 |
from megatron.core import parallel_state
|
22 |
from torch import Tensor
|
|
|
168 |
x0_pred_replaced=x0_pred_replaced,
|
169 |
)
|
170 |
|
171 |
+
CHUNKING_SIZE = 4
|
172 |
+
CHUNKING_MODE = "rand_order" # ["shuffle", "in_order", "rand_order"]
|
173 |
+
IS_STAGGERED = True
|
174 |
+
|
175 |
+
def get_chunks_indices(self, total_flen) -> List[torch.Tensor]:
|
176 |
+
chunks_indices = []
|
177 |
+
if self.CHUNKING_MODE == "shuffle":
|
178 |
+
for index in torch.arange(0, total_flen, 1).split(self.CHUNKING_SIZE):
|
179 |
+
chunks_indices.append(index)
|
180 |
+
np.random.shuffle(chunks_indices)
|
181 |
+
else:
|
182 |
+
first_chunk_end = (
|
183 |
+
int(torch.randint(low=0, high=self.CHUNKING_SIZE, size=(1,)) + 1) if self.IS_STAGGERED else self.CHUNKING_SIZE
|
184 |
+
)
|
185 |
+
|
186 |
+
if first_chunk_end >= total_flen:
|
187 |
+
chunks_indices.append(torch.arange(total_flen))
|
188 |
+
else:
|
189 |
+
chunks_indices.append(torch.arange(first_chunk_end))
|
190 |
+
|
191 |
+
for index in torch.arange(first_chunk_end, total_flen, 1).split(self.CHUNKING_SIZE):
|
192 |
+
chunks_indices.append(index)
|
193 |
+
|
194 |
+
if self.CHUNKING_MODE == "in_order":
|
195 |
+
pass
|
196 |
+
elif self.CHUNKING_MODE == "rand_order":
|
197 |
+
if np.random.rand() > 0.5:
|
198 |
+
chunks_indices = chunks_indices[::-1]
|
199 |
+
else:
|
200 |
+
raise NotImplementedError(f"{self.CHUNKING_MODE} mode not implemented!!")
|
201 |
+
|
202 |
+
return chunks_indices
|
203 |
+
|
204 |
def generate_samples_from_batch(
|
205 |
self,
|
206 |
data_batch: Dict,
|
|
|
216 |
add_input_frames_guidance: bool = False,
|
217 |
x_sigma_max: Optional[torch.Tensor] = None,
|
218 |
sigma_max: Optional[float] = None,
|
219 |
+
chunking: bool = False,
|
220 |
**kwargs,
|
221 |
) -> Tensor:
|
222 |
"""Generates video samples conditioned on input frames.
|
|
|
234 |
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
235 |
add_input_frames_guidance: Whether to apply guidance to input frames
|
236 |
x_sigma_max: Maximum noise level tensor
|
237 |
+
chunking: Whether to use the chunking method in generation pipeline
|
238 |
|
239 |
Returns:
|
240 |
Generated video samples tensor
|
|
|
249 |
|
250 |
assert condition_latent is not None, "condition_latent should be provided"
|
251 |
|
252 |
+
# try to add chunking here !!!
|
253 |
x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
|
254 |
data_batch,
|
255 |
guidance,
|
|
|
259 |
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
|
260 |
add_input_frames_guidance=add_input_frames_guidance,
|
261 |
seed=seed,
|
262 |
+
chunking=chunking,
|
263 |
)
|
264 |
if sigma_max is None:
|
265 |
sigma_max = self.sde.sigma_max
|
|
|
294 |
condition_video_augment_sigma_in_inference: float = None,
|
295 |
add_input_frames_guidance: bool = False,
|
296 |
seed: int = 1,
|
297 |
+
chunking: bool = False,
|
298 |
) -> Callable:
|
299 |
"""Creates denoising function for conditional video generation.
|
300 |
|
|
|
307 |
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
308 |
add_input_frames_guidance: Whether to apply guidance to input frames
|
309 |
seed: Random seed for reproducibility
|
310 |
+
chunking: Whether to use the chunking method in generation pipeline
|
311 |
|
312 |
Returns:
|
313 |
Function that takes noisy input and noise level and returns denoised prediction
|
314 |
"""
|
315 |
+
if not chunking:
|
316 |
+
if is_negative_prompt:
|
317 |
+
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
318 |
+
else:
|
319 |
+
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
|
320 |
+
|
321 |
+
condition.video_cond_bool = True
|
322 |
+
condition = self.add_condition_video_indicator_and_video_input_mask(
|
323 |
+
condition_latent, condition, num_condition_t
|
324 |
+
)
|
325 |
|
326 |
+
uncondition.video_cond_bool = False if add_input_frames_guidance else True
|
327 |
+
uncondition = self.add_condition_video_indicator_and_video_input_mask(
|
328 |
+
condition_latent, uncondition, num_condition_t
|
329 |
+
)
|
330 |
|
331 |
+
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
332 |
+
cond_x0 = self.denoise(
|
333 |
+
noise_x,
|
334 |
+
sigma,
|
335 |
+
condition,
|
336 |
+
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
|
337 |
+
seed=seed,
|
338 |
+
).x0_pred_replaced
|
339 |
+
uncond_x0 = self.denoise(
|
340 |
+
noise_x,
|
341 |
+
sigma,
|
342 |
+
uncondition,
|
343 |
+
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
|
344 |
+
seed=seed,
|
345 |
+
).x0_pred_replaced
|
346 |
+
|
347 |
+
return cond_x0 + guidance * (cond_x0 - uncond_x0)
|
348 |
+
|
349 |
+
return x0_fn
|
350 |
+
else:
|
351 |
+
log.critical("GO CHUNKING !!!")
|
352 |
+
|
353 |
+
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
354 |
+
if is_negative_prompt:
|
355 |
+
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
356 |
+
else:
|
357 |
+
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
|
358 |
+
|
359 |
+
noises = torch.zeros_like(condition_latent)
|
360 |
+
T = condition_latent.shape[2]
|
361 |
+
for chunk_idx in self.get_chunks_indices(T):
|
362 |
+
latents_ = condition_latent[:, :, chunk_idx, :, :]
|
363 |
+
log.info(f"chunk_idx: {chunk_idx}, chunk shape: {latents_.shape}")
|
364 |
+
# controlnet_cond_ = self.controlnet_data[:, chunk_idx]
|
365 |
+
|
366 |
+
condition.video_cond_bool = True
|
367 |
+
condition = self.add_condition_video_indicator_and_video_input_mask(
|
368 |
+
latents_, condition, num_condition_t
|
369 |
+
)
|
370 |
+
|
371 |
+
uncondition.video_cond_bool = False if add_input_frames_guidance else True
|
372 |
+
uncondition = self.add_condition_video_indicator_and_video_input_mask(
|
373 |
+
latents_, uncondition, num_condition_t
|
374 |
+
)
|
375 |
+
|
376 |
+
cond_x0 = self.denoise(
|
377 |
+
noise_x,
|
378 |
+
sigma,
|
379 |
+
condition,
|
380 |
+
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
|
381 |
+
seed=seed,
|
382 |
+
).x0_pred_replaced
|
383 |
+
uncond_x0 = self.denoise(
|
384 |
+
noise_x,
|
385 |
+
sigma,
|
386 |
+
uncondition,
|
387 |
+
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
|
388 |
+
seed=seed,
|
389 |
+
).x0_pred_replaced
|
390 |
+
|
391 |
+
noises[:, :, chunk_idx, :, :] = cond_x0 + guidance * (cond_x0 - uncond_x0)
|
392 |
+
|
393 |
+
# TODO: need scheduler ?
|
394 |
+
return noises
|
395 |
+
return x0_fn
|
396 |
|
397 |
def add_condition_video_indicator_and_video_input_mask(
|
398 |
self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None
|