Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
import sys,os | |
sys.path.append('../OSS') | |
from OSS.OSS import search_OSS_video, infer_OSS | |
from OSS.model_wrap import _WrappedModel_Wan | |
import gc | |
import logging | |
import math | |
import os | |
import pdb | |
import random | |
import sys | |
import types | |
from contextlib import contextmanager | |
from functools import partial | |
import numpy as np | |
import torch | |
import torch.cuda.amp as amp | |
import torch.distributed as dist | |
import torchvision.transforms.functional as TF | |
from tqdm import tqdm | |
from .distributed.fsdp import shard_model | |
from .modules.clip import CLIPModel | |
from .modules.model_infer import WanModel | |
from .modules.t5 import T5EncoderModel | |
from .modules.vae import WanVAE | |
# from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,get_sampling_sigmas, retrieve_timesteps) | |
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler) | |
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
from diffusers import FlowMatchEulerDiscreteScheduler | |
import inspect | |
import math | |
from typing import Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
import numpy as np | |
import random | |
def set_seed(seed): | |
if seed == -1: | |
seed = random.randint(0, 1000000) | |
seed = int(seed) | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
class FlowMatchScheduler(): | |
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003 / 1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False): | |
self.num_train_timesteps = num_train_timesteps | |
self.shift = shift | |
self.sigma_max = sigma_max | |
self.sigma_min = sigma_min | |
self.inverse_timesteps = inverse_timesteps | |
self.extra_one_step = extra_one_step | |
self.reverse_sigmas = reverse_sigmas | |
self.set_timesteps(num_inference_steps) | |
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None): | |
if shift is not None: | |
self.shift = shift | |
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength | |
if self.extra_one_step: | |
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] | |
else: | |
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) | |
if self.inverse_timesteps: | |
self.sigmas = torch.flip(self.sigmas, dims=[0]) | |
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) | |
if self.reverse_sigmas: | |
self.sigmas = 1 - self.sigmas | |
self.timesteps = self.sigmas * self.num_train_timesteps | |
if training: | |
x = self.timesteps | |
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) | |
y_shifted = y - y.min() | |
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) | |
self.linear_timesteps_weights = bsmntw_weighing | |
def step(self, model_output, timestep, sample, to_final=False): | |
if isinstance(timestep, torch.Tensor): | |
timestep = timestep.cpu() | |
timestep_id = torch.argmin((self.timesteps - timestep).abs()) | |
sigma = self.sigmas[timestep_id] | |
if to_final or timestep_id + 1 >= len(self.timesteps): | |
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 | |
else: | |
sigma_ = self.sigmas[timestep_id + 1] | |
prev_sample = sample + model_output * (sigma_ - sigma) | |
return prev_sample | |
def return_to_timestep(self, timestep, sample, sample_stablized): | |
if isinstance(timestep, torch.Tensor): | |
timestep = timestep.cpu() | |
timestep_id = torch.argmin((self.timesteps - timestep).abs()) | |
sigma = self.sigmas[timestep_id] | |
model_output = (sample - sample_stablized) / sigma | |
return model_output | |
def add_noise(self, original_samples, noise, timestep): | |
if isinstance(timestep, torch.Tensor): | |
timestep = timestep.cpu() | |
timestep_id = torch.argmin((self.timesteps - timestep).abs()) | |
sigma = self.sigmas[timestep_id] | |
sample = (1 - sigma) * original_samples + sigma * noise | |
return sample | |
def training_target(self, sample, noise, timestep): | |
target = noise - sample | |
return target | |
def training_weight(self, timestep): | |
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) | |
weights = self.linear_timesteps_weights[timestep_id] | |
return weights | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
def retrieve_timesteps( | |
scheduler, | |
num_inference_steps: Optional[int] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
timesteps: Optional[List[int]] = None, | |
sigmas: Optional[List[float]] = None, | |
**kwargs, | |
): | |
r""" | |
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
Args: | |
scheduler (`SchedulerMixin`): | |
The scheduler to get timesteps from. | |
num_inference_steps (`int`): | |
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
must be `None`. | |
device (`str` or `torch.device`, *optional*): | |
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
timesteps (`List[int]`, *optional*): | |
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
`num_inference_steps` and `sigmas` must be `None`. | |
sigmas (`List[float]`, *optional*): | |
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
`num_inference_steps` and `timesteps` must be `None`. | |
Returns: | |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
second element is the number of inference steps. | |
""" | |
if timesteps is not None and sigmas is not None: | |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
if timesteps is not None: | |
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accepts_timesteps: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" timestep schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
elif sigmas is not None: | |
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accept_sigmas: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" sigmas schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
else: | |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
return timesteps, num_inference_steps | |
class WanI2V: | |
def __init__( | |
self, | |
config, | |
checkpoint_dir, | |
device_id=0, | |
rank=0, | |
t5_fsdp=False, | |
dit_fsdp=False, | |
use_usp=False, | |
t5_cpu=False, | |
init_on_cpu=True, | |
): | |
r""" | |
Initializes the image-to-video generation model components. | |
Args: | |
config (EasyDict): | |
Object containing model parameters initialized from config.py | |
checkpoint_dir (`str`): | |
Path to directory containing model checkpoints | |
device_id (`int`, *optional*, defaults to 0): | |
Id of target GPU device | |
rank (`int`, *optional*, defaults to 0): | |
Process rank for distributed training | |
t5_fsdp (`bool`, *optional*, defaults to False): | |
Enable FSDP sharding for T5 model | |
dit_fsdp (`bool`, *optional*, defaults to False): | |
Enable FSDP sharding for DiT model | |
use_usp (`bool`, *optional*, defaults to False): | |
Enable distribution strategy of USP. | |
t5_cpu (`bool`, *optional*, defaults to False): | |
Whether to place T5 model on CPU. Only works without t5_fsdp. | |
init_on_cpu (`bool`, *optional*, defaults to True): | |
Enable initializing Transformer Model on CPU. Only works without FSDP or USP. | |
""" | |
self.device = torch.device(f"cuda:{device_id}") | |
self.config = config | |
self.rank = rank | |
self.use_usp = use_usp | |
self.t5_cpu = t5_cpu | |
self.scheduler =FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) | |
# self.scheduler =FlowMatchScheduler(shift=17, sigma_min=0.0, extra_one_step=True) | |
self.num_train_timesteps = config.num_train_timesteps | |
self.param_dtype = config.param_dtype | |
shard_fn = partial(shard_model, device_id=device_id) | |
self.text_encoder = T5EncoderModel( | |
text_len=config.text_len, | |
dtype=config.t5_dtype, | |
device=torch.device('cpu'), | |
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), | |
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), | |
shard_fn=shard_fn if t5_fsdp else None, | |
) | |
self.vae_stride = config.vae_stride | |
self.patch_size = config.patch_size | |
self.vae = WanVAE( | |
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), | |
device=self.device) | |
self.clip = CLIPModel( | |
dtype=config.clip_dtype, | |
device=self.device, | |
checkpoint_path=os.path.join(checkpoint_dir,config.clip_checkpoint), | |
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) | |
logging.info(f"Creating WanModel from {checkpoint_dir}") | |
self.model = WanModel.from_pretrained(checkpoint_dir) | |
self.model.eval().requires_grad_(False) | |
if t5_fsdp or dit_fsdp or use_usp: | |
init_on_cpu = False | |
if use_usp: | |
from xfuser.core.distributed import \ | |
get_sequence_parallel_world_size | |
from .distributed.xdit_context_parallel import (usp_attn_forward,usp_dit_forward) | |
for block in self.model.blocks: | |
block.self_attn.forward = types.MethodType( | |
usp_attn_forward, block.self_attn) | |
self.model.forward = types.MethodType(usp_dit_forward, self.model) | |
self.sp_size = get_sequence_parallel_world_size() | |
else: | |
self.sp_size = 1 | |
if dist.is_initialized(): | |
dist.barrier() | |
if dit_fsdp: | |
self.model = shard_fn(self.model) | |
else: | |
if not init_on_cpu: | |
self.model=self.model.to(self.device) | |
self.sample_neg_prompt = config.sample_neg_prompt | |
def generate(self, | |
args, | |
input_prompt, | |
img, | |
max_area=720 * 1280, | |
frame_num=81, | |
shift=5.0, | |
sample_solver='unipc', | |
sampling_steps=40, | |
guide_scale=5.0, | |
n_prompt="", | |
seed=-1, | |
offload_model=True, | |
student_steps=20, | |
norm=2, | |
frame_type="all", | |
channel_type="all", | |
random_channel=False, | |
): | |
r""" | |
Generates video frames from input image and text prompt using diffusion process. | |
Args: | |
input_prompt (`str`): | |
Text prompt for content generation. | |
img (PIL.Image.Image): | |
Input image tensor. Shape: [3, H, W] | |
max_area (`int`, *optional*, defaults to 720*1280): | |
Maximum pixel area for latent space calculation. Controls video resolution scaling | |
frame_num (`int`, *optional*, defaults to 81): | |
How many frames to sample from a video. The number should be 4n+1 | |
shift (`float`, *optional*, defaults to 5.0): | |
Noise schedule shift parameter. Affects temporal dynamics | |
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. | |
sample_solver (`str`, *optional*, defaults to 'unipc'): | |
Solver used to sample the video. | |
sampling_steps (`int`, *optional*, defaults to 40): | |
Number of diffusion sampling steps. Higher values improve quality but slow generation | |
guide_scale (`float`, *optional*, defaults 5.0): | |
Classifier-free guidance scale. Controls prompt adherence vs. creativity | |
n_prompt (`str`, *optional*, defaults to ""): | |
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` | |
seed (`int`, *optional*, defaults to -1): | |
Random seed for noise generation. If -1, use random seed | |
offload_model (`bool`, *optional*, defaults to True): | |
If True, offloads models to CPU during generation to save VRAM | |
Returns: | |
torch.Tensor: | |
Generated video frames tensor. Dimensions: (C, N H, W) where: | |
- C: Color channels (3 for RGB) | |
- N: Number of frames (81) | |
- H: Frame height (from max_area) | |
- W: Frame width from max_area) | |
""" | |
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) | |
F = frame_num | |
h, w = img.shape[1:] | |
aspect_ratio = h / w | |
lat_h = round( | |
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // | |
self.patch_size[1] * self.patch_size[1]) | |
lat_w = round( | |
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // | |
self.patch_size[2] * self.patch_size[2]) | |
h = lat_h * self.vae_stride[1] | |
w = lat_w * self.vae_stride[2] | |
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( | |
self.patch_size[1] * self.patch_size[2]) | |
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size | |
seed = seed if seed >= 0 else random.randint(0, sys.maxsize) | |
if seed >= 0: | |
set_seed(seed) | |
seed_g = torch.Generator(device=self.device) | |
seed_g.manual_seed(seed) | |
noise = torch.randn( | |
16, | |
F//4+1, | |
lat_h, | |
lat_w, | |
dtype=torch.float32, | |
generator=seed_g, | |
device=self.device) | |
msk = torch.ones(1, F, lat_h, lat_w, device=self.device) | |
msk[:, 1:] = 0 | |
msk = torch.concat([ | |
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] | |
],dim=1) | |
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) | |
msk = msk.transpose(1, 2)[0] | |
if n_prompt == "": | |
n_prompt = self.sample_neg_prompt | |
# preprocess | |
if not self.t5_cpu: | |
self.text_encoder.model=self.text_encoder.model.to(self.device) | |
context = self.text_encoder([input_prompt], self.device) | |
context_null = self.text_encoder([n_prompt], self.device) | |
if offload_model: | |
self.text_encoder.model=self.text_encoder.model.cpu() | |
else: | |
context = self.text_encoder([input_prompt], torch.device('cpu')) | |
context_null = self.text_encoder([n_prompt], torch.device('cpu')) | |
context = [t.to(self.device) for t in context] | |
context_null = [t.to(self.device) for t in context_null] | |
self.clip.model=self.clip.model.to(self.device) | |
clip_context = self.clip.visual([img[:, None, :, :]]) | |
if offload_model: | |
self.clip.model=self.clip.model.cpu() | |
torch.cuda.empty_cache() | |
y = self.vae.encode([ | |
torch.concat([ | |
torch.nn.functional.interpolate( | |
img[None].cpu(), size=(h, w), mode='bicubic').transpose( | |
0, 1), | |
torch.zeros(3, F-1, h, w) | |
],dim=1).to(self.device) | |
])[0] | |
y = torch.concat([msk, y]) | |
def noop_no_sync(): | |
yield | |
no_sync = getattr(self.model, 'no_sync', noop_no_sync) | |
# sampling_steps=10 | |
# evaluation mode | |
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): | |
device = self.device | |
num_inference_steps=sampling_steps | |
self.scheduler.set_timesteps(num_inference_steps, 1.0, shift=5.0) | |
# sample videos | |
latents = noise | |
if offload_model: | |
torch.cuda.empty_cache() | |
self.model=self.model.to(self.device) | |
# arg_c = { | |
# 'context': [context[0]], | |
# 'clip_fea': clip_context, | |
# 'seq_len': max_seq_len, | |
# 'y': [y], | |
# } | |
# | |
# arg_null = { | |
# 'context': context_null, | |
# 'clip_fea': clip_context, | |
# 'seq_len': max_seq_len, | |
# 'y': [y], | |
# } | |
# pre-process | |
model = _WrappedModel_Wan(self.model, self.scheduler.timesteps, self.num_train_timesteps, context_null, guide_scale) | |
model_kwargs = { | |
'seq_len': max_seq_len, | |
'y': [y], | |
'clip_fea': clip_context, | |
} | |
B = 1 | |
# latents = latents[0].unsqueeze(0) | |
latents = latents.unsqueeze(0) | |
oss_steps = search_OSS_video(model, latents, B, context, self.device, teacher_steps=sampling_steps, student_steps=student_steps, norm=norm, model_kwargs=model_kwargs, frame_type=frame_type, channel_type=channel_type, random_channel=random_channel) | |
latents_oss = infer_OSS(oss_steps, model, latents, context, self.device, model_kwargs=model_kwargs) | |
with open("%s.txt"%args.save_file,"w")as f:f.write(str(oss_steps)) | |
os._exit(2333) | |
# pdb.set_trace() | |
# teacher video | |
teacher_steps = list(range(1, sampling_steps+1)) | |
latents_tea = infer_OSS(teacher_steps, model, latents, context, self.device, model_kwargs=model_kwargs) | |
x0_oss = latents_oss | |
x0_tea = latents_tea | |
if offload_model: | |
self.model.cpu() | |
torch.cuda.empty_cache() | |
if self.rank == 0: | |
videos_oss = self.vae.decode(x0_oss) | |
videos_tea = self.vae.decode(x0_tea) | |
# for idx, t in enumerate(tqdm(self.scheduler.timesteps)): | |
# latent_model_input = [latent.to(self.device)] | |
# timestep = [t] | |
# | |
# timestep = torch.stack(timestep).to(self.device) | |
# noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0].to(torch.device('cpu') if offload_model else self.device) | |
# if offload_model: | |
# torch.cuda.empty_cache() | |
# noise_pred_uncond = self.model(latent_model_input, t=timestep, **arg_null)[0].to(torch.device('cpu') if offload_model else self.device) | |
# if offload_model: | |
# torch.cuda.empty_cache() | |
# noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) | |
# # noise_pred = noise_pred_cond | |
# latent = latent.to(torch.device('cpu') if offload_model else self.device) | |
# | |
# # latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) | |
# temp_x0 = self.scheduler.step( | |
# noise_pred.unsqueeze(0), | |
# self.scheduler.timesteps[idx], | |
# latent.unsqueeze(0))[0] | |
# latent = temp_x0.squeeze(0) | |
# | |
# x0 = [latent.to(self.device)] | |
# del latent_model_input, timestep | |
# | |
# if offload_model: | |
# self.model=self.model.cpu() | |
# torch.cuda.empty_cache() | |
# | |
# if self.rank == 0: | |
# videos = self.vae.decode(x0) | |
del noise, latents | |
# del self.scheduler | |
if offload_model: | |
gc.collect() | |
torch.cuda.synchronize() | |
if dist.is_initialized(): | |
dist.barrier() | |
return videos[0] if self.rank == 0 else None | |