from mmgp import offload import argparse import os import random from datetime import datetime from pathlib import Path from diffusers.utils import logging from typing import Optional, List, Union import yaml from wan.utils.utils import calculate_new_dimensions import imageio import json import numpy as np import torch from safetensors import safe_open from PIL import Image from transformers import ( T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer, ) from huggingface_hub import hf_hub_download from .models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) from .models.transformers.symmetric_patchifier import SymmetricPatchifier from .models.transformers.transformer3d import Transformer3DModel from .pipelines.pipeline_ltx_video import ( ConditioningItem, LTXVideoPipeline, LTXMultiScalePipeline, ) from .schedulers.rf import RectifiedFlowScheduler from .utils.skip_layer_strategy import SkipLayerStrategy from .models.autoencoders.latent_upsampler import LatentUpsampler from .pipelines import crf_compressor import cv2 MAX_HEIGHT = 720 MAX_WIDTH = 1280 MAX_NUM_FRAMES = 257 logger = logging.get_logger("LTX-Video") def get_total_gpu_memory(): if torch.cuda.is_available(): total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) return total_memory return 0 def get_device(): if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" return "cpu" def load_image_to_tensor_with_resize_and_crop( image_input: Union[str, Image.Image], target_height: int = 512, target_width: int = 768, just_crop: bool = False, ) -> torch.Tensor: """Load and process an image into a tensor. Args: image_input: Either a file path (str) or a PIL Image object target_height: Desired height of output tensor target_width: Desired width of output tensor just_crop: If True, only crop the image to the target size without resizing """ if isinstance(image_input, str): image = Image.open(image_input).convert("RGB") elif isinstance(image_input, Image.Image): image = image_input else: raise ValueError("image_input must be either a file path or a PIL Image object") input_width, input_height = image.size aspect_ratio_target = target_width / target_height aspect_ratio_frame = input_width / input_height if aspect_ratio_frame > aspect_ratio_target: new_width = int(input_height * aspect_ratio_target) new_height = input_height x_start = (input_width - new_width) // 2 y_start = 0 else: new_width = input_width new_height = int(input_width / aspect_ratio_target) x_start = 0 y_start = (input_height - new_height) // 2 image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) if not just_crop: image = image.resize((target_width, target_height)) image = np.array(image) image = cv2.GaussianBlur(image, (3, 3), 0) frame_tensor = torch.from_numpy(image).float() frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0 frame_tensor = frame_tensor.permute(2, 0, 1) frame_tensor = (frame_tensor / 127.5) - 1.0 # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) return frame_tensor.unsqueeze(0).unsqueeze(2) def calculate_padding( source_height: int, source_width: int, target_height: int, target_width: int ) -> tuple[int, int, int, int]: # Calculate total padding needed pad_height = target_height - source_height pad_width = target_width - source_width # Calculate padding for each side pad_top = pad_height // 2 pad_bottom = pad_height - pad_top # Handles odd padding pad_left = pad_width // 2 pad_right = pad_width - pad_left # Handles odd padding # Return padded tensor # Padding format is (left, right, top, bottom) padding = (pad_left, pad_right, pad_top, pad_bottom) return padding def seed_everething(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) class LTXV: def __init__( self, model_filepath: str, text_encoder_filepath: str, dtype = torch.bfloat16, VAE_dtype = torch.bfloat16, mixed_precision_transformer = False ): # if dtype == torch.float16: dtype = torch.bfloat16 self.mixed_precision_transformer = mixed_precision_transformer self.distilled = any("lora" in name for name in model_filepath) model_filepath = [name for name in model_filepath if not "lora" in name ] # with safe_open(ckpt_path, framework="pt") as f: # metadata = f.metadata() # config_str = metadata.get("config") # configs = json.loads(config_str) # allowed_inference_steps = configs.get("allowed_inference_steps", None) # transformer = Transformer3DModel.from_pretrained(ckpt_path) # transformer = offload.fast_load_transformers_model("c:/temp/ltxdistilled/diffusion_pytorch_model-00001-of-00006.safetensors", forcedConfigPath="c:/temp/ltxdistilled/config.json") # vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) vae = offload.fast_load_transformers_model("ckpts/ltxv_0.9.7_VAE.safetensors", modelClass=CausalVideoAutoencoder) # if VAE_dtype == torch.float16: VAE_dtype = torch.bfloat16 vae = vae.to(VAE_dtype) vae._model_dtype = VAE_dtype # vae = offload.fast_load_transformers_model("vae.safetensors", modelClass=CausalVideoAutoencoder, modelPrefix= "vae", forcedConfigPath="config_vae.json") # offload.save_model(vae, "vae.safetensors", config_file_path="config_vae.json") # model_filepath = "c:/temp/ltxd/ltxv-13b-0.9.7-distilled.safetensors" transformer = offload.fast_load_transformers_model(model_filepath, modelClass=Transformer3DModel) # offload.save_model(transformer, "ckpts/ltxv_0.9.7_13B_distilled_bf16.safetensors", config_file_path= "c:/temp/ltxd/config.json") # offload.save_model(transformer, "ckpts/ltxv_0.9.7_13B_distilled_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="c:/temp/ltxd/config.json") # transformer = offload.fast_load_transformers_model(model_filepath, modelClass=Transformer3DModel) transformer._model_dtype = dtype if mixed_precision_transformer: transformer._lock_dtype = torch.float scheduler = RectifiedFlowScheduler.from_pretrained("ckpts/ltxv_scheduler.json") # transformer = offload.fast_load_transformers_model("ltx_13B_quanto_bf16_int8.safetensors", modelClass=Transformer3DModel, modelPrefix= "model.diffusion_model", forcedConfigPath="config_transformer.json") # offload.save_model(transformer, "ltx_13B_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="config_transformer.json") latent_upsampler = LatentUpsampler.from_pretrained("ckpts/ltxv_0.9.7_spatial_upscaler.safetensors").to("cpu").eval() latent_upsampler.to(VAE_dtype) latent_upsampler._model_dtype = VAE_dtype allowed_inference_steps = None # text_encoder = T5EncoderModel.from_pretrained( # "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder" # ) # text_encoder.to(torch.bfloat16) # offload.save_model(text_encoder, "T5_xxl_1.1_enc_bf16.safetensors", config_file_path="T5_config.json") # offload.save_model(text_encoder, "T5_xxl_1.1_enc_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="T5_config.json") text_encoder = offload.fast_load_transformers_model(text_encoder_filepath) patchifier = SymmetricPatchifier(patch_size=1) tokenizer = T5Tokenizer.from_pretrained( "ckpts/T5_xxl_1.1") enhance_prompt = False if enhance_prompt: prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True) prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True) prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2_quanto_bf16_int8.safetensors") prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/Llama3_2") else: prompt_enhancer_image_caption_model = None prompt_enhancer_image_caption_processor = None prompt_enhancer_llm_model = None prompt_enhancer_llm_tokenizer = None if prompt_enhancer_image_caption_model != None: pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model prompt_enhancer_image_caption_model._model_dtype = torch.float pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model # offload.profile(pipe, profile_no=5, extraModelsToQuantize = None, quantizeTransformer = False, budgets = { "prompt_enhancer_llm_model" : 10000, "prompt_enhancer_image_caption_model" : 10000, "vae" : 3000, "*" : 100 }, verboseLevel=2) # Use submodels for the pipeline submodel_dict = { "transformer": transformer, "patchifier": patchifier, "text_encoder": text_encoder, "tokenizer": tokenizer, "scheduler": scheduler, "vae": vae, "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model, "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor, "prompt_enhancer_llm_model": prompt_enhancer_llm_model, "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer, "allowed_inference_steps": allowed_inference_steps, } pipeline = LTXVideoPipeline(**submodel_dict) pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) self.pipeline = pipeline self.model = transformer self.vae = vae # return pipeline, pipe def generate( self, input_prompt: str, n_prompt: str, image_start = None, image_end = None, input_video = None, sampling_steps = 50, image_cond_noise_scale: float = 0.15, input_media_path: Optional[str] = None, strength: Optional[float] = 1.0, seed: int = 42, height: Optional[int] = 704, width: Optional[int] = 1216, frame_num: int = 81, frame_rate: int = 30, fit_into_canvas = True, callback=None, device: Optional[str] = None, VAE_tile_size = None, **kwargs, ): num_inference_steps1 = sampling_steps num_inference_steps2 = sampling_steps #10 conditioning_strengths = None conditioning_media_paths = [] conditioning_start_frames = [] if input_video != None: conditioning_media_paths.append(input_video) conditioning_start_frames.append(0) height, width = input_video.shape[-2:] else: if image_start != None: frame_width, frame_height = image_start.size height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32) conditioning_media_paths.append(image_start) conditioning_start_frames.append(0) if image_end != None: conditioning_media_paths.append(image_end) conditioning_start_frames.append(frame_num-1) if len(conditioning_media_paths) == 0: conditioning_media_paths = None conditioning_start_frames = None if self.distilled : pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml" else: pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml" # check if pipeline_config is a file if not os.path.isfile(pipeline_config): raise ValueError(f"Pipeline config file {pipeline_config} does not exist") with open(pipeline_config, "r") as f: pipeline_config = yaml.safe_load(f) # Validate conditioning arguments if conditioning_media_paths: # Use default strengths of 1.0 if not conditioning_strengths: conditioning_strengths = [1.0] * len(conditioning_media_paths) if not conditioning_start_frames: raise ValueError( "If `conditioning_media_paths` is provided, " "`conditioning_start_frames` must also be provided" ) if len(conditioning_media_paths) != len(conditioning_strengths) or len( conditioning_media_paths ) != len(conditioning_start_frames): raise ValueError( "`conditioning_media_paths`, `conditioning_strengths`, " "and `conditioning_start_frames` must have the same length" ) if any(s < 0 or s > 1 for s in conditioning_strengths): raise ValueError("All conditioning strengths must be between 0 and 1") if any(f < 0 or f >= frame_num for f in conditioning_start_frames): raise ValueError( f"All conditioning start frames must be between 0 and {frame_num-1}" ) # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1) height_padded = ((height - 1) // 32 + 1) * 32 width_padded = ((width - 1) // 32 + 1) * 32 num_frames_padded = ((frame_num - 2) // 8 + 1) * 8 + 1 padding = calculate_padding(height, width, height_padded, width_padded) logger.warning( f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}" ) # prompt_enhancement_words_threshold = pipeline_config[ # "prompt_enhancement_words_threshold" # ] # prompt_word_count = len(prompt.split()) # enhance_prompt = ( # prompt_enhancement_words_threshold > 0 # and prompt_word_count < prompt_enhancement_words_threshold # ) # # enhance_prompt = False # if prompt_enhancement_words_threshold > 0 and not enhance_prompt: # logger.info( # f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled." # ) seed_everething(seed) device = device or get_device() generator = torch.Generator(device=device).manual_seed(seed) media_item = None if input_media_path: media_item = load_media_file( media_path=input_media_path, height=height, width=width, max_frames=num_frames_padded, padding=padding, ) conditioning_items = ( prepare_conditioning( conditioning_media_paths=conditioning_media_paths, conditioning_strengths=conditioning_strengths, conditioning_start_frames=conditioning_start_frames, height=height, width=width, num_frames=frame_num, padding=padding, pipeline=self.pipeline, ) if conditioning_media_paths else None ) stg_mode = pipeline_config.get("stg_mode", "attention_values") del pipeline_config["stg_mode"] if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": skip_layer_strategy = SkipLayerStrategy.AttentionValues elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": skip_layer_strategy = SkipLayerStrategy.AttentionSkip elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": skip_layer_strategy = SkipLayerStrategy.Residual elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": skip_layer_strategy = SkipLayerStrategy.TransformerBlock else: raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") # Prepare input for the pipeline sample = { "prompt": input_prompt, "prompt_attention_mask": None, "negative_prompt": n_prompt, "negative_prompt_attention_mask": None, } images = self.pipeline( **pipeline_config, ltxv_model = self, num_inference_steps1 = num_inference_steps1, num_inference_steps2 = num_inference_steps2, skip_layer_strategy=skip_layer_strategy, generator=generator, output_type="pt", callback_on_step_end=None, height=height_padded, width=width_padded, num_frames=num_frames_padded, frame_rate=frame_rate, **sample, media_items=media_item, strength=strength, conditioning_items=conditioning_items, is_video=True, vae_per_channel_normalize=True, image_cond_noise_scale=image_cond_noise_scale, mixed_precision=pipeline_config.get("mixed", self.mixed_precision_transformer), callback=callback, VAE_tile_size = VAE_tile_size, device=device, # enhance_prompt=enhance_prompt, ) if images == None: return None # Crop the padded images to the desired resolution and number of frames (pad_left, pad_right, pad_top, pad_bottom) = padding pad_bottom = -pad_bottom pad_right = -pad_right if pad_bottom == 0: pad_bottom = images.shape[3] if pad_right == 0: pad_right = images.shape[4] images = images[:, :, :frame_num, pad_top:pad_bottom, pad_left:pad_right] images = images.sub_(0.5).mul_(2).squeeze(0) return images def prepare_conditioning( conditioning_media_paths: List[str], conditioning_strengths: List[float], conditioning_start_frames: List[int], height: int, width: int, num_frames: int, padding: tuple[int, int, int, int], pipeline: LTXVideoPipeline, ) -> Optional[List[ConditioningItem]]: """Prepare conditioning items based on input media paths and their parameters. Args: conditioning_media_paths: List of paths to conditioning media (images or videos) conditioning_strengths: List of conditioning strengths for each media item conditioning_start_frames: List of frame indices where each item should be applied height: Height of the output frames width: Width of the output frames num_frames: Number of frames in the output video padding: Padding to apply to the frames pipeline: LTXVideoPipeline object used for condition video trimming Returns: A list of ConditioningItem objects. """ conditioning_items = [] for path, strength, start_frame in zip( conditioning_media_paths, conditioning_strengths, conditioning_start_frames ): if isinstance(path, Image.Image): num_input_frames = orig_num_input_frames = 1 else: num_input_frames = orig_num_input_frames = get_media_num_frames(path) if hasattr(pipeline, "trim_conditioning_sequence") and callable( getattr(pipeline, "trim_conditioning_sequence") ): num_input_frames = pipeline.trim_conditioning_sequence( start_frame, orig_num_input_frames, num_frames ) if num_input_frames < orig_num_input_frames: logger.warning( f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames." ) media_tensor = load_media_file( media_path=path, height=height, width=width, max_frames=num_input_frames, padding=padding, just_crop=True, ) conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength)) return conditioning_items def get_media_num_frames(media_path: str) -> int: if isinstance(media_path, Image.Image): return 1 elif torch.is_tensor(media_path): return media_path.shape[1] elif isinstance(media_path, str) and any( media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]): reader = imageio.get_reader(media_path) return min(reader.count_frames(), max_frames) else: raise Exception("video format not supported") def load_media_file( media_path: str, height: int, width: int, max_frames: int, padding: tuple[int, int, int, int], just_crop: bool = False, ) -> torch.Tensor: if isinstance(media_path, Image.Image): # Input image media_tensor = load_image_to_tensor_with_resize_and_crop( media_path, height, width, just_crop=just_crop ) media_tensor = torch.nn.functional.pad(media_tensor, padding) elif torch.is_tensor(media_path): media_tensor = media_path.unsqueeze(0) num_input_frames = media_tensor.shape[2] elif isinstance(media_path, str) and any( media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]): reader = imageio.get_reader(media_path) num_input_frames = min(reader.count_frames(), max_frames) # Read and preprocess the relevant frames from the video file. frames = [] for i in range(num_input_frames): frame = Image.fromarray(reader.get_data(i)) frame_tensor = load_image_to_tensor_with_resize_and_crop( frame, height, width, just_crop=just_crop ) frame_tensor = torch.nn.functional.pad(frame_tensor, padding) frames.append(frame_tensor) reader.close() # Stack frames along the temporal dimension media_tensor = torch.cat(frames, dim=2) else: raise Exception("video format not supported") return media_tensor if __name__ == "__main__": main()