dangthr's picture
Update app.py
515245d verified
raw
history blame
16.5 kB
import torch
import numpy as np
import random
import os
import yaml
import argparse
from pathlib import Path
import imageio
import tempfile
from PIL import Image
from huggingface_hub import hf_hub_download
import shutil
from inference import (
create_ltx_video_pipeline,
create_latent_upsampler,
load_image_to_tensor_with_resize_and_crop,
seed_everething,
get_device,
calculate_padding,
load_media_file
)
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
config_file_path = "configs/ltxv-13b-0.9.7-distilled.yaml"
with open(config_file_path, "r") as file:
PIPELINE_CONFIG_YAML = yaml.safe_load(file)
LTX_REPO = "Lightricks/LTX-Video"
MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280)
MAX_NUM_FRAMES = 257
FPS = 30.0
# --- Global variables for loaded models ---
pipeline_instance = None
latent_upsampler_instance = None
models_dir = "downloaded_models_gradio_cpu_init"
Path(models_dir).mkdir(parents=True, exist_ok=True)
# 创建输出目录
output_dir = "output"
Path(output_dir).mkdir(parents=True, exist_ok=True)
print("Downloading models (if not present)...")
distilled_model_actual_path = hf_hub_download(
repo_id=LTX_REPO,
filename=PIPELINE_CONFIG_YAML["checkpoint_path"],
local_dir=models_dir,
local_dir_use_symlinks=False
)
PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path
print(f"Distilled model path: {distilled_model_actual_path}")
SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"]
spatial_upscaler_actual_path = hf_hub_download(
repo_id=LTX_REPO,
filename=SPATIAL_UPSCALER_FILENAME,
local_dir=models_dir,
local_dir_use_symlinks=False
)
PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path
print(f"Spatial upscaler model path: {spatial_upscaler_actual_path}")
print("Creating LTX Video pipeline on CPU...")
pipeline_instance = create_ltx_video_pipeline(
ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"],
precision=PIPELINE_CONFIG_YAML["precision"],
text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"],
sampler=PIPELINE_CONFIG_YAML["sampler"],
device="cpu",
enhance_prompt=False,
prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"],
prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"],
)
print("LTX Video pipeline created on CPU.")
if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
print("Creating latent upsampler on CPU...")
latent_upsampler_instance = create_latent_upsampler(
PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"],
device="cpu"
)
print("Latent upsampler created on CPU.")
target_inference_device = "cuda"
print(f"Target inference device: {target_inference_device}")
pipeline_instance.to(target_inference_device)
if latent_upsampler_instance:
latent_upsampler_instance.to(target_inference_device)
# --- Helper function for dimension calculation ---
MIN_DIM_SLIDER = 256
TARGET_FIXED_SIDE = 768
def calculate_new_dimensions(orig_w, orig_h):
"""
Calculates new dimensions for height and width sliders based on original media dimensions.
Ensures one side is TARGET_FIXED_SIDE, the other is scaled proportionally,
both are multiples of 32, and within [MIN_DIM_SLIDER, MAX_IMAGE_SIZE].
"""
if orig_w == 0 or orig_h == 0:
return int(TARGET_FIXED_SIDE), int(TARGET_FIXED_SIDE)
if orig_w >= orig_h:
new_h = TARGET_FIXED_SIDE
aspect_ratio = orig_w / orig_h
new_w_ideal = new_h * aspect_ratio
new_w = round(new_w_ideal / 32) * 32
new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
else:
new_w = TARGET_FIXED_SIDE
aspect_ratio = orig_h / orig_w
new_h_ideal = new_w * aspect_ratio
new_h = round(new_h_ideal / 32) * 32
new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
return int(new_h), int(new_w)
def generate(prompt, negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
input_image_filepath=None, input_video_filepath=None,
height_ui=512, width_ui=704, mode="text-to-video",
duration_ui=2.0, ui_frames_to_use=9,
seed_ui=42, randomize_seed=True, ui_guidance_scale=None, improve_texture_flag=True):
if randomize_seed:
seed_ui = random.randint(0, 2**32 - 1)
seed_everething(int(seed_ui))
if ui_guidance_scale is None:
ui_guidance_scale = PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0)
target_frames_ideal = duration_ui * FPS
target_frames_rounded = round(target_frames_ideal)
if target_frames_rounded < 1:
target_frames_rounded = 1
n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
actual_num_frames = int(n_val * 8 + 1)
actual_num_frames = max(9, actual_num_frames)
actual_num_frames = min(MAX_NUM_FRAMES, actual_num_frames)
actual_height = int(height_ui)
actual_width = int(width_ui)
height_padded = ((actual_height - 1) // 32 + 1) * 32
width_padded = ((actual_width - 1) // 32 + 1) * 32
num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
call_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"height": height_padded,
"width": width_padded,
"num_frames": num_frames_padded,
"frame_rate": int(FPS),
"generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
"output_type": "pt",
"conditioning_items": None,
"media_items": None,
"decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"],
"decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"],
"stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"],
"image_cond_noise_scale": 0.15,
"is_video": True,
"vae_per_channel_normalize": True,
"mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"),
"offload_to_cpu": False,
"enhance_prompt": False,
}
stg_mode_str = PIPELINE_CONFIG_YAML.get("stg_mode", "attention_values")
if stg_mode_str.lower() in ["stg_av", "attention_values"]:
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionValues
elif stg_mode_str.lower() in ["stg_as", "attention_skip"]:
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionSkip
elif stg_mode_str.lower() in ["stg_r", "residual"]:
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.Residual
elif stg_mode_str.lower() in ["stg_t", "transformer_block"]:
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.TransformerBlock
else:
raise ValueError(f"Invalid stg_mode: {stg_mode_str}")
if mode == "image-to-video" and input_image_filepath:
try:
media_tensor = load_image_to_tensor_with_resize_and_crop(
input_image_filepath, actual_height, actual_width
)
media_tensor = torch.nn.functional.pad(media_tensor, padding_values)
call_kwargs["conditioning_items"] = [ConditioningItem(media_tensor.to(target_inference_device), 0, 1.0)]
except Exception as e:
print(f"Error loading image {input_image_filepath}: {e}")
raise RuntimeError(f"Could not load image: {e}")
elif mode == "video-to-video" and input_video_filepath:
try:
call_kwargs["media_items"] = load_media_file(
media_path=input_video_filepath,
height=actual_height,
width=actual_width,
max_frames=int(ui_frames_to_use),
padding=padding_values
).to(target_inference_device)
except Exception as e:
print(f"Error loading video {input_video_filepath}: {e}")
raise RuntimeError(f"Could not load video: {e}")
print(f"Moving models to {target_inference_device} for inference (if not already there)...")
active_latent_upsampler = None
if improve_texture_flag and latent_upsampler_instance:
active_latent_upsampler = latent_upsampler_instance
result_images_tensor = None
if improve_texture_flag:
if not active_latent_upsampler:
raise RuntimeError("Spatial upscaler model not loaded or improve_texture not selected, cannot use multi-scale.")
multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, active_latent_upsampler)
first_pass_args = PIPELINE_CONFIG_YAML.get("first_pass", {}).copy()
first_pass_args["guidance_scale"] = float(ui_guidance_scale)
first_pass_args.pop("num_inference_steps", None)
second_pass_args = PIPELINE_CONFIG_YAML.get("second_pass", {}).copy()
second_pass_args["guidance_scale"] = float(ui_guidance_scale)
second_pass_args.pop("num_inference_steps", None)
multi_scale_call_kwargs = call_kwargs.copy()
multi_scale_call_kwargs.update({
"downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"],
"first_pass": first_pass_args,
"second_pass": second_pass_args,
})
print(f"Calling multi-scale pipeline (eff. HxW: {actual_height}x{actual_width}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
result_images_tensor = multi_scale_pipeline_obj(**multi_scale_call_kwargs).images
else:
single_pass_call_kwargs = call_kwargs.copy()
first_pass_config_from_yaml = PIPELINE_CONFIG_YAML.get("first_pass", {})
single_pass_call_kwargs["timesteps"] = first_pass_config_from_yaml.get("timesteps")
single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale)
single_pass_call_kwargs["stg_scale"] = first_pass_config_from_yaml.get("stg_scale")
single_pass_call_kwargs["rescaling_scale"] = first_pass_config_from_yaml.get("rescaling_scale")
single_pass_call_kwargs["skip_block_list"] = first_pass_config_from_yaml.get("skip_block_list")
single_pass_call_kwargs.pop("num_inference_steps", None)
single_pass_call_kwargs.pop("first_pass", None)
single_pass_call_kwargs.pop("second_pass", None)
single_pass_call_kwargs.pop("downscale_factor", None)
print(f"Calling base pipeline (padded HxW: {height_padded}x{width_padded}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
if result_images_tensor is None:
raise RuntimeError("Generation failed.")
pad_left, pad_right, pad_top, pad_bottom = padding_values
slice_h_end = -pad_bottom if pad_bottom > 0 else None
slice_w_end = -pad_right if pad_right > 0 else None
result_images_tensor = result_images_tensor[
:, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
]
video_np = result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
video_np = np.clip(video_np, 0, 1)
video_np = (video_np * 255).astype(np.uint8)
# 生成带时间戳的文件名并保存到output目录
timestamp = random.randint(10000, 99999)
output_video_path = os.path.join(output_dir, f"output_{timestamp}.mp4")
try:
with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as video_writer:
for frame_idx in range(video_np.shape[0]):
video_writer.append_data(video_np[frame_idx])
if frame_idx % 10 == 0:
print(f"Saving frame {frame_idx + 1}/{video_np.shape[0]}")
except Exception as e:
print(f"Error saving video with macro_block_size=1: {e}")
try:
with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], format='FFMPEG', codec='libx264', quality=8) as video_writer:
for frame_idx in range(video_np.shape[0]):
video_writer.append_data(video_np[frame_idx])
if frame_idx % 10 == 0:
print(f"Saving frame {frame_idx + 1}/{video_np.shape[0]} (fallback)")
except Exception as e2:
print(f"Fallback video saving error: {e2}")
raise RuntimeError(f"Failed to save video: {e2}")
print(f"Video saved successfully to: {output_video_path}")
return output_video_path, seed_ui
def main():
parser = argparse.ArgumentParser(description="LTX Video Generation from Command Line")
parser.add_argument("--prompt", required=True, help="Text prompt for video generation")
parser.add_argument("--negative-prompt", default="worst quality, inconsistent motion, blurry, jittery, distorted",
help="Negative prompt")
parser.add_argument("--mode", choices=["text-to-video", "image-to-video", "video-to-video"],
default="text-to-video", help="Generation mode")
parser.add_argument("--input-image", help="Input image path for image-to-video mode")
parser.add_argument("--input-video", help="Input video path for video-to-video mode")
parser.add_argument("--duration", type=float, default=2.0, help="Video duration in seconds (0.3-8.5)")
parser.add_argument("--height", type=int, default=512, help="Video height (must be divisible by 32)")
parser.add_argument("--width", type=int, default=704, help="Video width (must be divisible by 32)")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--randomize-seed", action="store_true", help="Use random seed")
parser.add_argument("--guidance-scale", type=float, help="Guidance scale for generation")
parser.add_argument("--no-improve-texture", action="store_true", help="Disable texture improvement (faster)")
parser.add_argument("--frames-to-use", type=int, default=9, help="Frames to use from input video (for video-to-video)")
args = parser.parse_args()
# Validate parameters
if args.mode == "image-to-video" and not args.input_image:
print("Error: --input-image is required for image-to-video mode")
return
if args.mode == "video-to-video" and not args.input_video:
print("Error: --input-video is required for video-to-video mode")
return
# Ensure dimensions are divisible by 32
args.height = ((args.height - 1) // 32 + 1) * 32
args.width = ((args.width - 1) // 32 + 1) * 32
print(f"Starting video generation...")
print(f"Prompt: {args.prompt}")
print(f"Mode: {args.mode}")
print(f"Duration: {args.duration}s")
print(f"Resolution: {args.width}x{args.height}")
print(f"Output directory: {os.path.abspath(output_dir)}")
try:
output_path, used_seed = generate(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
input_image_filepath=args.input_image,
input_video_filepath=args.input_video,
height_ui=args.height,
width_ui=args.width,
mode=args.mode,
duration_ui=args.duration,
ui_frames_to_use=args.frames_to_use,
seed_ui=args.seed,
randomize_seed=args.randomize_seed,
ui_guidance_scale=args.guidance_scale,
improve_texture_flag=not args.no_improve_texture
)
print(f"\n✅ Video generation completed!")
print(f"📁 Output saved to: {output_path}")
print(f"🎲 Used seed: {used_seed}")
print(f"📂 Full path: {os.path.abspath(output_path)}")
except Exception as e:
print(f"❌ Error during generation: {e}")
raise
if __name__ == "__main__":
if os.path.exists(models_dir) and os.path.isdir(models_dir):
print(f"Model directory: {Path(models_dir).resolve()}")
print(f"Output directory: {Path(output_dir).resolve()}")
main()