File size: 23,264 Bytes
78360e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 |
from mmgp import offload
import argparse
import os
import random
from datetime import datetime
from pathlib import Path
from diffusers.utils import logging
from typing import Optional, List, Union
import yaml
from wan.utils.utils import calculate_new_dimensions
import imageio
import json
import numpy as np
import torch
from safetensors import safe_open
from PIL import Image
from transformers import (
T5EncoderModel,
T5Tokenizer,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
)
from huggingface_hub import hf_hub_download
from .models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
)
from .models.transformers.symmetric_patchifier import SymmetricPatchifier
from .models.transformers.transformer3d import Transformer3DModel
from .pipelines.pipeline_ltx_video import (
ConditioningItem,
LTXVideoPipeline,
LTXMultiScalePipeline,
)
from .schedulers.rf import RectifiedFlowScheduler
from .utils.skip_layer_strategy import SkipLayerStrategy
from .models.autoencoders.latent_upsampler import LatentUpsampler
from .pipelines import crf_compressor
import cv2
MAX_HEIGHT = 720
MAX_WIDTH = 1280
MAX_NUM_FRAMES = 257
logger = logging.get_logger("LTX-Video")
def get_total_gpu_memory():
if torch.cuda.is_available():
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
return total_memory
return 0
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def load_image_to_tensor_with_resize_and_crop(
image_input: Union[str, Image.Image],
target_height: int = 512,
target_width: int = 768,
just_crop: bool = False,
) -> torch.Tensor:
"""Load and process an image into a tensor.
Args:
image_input: Either a file path (str) or a PIL Image object
target_height: Desired height of output tensor
target_width: Desired width of output tensor
just_crop: If True, only crop the image to the target size without resizing
"""
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("image_input must be either a file path or a PIL Image object")
input_width, input_height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_frame = input_width / input_height
if aspect_ratio_frame > aspect_ratio_target:
new_width = int(input_height * aspect_ratio_target)
new_height = input_height
x_start = (input_width - new_width) // 2
y_start = 0
else:
new_width = input_width
new_height = int(input_width / aspect_ratio_target)
x_start = 0
y_start = (input_height - new_height) // 2
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
if not just_crop:
image = image.resize((target_width, target_height))
image = np.array(image)
image = cv2.GaussianBlur(image, (3, 3), 0)
frame_tensor = torch.from_numpy(image).float()
frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
frame_tensor = frame_tensor.permute(2, 0, 1)
frame_tensor = (frame_tensor / 127.5) - 1.0
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
return frame_tensor.unsqueeze(0).unsqueeze(2)
def calculate_padding(
source_height: int, source_width: int, target_height: int, target_width: int
) -> tuple[int, int, int, int]:
# Calculate total padding needed
pad_height = target_height - source_height
pad_width = target_width - source_width
# Calculate padding for each side
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top # Handles odd padding
pad_left = pad_width // 2
pad_right = pad_width - pad_left # Handles odd padding
# Return padded tensor
# Padding format is (left, right, top, bottom)
padding = (pad_left, pad_right, pad_top, pad_bottom)
return padding
def seed_everething(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
class LTXV:
def __init__(
self,
model_filepath: str,
text_encoder_filepath: str,
dtype = torch.bfloat16,
VAE_dtype = torch.bfloat16,
mixed_precision_transformer = False
):
# if dtype == torch.float16:
dtype = torch.bfloat16
self.mixed_precision_transformer = mixed_precision_transformer
self.distilled = any("lora" in name for name in model_filepath)
model_filepath = [name for name in model_filepath if not "lora" in name ]
# with safe_open(ckpt_path, framework="pt") as f:
# metadata = f.metadata()
# config_str = metadata.get("config")
# configs = json.loads(config_str)
# allowed_inference_steps = configs.get("allowed_inference_steps", None)
# transformer = Transformer3DModel.from_pretrained(ckpt_path)
# transformer = offload.fast_load_transformers_model("c:/temp/ltxdistilled/diffusion_pytorch_model-00001-of-00006.safetensors", forcedConfigPath="c:/temp/ltxdistilled/config.json")
# vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
vae = offload.fast_load_transformers_model("ckpts/ltxv_0.9.7_VAE.safetensors", modelClass=CausalVideoAutoencoder)
# if VAE_dtype == torch.float16:
VAE_dtype = torch.bfloat16
vae = vae.to(VAE_dtype)
vae._model_dtype = VAE_dtype
# vae = offload.fast_load_transformers_model("vae.safetensors", modelClass=CausalVideoAutoencoder, modelPrefix= "vae", forcedConfigPath="config_vae.json")
# offload.save_model(vae, "vae.safetensors", config_file_path="config_vae.json")
# model_filepath = "c:/temp/ltxd/ltxv-13b-0.9.7-distilled.safetensors"
transformer = offload.fast_load_transformers_model(model_filepath, modelClass=Transformer3DModel)
# offload.save_model(transformer, "ckpts/ltxv_0.9.7_13B_distilled_bf16.safetensors", config_file_path= "c:/temp/ltxd/config.json")
# offload.save_model(transformer, "ckpts/ltxv_0.9.7_13B_distilled_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="c:/temp/ltxd/config.json")
# transformer = offload.fast_load_transformers_model(model_filepath, modelClass=Transformer3DModel)
transformer._model_dtype = dtype
if mixed_precision_transformer:
transformer._lock_dtype = torch.float
scheduler = RectifiedFlowScheduler.from_pretrained("ckpts/ltxv_scheduler.json")
# transformer = offload.fast_load_transformers_model("ltx_13B_quanto_bf16_int8.safetensors", modelClass=Transformer3DModel, modelPrefix= "model.diffusion_model", forcedConfigPath="config_transformer.json")
# offload.save_model(transformer, "ltx_13B_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="config_transformer.json")
latent_upsampler = LatentUpsampler.from_pretrained("ckpts/ltxv_0.9.7_spatial_upscaler.safetensors").to("cpu").eval()
latent_upsampler.to(VAE_dtype)
latent_upsampler._model_dtype = VAE_dtype
allowed_inference_steps = None
# text_encoder = T5EncoderModel.from_pretrained(
# "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
# )
# text_encoder.to(torch.bfloat16)
# offload.save_model(text_encoder, "T5_xxl_1.1_enc_bf16.safetensors", config_file_path="T5_config.json")
# offload.save_model(text_encoder, "T5_xxl_1.1_enc_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="T5_config.json")
text_encoder = offload.fast_load_transformers_model(text_encoder_filepath)
patchifier = SymmetricPatchifier(patch_size=1)
tokenizer = T5Tokenizer.from_pretrained( "ckpts/T5_xxl_1.1")
enhance_prompt = False
if enhance_prompt:
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2_quanto_bf16_int8.safetensors")
prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/Llama3_2")
else:
prompt_enhancer_image_caption_model = None
prompt_enhancer_image_caption_processor = None
prompt_enhancer_llm_model = None
prompt_enhancer_llm_tokenizer = None
if prompt_enhancer_image_caption_model != None:
pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model
prompt_enhancer_image_caption_model._model_dtype = torch.float
pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model
# offload.profile(pipe, profile_no=5, extraModelsToQuantize = None, quantizeTransformer = False, budgets = { "prompt_enhancer_llm_model" : 10000, "prompt_enhancer_image_caption_model" : 10000, "vae" : 3000, "*" : 100 }, verboseLevel=2)
# Use submodels for the pipeline
submodel_dict = {
"transformer": transformer,
"patchifier": patchifier,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
"vae": vae,
"prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
"prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
"prompt_enhancer_llm_model": prompt_enhancer_llm_model,
"prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
"allowed_inference_steps": allowed_inference_steps,
}
pipeline = LTXVideoPipeline(**submodel_dict)
pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
self.pipeline = pipeline
self.model = transformer
self.vae = vae
# return pipeline, pipe
def generate(
self,
input_prompt: str,
n_prompt: str,
image_start = None,
image_end = None,
input_video = None,
sampling_steps = 50,
image_cond_noise_scale: float = 0.15,
input_media_path: Optional[str] = None,
strength: Optional[float] = 1.0,
seed: int = 42,
height: Optional[int] = 704,
width: Optional[int] = 1216,
frame_num: int = 81,
frame_rate: int = 30,
fit_into_canvas = True,
callback=None,
device: Optional[str] = None,
VAE_tile_size = None,
**kwargs,
):
num_inference_steps1 = sampling_steps
num_inference_steps2 = sampling_steps #10
conditioning_strengths = None
conditioning_media_paths = []
conditioning_start_frames = []
if input_video != None:
conditioning_media_paths.append(input_video)
conditioning_start_frames.append(0)
height, width = input_video.shape[-2:]
else:
if image_start != None:
frame_width, frame_height = image_start.size
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
conditioning_media_paths.append(image_start)
conditioning_start_frames.append(0)
if image_end != None:
conditioning_media_paths.append(image_end)
conditioning_start_frames.append(frame_num-1)
if len(conditioning_media_paths) == 0:
conditioning_media_paths = None
conditioning_start_frames = None
if self.distilled :
pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml"
else:
pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml"
# check if pipeline_config is a file
if not os.path.isfile(pipeline_config):
raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
with open(pipeline_config, "r") as f:
pipeline_config = yaml.safe_load(f)
# Validate conditioning arguments
if conditioning_media_paths:
# Use default strengths of 1.0
if not conditioning_strengths:
conditioning_strengths = [1.0] * len(conditioning_media_paths)
if not conditioning_start_frames:
raise ValueError(
"If `conditioning_media_paths` is provided, "
"`conditioning_start_frames` must also be provided"
)
if len(conditioning_media_paths) != len(conditioning_strengths) or len(
conditioning_media_paths
) != len(conditioning_start_frames):
raise ValueError(
"`conditioning_media_paths`, `conditioning_strengths`, "
"and `conditioning_start_frames` must have the same length"
)
if any(s < 0 or s > 1 for s in conditioning_strengths):
raise ValueError("All conditioning strengths must be between 0 and 1")
if any(f < 0 or f >= frame_num for f in conditioning_start_frames):
raise ValueError(
f"All conditioning start frames must be between 0 and {frame_num-1}"
)
# Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
height_padded = ((height - 1) // 32 + 1) * 32
width_padded = ((width - 1) // 32 + 1) * 32
num_frames_padded = ((frame_num - 2) // 8 + 1) * 8 + 1
padding = calculate_padding(height, width, height_padded, width_padded)
logger.warning(
f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
)
# prompt_enhancement_words_threshold = pipeline_config[
# "prompt_enhancement_words_threshold"
# ]
# prompt_word_count = len(prompt.split())
# enhance_prompt = (
# prompt_enhancement_words_threshold > 0
# and prompt_word_count < prompt_enhancement_words_threshold
# )
# # enhance_prompt = False
# if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
# logger.info(
# f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
# )
seed_everething(seed)
device = device or get_device()
generator = torch.Generator(device=device).manual_seed(seed)
media_item = None
if input_media_path:
media_item = load_media_file(
media_path=input_media_path,
height=height,
width=width,
max_frames=num_frames_padded,
padding=padding,
)
conditioning_items = (
prepare_conditioning(
conditioning_media_paths=conditioning_media_paths,
conditioning_strengths=conditioning_strengths,
conditioning_start_frames=conditioning_start_frames,
height=height,
width=width,
num_frames=frame_num,
padding=padding,
pipeline=self.pipeline,
)
if conditioning_media_paths
else None
)
stg_mode = pipeline_config.get("stg_mode", "attention_values")
del pipeline_config["stg_mode"]
if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
skip_layer_strategy = SkipLayerStrategy.AttentionValues
elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
skip_layer_strategy = SkipLayerStrategy.Residual
elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
else:
raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
# Prepare input for the pipeline
sample = {
"prompt": input_prompt,
"prompt_attention_mask": None,
"negative_prompt": n_prompt,
"negative_prompt_attention_mask": None,
}
images = self.pipeline(
**pipeline_config,
ltxv_model = self,
num_inference_steps1 = num_inference_steps1,
num_inference_steps2 = num_inference_steps2,
skip_layer_strategy=skip_layer_strategy,
generator=generator,
output_type="pt",
callback_on_step_end=None,
height=height_padded,
width=width_padded,
num_frames=num_frames_padded,
frame_rate=frame_rate,
**sample,
media_items=media_item,
strength=strength,
conditioning_items=conditioning_items,
is_video=True,
vae_per_channel_normalize=True,
image_cond_noise_scale=image_cond_noise_scale,
mixed_precision=pipeline_config.get("mixed", self.mixed_precision_transformer),
callback=callback,
VAE_tile_size = VAE_tile_size,
device=device,
# enhance_prompt=enhance_prompt,
)
if images == None:
return None
# Crop the padded images to the desired resolution and number of frames
(pad_left, pad_right, pad_top, pad_bottom) = padding
pad_bottom = -pad_bottom
pad_right = -pad_right
if pad_bottom == 0:
pad_bottom = images.shape[3]
if pad_right == 0:
pad_right = images.shape[4]
images = images[:, :, :frame_num, pad_top:pad_bottom, pad_left:pad_right]
images = images.sub_(0.5).mul_(2).squeeze(0)
return images
def prepare_conditioning(
conditioning_media_paths: List[str],
conditioning_strengths: List[float],
conditioning_start_frames: List[int],
height: int,
width: int,
num_frames: int,
padding: tuple[int, int, int, int],
pipeline: LTXVideoPipeline,
) -> Optional[List[ConditioningItem]]:
"""Prepare conditioning items based on input media paths and their parameters.
Args:
conditioning_media_paths: List of paths to conditioning media (images or videos)
conditioning_strengths: List of conditioning strengths for each media item
conditioning_start_frames: List of frame indices where each item should be applied
height: Height of the output frames
width: Width of the output frames
num_frames: Number of frames in the output video
padding: Padding to apply to the frames
pipeline: LTXVideoPipeline object used for condition video trimming
Returns:
A list of ConditioningItem objects.
"""
conditioning_items = []
for path, strength, start_frame in zip(
conditioning_media_paths, conditioning_strengths, conditioning_start_frames
):
if isinstance(path, Image.Image):
num_input_frames = orig_num_input_frames = 1
else:
num_input_frames = orig_num_input_frames = get_media_num_frames(path)
if hasattr(pipeline, "trim_conditioning_sequence") and callable(
getattr(pipeline, "trim_conditioning_sequence")
):
num_input_frames = pipeline.trim_conditioning_sequence(
start_frame, orig_num_input_frames, num_frames
)
if num_input_frames < orig_num_input_frames:
logger.warning(
f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
)
media_tensor = load_media_file(
media_path=path,
height=height,
width=width,
max_frames=num_input_frames,
padding=padding,
just_crop=True,
)
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
return conditioning_items
def get_media_num_frames(media_path: str) -> int:
if isinstance(media_path, Image.Image):
return 1
elif torch.is_tensor(media_path):
return media_path.shape[1]
elif isinstance(media_path, str) and any( media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]):
reader = imageio.get_reader(media_path)
return min(reader.count_frames(), max_frames)
else:
raise Exception("video format not supported")
def load_media_file(
media_path: str,
height: int,
width: int,
max_frames: int,
padding: tuple[int, int, int, int],
just_crop: bool = False,
) -> torch.Tensor:
if isinstance(media_path, Image.Image):
# Input image
media_tensor = load_image_to_tensor_with_resize_and_crop(
media_path, height, width, just_crop=just_crop
)
media_tensor = torch.nn.functional.pad(media_tensor, padding)
elif torch.is_tensor(media_path):
media_tensor = media_path.unsqueeze(0)
num_input_frames = media_tensor.shape[2]
elif isinstance(media_path, str) and any( media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]):
reader = imageio.get_reader(media_path)
num_input_frames = min(reader.count_frames(), max_frames)
# Read and preprocess the relevant frames from the video file.
frames = []
for i in range(num_input_frames):
frame = Image.fromarray(reader.get_data(i))
frame_tensor = load_image_to_tensor_with_resize_and_crop(
frame, height, width, just_crop=just_crop
)
frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
frames.append(frame_tensor)
reader.close()
# Stack frames along the temporal dimension
media_tensor = torch.cat(frames, dim=2)
else:
raise Exception("video format not supported")
return media_tensor
if __name__ == "__main__":
main()
|