dangthr's picture
Update app.py
59a4c9f verified
raw
history blame
9.65 kB
import torch
import numpy as np
import random
import os
import yaml
import argparse
from pathlib import Path
import imageio
import tempfile
if latent_upsampler_instance:
latent_upsampler_instance.to(target_inference_device)
# --- Helper function for dimension calculation ---
MIN_DIM_SLIDER = 256
TARGET_FIXED_SIDE = 768
def calculate_new_dimensions(orig_w, orig_h):
"""
both are multiples of 32, and within [MIN_DIM_SLIDER, MAX_IMAGE_SIZE].
"""
if orig_w == 0 or orig_h == 0:
return int(TARGET_FIXED_SIDE), int(TARGET_FIXED_SIDE)
if orig_w >= orig_h:
new_h = TARGET_FIXED_SIDE
aspect_ratio = orig_w / orig_h
new_w_ideal = new_h * aspect_ratio
new_w = round(new_w_ideal / 32) * 32
new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
else:
new_w = TARGET_FIXED_SIDE
aspect_ratio = orig_h / orig_w
new_h_ideal = new_w * aspect_ratio
new_h = round(new_h_ideal / 32) * 32
new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
return int(new_h), int(new_w)
def generate(prompt, negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
input_image_filepath=None, input_video_filepath=None,
height_ui=512, width_ui=704, mode="text-to-video",
duration_ui=2.0, ui_frames_to_use=9,
seed_ui=42, randomize_seed=True, ui_guidance_scale=None, improve_texture_flag=True):
if randomize_seed:
seed_ui = random.randint(0, 2**32 - 1)
seed_everething(int(seed_ui))
if ui_guidance_scale is None:
ui_guidance_scale = PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0)
target_frames_ideal = duration_ui * FPS
target_frames_rounded = round(target_frames_ideal)
if target_frames_rounded < 1:
height_padded = ((actual_height - 1) // 32 + 1) * 32
width_padded = ((actual_width - 1) // 32 + 1) * 32
num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
call_kwargs["conditioning_items"] = [ConditioningItem(media_tensor.to(target_inference_device), 0, 1.0)]
except Exception as e:
print(f"Error loading image {input_image_filepath}: {e}")
raise RuntimeError(f"Could not load image: {e}")
elif mode == "video-to-video" and input_video_filepath:
try:
call_kwargs["media_items"] = load_media_file(
).to(target_inference_device)
except Exception as e:
print(f"Error loading video {input_video_filepath}: {e}")
raise RuntimeError(f"Could not load video: {e}")
print(f"Moving models to {target_inference_device} for inference (if not already there)...")
result_images_tensor = None
if improve_texture_flag:
if not active_latent_upsampler:
raise RuntimeError("Spatial upscaler model not loaded or improve_texture not selected, cannot use multi-scale.")
multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, active_latent_upsampler)
first_pass_args = PIPELINE_CONFIG_YAML.get("first_pass", {}).copy()
first_pass_args["guidance_scale"] = float(ui_guidance_scale)
first_pass_args.pop("num_inference_steps", None)
second_pass_args = PIPELINE_CONFIG_YAML.get("second_pass", {}).copy()
second_pass_args["guidance_scale"] = float(ui_guidance_scale)
second_pass_args.pop("num_inference_steps", None)
multi_scale_call_kwargs = call_kwargs.copy()
first_pass_config_from_yaml = PIPELINE_CONFIG_YAML.get("first_pass", {})
single_pass_call_kwargs["timesteps"] = first_pass_config_from_yaml.get("timesteps")
single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale)
single_pass_call_kwargs["stg_scale"] = first_pass_config_from_yaml.get("stg_scale")
single_pass_call_kwargs["rescaling_scale"] = first_pass_config_from_yaml.get("rescaling_scale")
single_pass_call_kwargs["skip_block_list"] = first_pass_config_from_yaml.get("skip_block_list")
single_pass_call_kwargs.pop("num_inference_steps", None)
single_pass_call_kwargs.pop("first_pass", None)
single_pass_call_kwargs.pop("second_pass", None)
result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
if result_images_tensor is None:
raise RuntimeError("Generation failed.")
pad_left, pad_right, pad_top, pad_bottom = padding_values
slice_h_end = -pad_bottom if pad_bottom > 0 else None
]
video_np = result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
video_np = np.clip(video_np, 0, 1)
video_np = (video_np * 255).astype(np.uint8)
timestamp = random.randint(10000, 99999)
output_video_path = f"output_{timestamp}.mp4"
try:
with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as video_writer:
for frame_idx in range(video_np.shape[0]):
video_writer.append_data(video_np[frame_idx])
if frame_idx % 10 == 0:
print(f"Saving frame {frame_idx + 1}/{video_np.shape[0]}")
except Exception as e:
print(f"Error saving video with macro_block_size=1: {e}")
try:
with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], format='FFMPEG', codec='libx264', quality=8) as video_writer:
for frame_idx in range(video_np.shape[0]):
video_writer.append_data(video_np[frame_idx])
if frame_idx % 10 == 0:
print(f"Saving frame {frame_idx + 1}/{video_np.shape[0]} (fallback)")
except Exception as e2:
print(f"Fallback video saving error: {e2}")
raise RuntimeError(f"Failed to save video: {e2}")
return output_video_path, seed_ui
def main():
parser = argparse.ArgumentParser(description="LTX Video Generation from Command Line")
parser.add_argument("--prompt", required=True, help="Text prompt for video generation")
parser.add_argument("--negative-prompt", default="worst quality, inconsistent motion, blurry, jittery, distorted",
help="Negative prompt")
parser.add_argument("--mode", choices=["text-to-video", "image-to-video", "video-to-video"],
default="text-to-video", help="Generation mode")
parser.add_argument("--input-image", help="Input image path for image-to-video mode")
parser.add_argument("--input-video", help="Input video path for video-to-video mode")
parser.add_argument("--duration", type=float, default=2.0, help="Video duration in seconds (0.3-8.5)")
parser.add_argument("--height", type=int, default=512, help="Video height (must be divisible by 32)")
parser.add_argument("--width", type=int, default=704, help="Video width (must be divisible by 32)")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--randomize-seed", action="store_true", help="Use random seed")
parser.add_argument("--guidance-scale", type=float, help="Guidance scale for generation")
parser.add_argument("--no-improve-texture", action="store_true", help="Disable texture improvement (faster)")
parser.add_argument("--frames-to-use", type=int, default=9, help="Frames to use from input video (for video-to-video)")
args = parser.parse_args()
# Validate parameters
if args.mode == "image-to-video" and not args.input_image:
print("Error: --input-image is required for image-to-video mode")
return
if args.mode == "video-to-video" and not args.input_video:
print("Error: --input-video is required for video-to-video mode")
return
# Ensure dimensions are divisible by 32
args.height = ((args.height - 1) // 32 + 1) * 32
args.width = ((args.width - 1) // 32 + 1) * 32
print(f"Starting video generation...")
print(f"Prompt: {args.prompt}")
print(f"Mode: {args.mode}")
print(f"Duration: {args.duration}s")
print(f"Resolution: {args.width}x{args.height}")
try:
output_path, used_seed = generate(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
input_image_filepath=args.input_image,
input_video_filepath=args.input_video,
height_ui=args.height,
width_ui=args.width,
mode=args.mode,
duration_ui=args.duration,
ui_frames_to_use=args.frames_to_use,
seed_ui=args.seed,
randomize_seed=args.randomize_seed,
ui_guidance_scale=args.guidance_scale,
improve_texture_flag=not args.no_improve_texture
)
print(f"\nVideo generation completed!")
print(f"Output saved to: {output_path}")
print(f"Used seed: {used_seed}")
except Exception as e:
print(f"Error during generation: {e}")
raise
if __name__ == "__main__":
if os.path.exists(models_dir) and os.path.isdir(models_dir):
print(f"Model directory: {Path(models_dir).resolve()}")
main()