Spaces:
Paused
Paused
from typing import Any, Dict, List, Tuple | |
from pathlib import Path | |
import os | |
import hashlib | |
import json | |
import random | |
import wandb | |
import math | |
import numpy as np | |
from einops import rearrange, repeat | |
from safetensors.torch import load_file, save_file | |
from accelerate.logging import get_logger | |
import torch | |
from accelerate.utils import gather_object | |
from diffusers import ( | |
AutoencoderKLCogVideoX, | |
CogVideoXDPMScheduler, | |
CogVideoXImageToVideoPipeline, | |
CogVideoXTransformer3DModel, | |
) | |
from diffusers.utils.export_utils import export_to_video | |
from finetune.pipeline.flovd_FVSM_cogvideox_controlnet_pipeline import FloVDCogVideoXControlnetImageToVideoPipeline | |
from finetune.constants import LOG_LEVEL, LOG_NAME | |
from diffusers.models.embeddings import get_3d_rotary_pos_embed | |
from PIL import Image | |
from numpy import dtype | |
from transformers import AutoTokenizer, T5EncoderModel | |
from typing_extensions import override | |
from finetune.schemas import Args, Components, State | |
from finetune.trainer import Trainer | |
from finetune.utils import ( | |
cast_training_params, | |
free_memory, | |
get_memory_statistics, | |
string_to_filename, | |
unwrap_model, | |
) | |
from finetune.datasets.utils import ( | |
preprocess_image_with_resize, | |
load_binary_mask_compressed, | |
) | |
from finetune.modules.cogvideox_controlnet import CogVideoXControlnet | |
from finetune.modules.cogvideox_custom_model import CustomCogVideoXTransformer3DModel | |
from finetune.modules.camera_sampler import SampleManualCam | |
from finetune.modules.camera_flow_generator import CameraFlowGenerator | |
from finetune.modules.utils import get_camera_flow_generator_input, forward_bilinear_splatting | |
from ..utils import register | |
import sys | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
import pdb | |
logger = get_logger(LOG_NAME, LOG_LEVEL) | |
class FloVDCogVideoXI2VControlnetTrainer(Trainer): | |
UNLOAD_LIST = ["text_encoder"] | |
def __init__(self, args: Args) -> None: | |
super().__init__(args) | |
# For validation | |
self.CameraSampler = SampleManualCam() | |
def load_components(self) -> Dict[str, Any]: | |
# TODO. Change the pipeline and ... | |
components = Components() | |
model_path = str(self.args.model_path) | |
components.pipeline_cls = FloVDCogVideoXControlnetImageToVideoPipeline | |
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") | |
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder") | |
# components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer") | |
components.transformer = CustomCogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer") | |
additional_kwargs = { | |
'num_layers': self.args.controlnet_transformer_num_layers, | |
'out_proj_dim_factor': self.args.controlnet_out_proj_dim_factor, | |
'out_proj_dim_zero_init': self.args.controlnet_out_proj_zero_init, | |
'notextinflow': self.args.notextinflow, | |
} | |
components.controlnet = CogVideoXControlnet.from_pretrained(model_path, subfolder="transformer", **additional_kwargs) | |
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae") | |
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler") | |
return components | |
def initialize_pipeline(self) -> FloVDCogVideoXControlnetImageToVideoPipeline: | |
# TODO. Change the pipeline and ... | |
pipe = FloVDCogVideoXControlnetImageToVideoPipeline( | |
tokenizer=self.components.tokenizer, | |
text_encoder=unwrap_model(self.accelerator, self.components.text_encoder), | |
vae=unwrap_model(self.accelerator, self.components.vae), | |
transformer=unwrap_model(self.accelerator, self.components.transformer), | |
controlnet=unwrap_model(self.accelerator, self.components.controlnet), | |
scheduler=self.components.scheduler, | |
) | |
return pipe | |
def initialize_flow_generator(self, ckpt_path): | |
depth_estimator_kwargs = { | |
"target": 'modules.depth_warping.depth_warping.DepthWarping_wrapper', | |
"kwargs": { | |
"ckpt_path": ckpt_path, | |
"model_config": { | |
"max_depth": 20, | |
"encoder": 'vitb', | |
"features": 128, | |
"out_channels": [96, 192, 384, 768], | |
} | |
} | |
} | |
return CameraFlowGenerator(depth_estimator_kwargs) | |
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: | |
ret = {"encoded_videos": [], "prompt_embedding": [], "images": [], "encoded_flow": []} | |
for sample in samples: | |
encoded_video = sample["encoded_video"] | |
prompt_embedding = sample["prompt_embedding"] | |
image = sample["image"] | |
encoded_flow = sample["encoded_flow"] | |
ret["encoded_videos"].append(encoded_video) | |
ret["prompt_embedding"].append(prompt_embedding) | |
ret["images"].append(image) | |
ret["encoded_flow"].append(encoded_flow) | |
ret["encoded_videos"] = torch.stack(ret["encoded_videos"]) | |
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"]) | |
ret["images"] = torch.stack(ret["images"]) | |
ret["encoded_flow"] = torch.stack(ret["encoded_flow"]) | |
return ret | |
def compute_loss(self, batch) -> torch.Tensor: | |
prompt_embedding = batch["prompt_embedding"] | |
latent = batch["encoded_videos"] | |
images = batch["images"] | |
latent_flow = batch["encoded_flow"] | |
# Shape of prompt_embedding: [B, seq_len, hidden_size] | |
# Shape of latent: [B, C, F, H, W] | |
# Shape of images: [B, C, H, W] | |
# Shape of latent_flow: [B, C, F, H, W] | |
patch_size_t = self.state.transformer_config.patch_size_t # WJ: None in i2v setting... | |
if patch_size_t is not None: | |
ncopy = latent.shape[2] % patch_size_t | |
# Copy the first frame ncopy times to match patch_size_t | |
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W] | |
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2) | |
assert latent.shape[2] % patch_size_t == 0 | |
batch_size, num_channels, num_frames, height, width = latent.shape | |
# Get prompt embeddings | |
_, seq_len, _ = prompt_embedding.shape | |
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype) | |
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W] | |
images = images.unsqueeze(2) | |
# Add noise to images | |
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device) | |
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype) | |
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] | |
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist | |
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor | |
""" | |
Modify below | |
""" | |
# Sample a random timestep for each sample | |
# timesteps = torch.randint( | |
# 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device | |
# ) | |
if self.args.enable_time_sampling: | |
if self.args.time_sampling_type == "truncated_normal": | |
time_sampling_dict = { | |
'mean': self.args.time_sampling_mean, | |
'std': self.args.time_sampling_std, | |
'a': 1 - self.args.controlnet_guidance_end, | |
'b': 1 - self.args.controlnet_guidance_start, | |
} | |
timesteps = torch.nn.init.trunc_normal_( | |
torch.empty(batch_size, device=latent.device), **time_sampling_dict | |
) * self.components.scheduler.config.num_train_timesteps | |
elif self.args.time_sampling_type == "truncated_uniform": | |
timesteps = torch.randint( | |
int((1- self.args.controlnet_guidance_end) * self.components.scheduler.config.num_train_timesteps), | |
int((1 - self.args.controlnet_guidance_start) * self.components.scheduler.config.num_train_timesteps), | |
(batch_size,), device=latent.device | |
) | |
else: | |
timesteps = torch.randint( | |
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device | |
) | |
timesteps = timesteps.long() | |
# from [B, C, F, H, W] to [B, F, C, H, W] | |
latent = latent.permute(0, 2, 1, 3, 4) | |
latent_flow = latent_flow.permute(0, 2, 1, 3, 4) | |
image_latents = image_latents.permute(0, 2, 1, 3, 4) | |
assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:]) == (latent_flow.shape[0], *latent_flow.shape[2:]) | |
# Padding image_latents to the same frame number as latent | |
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:]) | |
latent_padding = image_latents.new_zeros(padding_shape) | |
image_latents = torch.cat([image_latents, latent_padding], dim=1) | |
# Add noise to latent | |
noise = torch.randn_like(latent) | |
latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps) | |
# Concatenate latent and image_latents in the channel dimension | |
# latent_img_flow_noisy = torch.cat([latent_noisy, image_latents, latent_flow], dim=2) | |
latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2) | |
# Prepare rotary embeds | |
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1) | |
transformer_config = self.state.transformer_config | |
rotary_emb = ( | |
self.prepare_rotary_positional_embeddings( | |
height=height * vae_scale_factor_spatial, | |
width=width * vae_scale_factor_spatial, | |
num_frames=num_frames, | |
transformer_config=transformer_config, | |
vae_scale_factor_spatial=vae_scale_factor_spatial, | |
device=self.accelerator.device, | |
) | |
if transformer_config.use_rotary_positional_embeddings | |
else None | |
) | |
# Predict noise, For CogVideoX1.5 Only. | |
ofs_emb = ( | |
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0) | |
) | |
# Controlnet feedforward | |
controlnet_states = self.components.controlnet( | |
hidden_states=latent_noisy, | |
encoder_hidden_states=prompt_embedding, | |
image_rotary_emb=rotary_emb, | |
controlnet_hidden_states=latent_flow, | |
timestep=timesteps, | |
return_dict=False, | |
)[0] | |
if isinstance(controlnet_states, (tuple, list)): | |
controlnet_states = [x.to(dtype=self.state.weight_dtype) for x in controlnet_states] | |
else: | |
controlnet_states = controlnet_states.to(dtype=self.state.weight_dtype) | |
# Transformer feedforward | |
predicted_noise = self.components.transformer( | |
hidden_states=latent_img_noisy, | |
encoder_hidden_states=prompt_embedding, | |
controlnet_states=controlnet_states, | |
controlnet_weights=self.args.controlnet_weights, | |
timestep=timesteps, | |
# ofs=ofs_emb, | |
image_rotary_emb=rotary_emb, | |
return_dict=False, | |
)[0] | |
# Denoise | |
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps) | |
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps] | |
weights = 1 / (1 - alphas_cumprod) | |
while len(weights.shape) < len(latent_pred.shape): | |
weights = weights.unsqueeze(-1) | |
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1) | |
loss = loss.mean() | |
return loss | |
def prepare_rotary_positional_embeddings( | |
self, | |
height: int, | |
width: int, | |
num_frames: int, | |
transformer_config: Dict, | |
vae_scale_factor_spatial: int, | |
device: torch.device, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size) | |
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size) | |
if transformer_config.patch_size_t is None: | |
base_num_frames = num_frames | |
else: | |
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t | |
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( | |
embed_dim=transformer_config.attention_head_dim, | |
crops_coords=None, | |
grid_size=(grid_height, grid_width), | |
temporal_size=base_num_frames, | |
grid_type="slice", | |
max_size=(grid_height, grid_width), | |
device=device, | |
) | |
return freqs_cos, freqs_sin | |
# Validation | |
def prepare_for_validation(self): | |
# Load from dataset? | |
# Data_root | |
# - metadata.jsonl | |
# - video_latent / args.resolution / | |
# - prompt_embeddings / | |
# - first_frames / | |
# - flow_direct_f_latent / | |
data_root = self.args.data_root | |
metadata_path = data_root / "metadata_revised.jsonl" | |
assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl or metadata_revised.jsonl in the root path" | |
# Load metadata | |
# metadata = { | |
# "video_path": ..., | |
# "hash_code": ..., | |
# "prompt": ..., | |
# } | |
metadata = [] | |
with open(metadata_path, "r") as f: | |
for line in f: | |
metadata.append( json.loads(line) ) | |
metadata = random.sample(metadata, self.args.max_scene) | |
prompts = [x["prompt"] for x in metadata] | |
prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata] | |
videos = [data_root / "video_latent" / "x".join(str(x) for x in self.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata] | |
images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata] | |
flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata] | |
# load prompt embedding | |
validation_prompts = [] | |
validation_prompt_embeddings = [] | |
validation_video_latents = [] | |
validation_images = [] | |
validation_flow_latents = [] | |
for prompt, prompt_embedding, video_latent, image, flow_latent in zip(prompts, prompt_embeddings, videos, images, flows): | |
validation_prompts.append(prompt) | |
validation_prompt_embeddings.append(load_file(prompt_embedding)["prompt_embedding"].unsqueeze(0)) | |
validation_video_latents.append(load_file(video_latent)["encoded_video"].unsqueeze(0)) | |
validation_flow_latents.append(load_file(flow_latent)["encoded_flow_f"].unsqueeze(0)) | |
# validation_images.append(preprocess_image_with_resize(image, self.args.train_resolution[1], self.args.train_resolution[2])) | |
validation_images.append(image) | |
validation_videos = [None] * len(validation_prompts) | |
self.state.validation_prompts = validation_prompts | |
self.state.validation_prompt_embeddings = validation_prompt_embeddings | |
self.state.validation_images = validation_images | |
self.state.validation_videos = validation_videos | |
self.state.validation_video_latents = validation_video_latents | |
self.state.validation_flow_latents = validation_flow_latents | |
# Debug.. | |
# self.validate(0) | |
def validation_step( | |
self, eval_data: Dict[str, Any], pipe: FloVDCogVideoXControlnetImageToVideoPipeline | |
) -> List[Tuple[str, Image.Image | List[Image.Image]]]: | |
""" | |
Return the data that needs to be saved. For videos, the data format is List[PIL], | |
and for images, the data format is PIL | |
""" | |
prompt_embedding, image, flow_latent = eval_data["prompt_embedding"], eval_data["image"], eval_data["flow_latent"] | |
video_generate = pipe( | |
num_frames=self.state.train_frames, | |
height=self.state.train_height, | |
width=self.state.train_width, | |
prompt=None, | |
prompt_embeds=prompt_embedding, | |
image=image, | |
flow_latent=flow_latent, | |
generator=self.state.generator, | |
num_inference_steps=50, | |
controlnet_guidance_start = self.args.controlnet_guidance_start, | |
controlnet_guidance_end = self.args.controlnet_guidance_end, | |
).frames[0] | |
return [("synthesized_video", video_generate)] | |
def validate(self, step: int) -> None: | |
#TODO. Fix the codes!!!! | |
logger.info("Starting validation") | |
accelerator = self.accelerator | |
num_validation_samples = len(self.state.validation_prompts) | |
if num_validation_samples == 0: | |
logger.warning("No validation samples found. Skipping validation.") | |
return | |
self.components.controlnet.eval() | |
torch.set_grad_enabled(False) | |
memory_statistics = get_memory_statistics() | |
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") | |
##### Initialize pipeline ##### | |
pipe = self.initialize_pipeline() | |
camera_flow_generator = self.initialize_flow_generator(ckpt_path=self.args.depth_ckpt_path).to(device=self.accelerator.device, dtype=self.state.weight_dtype) | |
if self.state.using_deepspeed: | |
# Can't using model_cpu_offload in deepspeed, | |
# so we need to move all components in pipe to device | |
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype) | |
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["controlnet"]) | |
# self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer", "controlnet"]) | |
else: | |
# if not using deepspeed, use model_cpu_offload to further reduce memory usage | |
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage | |
pipe.enable_model_cpu_offload(device=self.accelerator.device) | |
# Convert all model weights to training dtype | |
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 | |
pipe = pipe.to(dtype=self.state.weight_dtype) | |
################################# | |
inference_type = ['training', 'inference'] | |
# inference_type = ['inference'] | |
for infer_type in inference_type: | |
all_processes_artifacts = [] | |
for i in range(num_validation_samples): | |
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3: | |
# Skip current validation on all processes but one | |
if i % accelerator.num_processes != accelerator.process_index: | |
continue | |
prompt = self.state.validation_prompts[i] | |
image = self.state.validation_images[i] | |
video = self.state.validation_videos[i] | |
video_latent = self.state.validation_video_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90]) | |
prompt_embedding = self.state.validation_prompt_embeddings[i] | |
flow_latent = self.state.validation_flow_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90]) | |
if image is not None: | |
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width) | |
image_torch = image.detach().clone() | |
# Convert image tensor (C, H, W) to PIL images | |
image = image.to(torch.uint8) | |
image = image.permute(1, 2, 0).cpu().numpy() | |
image = Image.fromarray(image) | |
if video is not None: | |
video = preprocess_video_with_resize( | |
video, self.state.train_frames, self.state.train_height, self.state.train_width | |
) | |
# Convert video tensor (F, C, H, W) to list of PIL images | |
video = video.round().clamp(0, 255).to(torch.uint8) | |
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video] | |
else: | |
if infer_type == 'training': | |
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype): | |
try: | |
video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae) | |
except: | |
pass | |
video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae) | |
video = ((video_decoded + 1.) / 2. * 255.)[0].permute(1,0,2,3).float().clip(0., 255.).to(torch.uint8) | |
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video] | |
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype): | |
try: | |
flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) | |
except: | |
pass | |
flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) # (BF)CHW (C=2) | |
# Prepare camera flow | |
if infer_type == 'inference': | |
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype): | |
camparam, cam_name = self.CameraSampler.sample() | |
camera_flow_generator_input = get_camera_flow_generator_input(image_torch, camparam, device=self.accelerator.device, speed=0.5) | |
image_torch = ((image_torch.unsqueeze(0) / 255.) * 2. - 1.).to(self.accelerator.device) | |
camera_flow, log_dict = camera_flow_generator(image_torch, camera_flow_generator_input) | |
camera_flow = camera_flow.to(self.accelerator.device) | |
# WTF, unknown bug. Need warm up inference. | |
try: | |
flow_latent = rearrange(encode_flow(camera_flow, self.components.vae, flow_scale_factor=[60, 36]), 'b c f h w -> b f c h w').to(self.accelerator.device, self.state.weight_dtype) | |
except: | |
pass | |
flow_latent = rearrange(encode_flow(camera_flow, self.components.vae, flow_scale_factor=[60, 36]), 'b c f h w -> b f c h w').to(self.accelerator.device, self.state.weight_dtype) | |
logger.debug( | |
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", | |
main_process_only=False, | |
) | |
# validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) | |
validation_artifacts = self.validation_step({"prompt_embedding": prompt_embedding, "image": image, "flow_latent": flow_latent}, pipe) | |
if ( | |
self.state.using_deepspeed | |
and self.accelerator.deepspeed_plugin.zero_stage == 3 | |
and not accelerator.is_main_process | |
): | |
continue | |
prompt_filename = string_to_filename(prompt)[:25] | |
# Calculate hash of reversed prompt as a unique identifier | |
reversed_prompt = prompt[::-1] | |
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5] | |
artifacts = { | |
"image": {"type": "image", "value": image}, | |
"video": {"type": "video", "value": video}, | |
} | |
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): | |
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) | |
if infer_type == 'training': | |
# Log flow_warped_frames | |
image_tensor = repeat(rearrange(torch.tensor(np.array(image)).to(flow_decoded.device, torch.float), 'h w c -> 1 c h w'), 'b c h w -> (b f) c h w', f=flow_decoded.size(0)) # scale~(0,255) (BF) C H W | |
warped_video = forward_bilinear_splatting(image_tensor, flow_decoded.to(torch.float)) # if we have an occlusion mask from dataset, we can use it. | |
frame_list = [] | |
for frame in warped_video: | |
frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255) | |
frame_list.append(Image.fromarray(frame)) | |
artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}}) | |
if infer_type == 'inference': | |
warped_video = log_dict['depth_warped_frames'] | |
frame_list = [] | |
for frame in warped_video: | |
frame = (frame + 1.)/2. * 255. | |
frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255) | |
frame_list.append(Image.fromarray(frame)) | |
artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}}) | |
logger.debug( | |
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", | |
main_process_only=False, | |
) | |
for key, value in list(artifacts.items()): | |
artifact_type = value["type"] | |
artifact_value = value["value"] | |
if artifact_type not in ["image", "video", "warped_video", "synthesized_video"] or artifact_value is None: | |
continue | |
extension = "png" if artifact_type == "image" else "mp4" | |
if artifact_type == "warped_video": | |
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}_warped_video.{extension}" | |
elif artifact_type == "synthesized_video": | |
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}_synthesized_video.{extension}" | |
else: | |
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}.{extension}" | |
validation_path = self.args.output_dir / "validation_res" | |
validation_path.mkdir(parents=True, exist_ok=True) | |
filename = str(validation_path / filename) | |
if artifact_type == "image": | |
logger.debug(f"Saving image to {filename}") | |
artifact_value.save(filename) | |
artifact_value = wandb.Image(filename) | |
elif artifact_type == "video" or artifact_type == "warped_video" or artifact_type == "synthesized_video": | |
logger.debug(f"Saving video to {filename}") | |
export_to_video(artifact_value, filename, fps=self.args.gen_fps) | |
artifact_value = wandb.Video(filename, caption=prompt) | |
all_processes_artifacts.append(artifact_value) | |
all_artifacts = gather_object(all_processes_artifacts) | |
if accelerator.is_main_process: | |
tracker_key = "validation" | |
for tracker in accelerator.trackers: | |
if tracker.name == "wandb": | |
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] | |
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] | |
tracker.log( | |
{ | |
tracker_key: {f"images_{infer_type}": image_artifacts, f"videos_{infer_type}": video_artifacts}, | |
}, | |
step=step, | |
) | |
########## Clean up ########## | |
if self.state.using_deepspeed: | |
del pipe | |
# Unload models except those needed for training | |
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST) | |
else: | |
pipe.remove_all_hooks() | |
del pipe | |
# Load models except those not needed for training | |
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST) | |
self.components.controlnet.to(self.accelerator.device, dtype=self.state.weight_dtype) | |
# Change trainable weights back to fp32 to keep with dtype after prepare the model | |
cast_training_params([self.components.controlnet], dtype=torch.float32) | |
del camera_flow_generator | |
free_memory() | |
accelerator.wait_for_everyone() | |
################################ | |
memory_statistics = get_memory_statistics() | |
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") | |
torch.cuda.reset_peak_memory_stats(accelerator.device) | |
torch.set_grad_enabled(True) | |
self.components.controlnet.train() | |
# mangling | |
def __move_components_to_device(self, dtype, ignore_list: List[str] = []): | |
ignore_list = set(ignore_list) | |
components = self.components.model_dump() | |
for name, component in components.items(): | |
if not isinstance(component, type) and hasattr(component, "to"): | |
if name not in ignore_list: | |
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) | |
# mangling | |
def __move_components_to_cpu(self, unload_list: List[str] = []): | |
unload_list = set(unload_list) | |
components = self.components.model_dump() | |
for name, component in components.items(): | |
if not isinstance(component, type) and hasattr(component, "to"): | |
if name in unload_list: | |
setattr(self.components, name, component.to("cpu")) | |
register("cogvideox-flovd", "controlnet", FloVDCogVideoXI2VControlnetTrainer) | |
#-------------------------------------------------------------------------------------------------- | |
# Extract function | |
def encode_text(prompt: str, components, device) -> torch.Tensor: | |
prompt_token_ids = components.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=components.transformer.config.max_text_seq_length, | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
prompt_token_ids = prompt_token_ids.input_ids | |
prompt_embedding = components.text_encoder(prompt_token_ids.to(device))[0] | |
return prompt_embedding | |
def encode_video(video: torch.Tensor, vae) -> torch.Tensor: | |
# shape of input video: [B, C, F, H, W] | |
video = video.to(vae.device, dtype=vae.dtype) | |
latent_dist = vae.encode(video).latent_dist | |
latent = latent_dist.sample() * vae.config.scaling_factor | |
return latent | |
def decode_latents(latents: torch.Tensor, vae) -> torch.Tensor: | |
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] | |
latents = 1 / vae.config.scaling_factor * latents | |
frames = vae.decode(latents).sample | |
return frames | |
def compute_optical_flow(raft, ctxt, trgt, raft_iter=20, chunk=2, only_forward=True): | |
num_frames = ctxt.shape[0] | |
chunk_size = (num_frames // chunk) + 1 | |
flow_f_list = [] | |
if not only_forward: | |
flow_b_list = [] | |
for i in range(chunk): | |
start = chunk_size * i | |
end = chunk_size * (i+1) | |
with torch.no_grad(): | |
flow_f = raft(ctxt[start:end], trgt[start:end], num_flow_updates=raft_iter)[-1] | |
if not only_forward: | |
flow_b = raft(trgt[start:end], ctxt[start:end], num_flow_updates=raft_iter)[-1] | |
flow_f_list.append(flow_f) | |
if not only_forward: | |
flow_b_list.append(flow_b) | |
flow_f = torch.cat(flow_f_list) | |
if not only_forward: | |
flow_b = torch.cat(flow_b_list) | |
if not only_forward: | |
return flow_f, flow_b | |
else: | |
return flow_f, None | |
def encode_flow(flow, vae, flow_scale_factor): | |
# flow: BF,C,H,W | |
# flow_scale_factor [sf_x, sf_y] | |
assert flow.ndim == 4 | |
num_frames, _, height, width = flow.shape | |
# Normalize optical flow | |
# ndim: 4 -> 5 | |
flow = rearrange(flow, '(b f) c h w -> b f c h w', b=1) | |
flow_norm = adaptive_normalize(flow, flow_scale_factor[0], flow_scale_factor[1]) | |
# ndim: 5 -> 4 | |
flow_norm = rearrange(flow_norm, 'b f c h w -> (b f) c h w', b=1) | |
# Duplicate mean value for third channel | |
num_frames, _, H, W = flow_norm.shape | |
flow_norm_extended = torch.empty((num_frames, 3, height, width)).to(flow_norm) | |
flow_norm_extended[:,:2] = flow_norm | |
flow_norm_extended[:,-1:] = flow_norm.mean(dim=1, keepdim=True) | |
flow_norm_extended = rearrange(flow_norm_extended, '(b f) c h w -> b c f h w', f=num_frames) | |
return encode_video(flow_norm_extended, vae) | |
def decode_flow(flow_latent, vae, flow_scale_factor): | |
flow_latent = flow_latent.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] | |
flow_latent = 1 / vae.config.scaling_factor * flow_latent | |
flow = vae.decode(flow_latent).sample # BCFHW | |
# discard third channel (which is a mean value of f_x and f_y) | |
flow = flow[:,:2].detach().clone() | |
# Unnormalize optical flow | |
flow = rearrange(flow, 'b c f h w -> b f c h w') | |
flow = adaptive_unnormalize(flow, flow_scale_factor[0], flow_scale_factor[1]) | |
flow = rearrange(flow, 'b f c h w -> (b f) c h w') | |
return flow # BF,C,H,W | |
def adaptive_normalize(flow, sf_x, sf_y): | |
# x: BFCHW, optical flow | |
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)' | |
assert sf_x is not None and sf_y is not None | |
b, f, c, h, w = flow.shape | |
max_clip_x = math.sqrt(w/sf_x) * 1.0 | |
max_clip_y = math.sqrt(h/sf_y) * 1.0 | |
flow_norm = flow.detach().clone() | |
flow_x = flow[:, :, 0].detach().clone() | |
flow_y = flow[:, :, 1].detach().clone() | |
flow_x_norm = torch.sign(flow_x) * torch.sqrt(torch.abs(flow_x)/sf_x + 1e-7) | |
flow_y_norm = torch.sign(flow_y) * torch.sqrt(torch.abs(flow_y)/sf_y + 1e-7) | |
flow_norm[:, :, 0] = torch.clamp(flow_x_norm, min=-max_clip_x, max=max_clip_x) | |
flow_norm[:, :, 1] = torch.clamp(flow_y_norm, min=-max_clip_y, max=max_clip_y) | |
return flow_norm | |
def adaptive_unnormalize(flow, sf_x, sf_y): | |
# x: BFCHW, optical flow | |
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)' | |
assert sf_x is not None and sf_y is not None | |
flow_orig = flow.detach().clone() | |
flow_x = flow[:, :, 0].detach().clone() | |
flow_y = flow[:, :, 1].detach().clone() | |
flow_orig[:, :, 0] = torch.sign(flow_x) * sf_x * (flow_x**2 - 1e-7) | |
flow_orig[:, :, 1] = torch.sign(flow_y) * sf_y * (flow_y**2 - 1e-7) | |
return flow_orig | |
#-------------------------------------------------------------------------------------------------- | |