Hunyuan-Avatar / hymm_sp /inference.py
rahul7star's picture
Upload 99 files
357c94c verified
import torch
from pathlib import Path
from loguru import logger
from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE
from hymm_sp.vae import load_vae
from hymm_sp.modules import load_model
from hymm_sp.text_encoder import TextEncoder
import torch.distributed
from hymm_sp.modules.parallel_states import (
nccl_info,
)
from hymm_sp.modules.fp8_optimization import convert_fp8_linear
class Inference(object):
def __init__(self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
cpu_offload=False,
device=None,
logger=None):
self.vae = vae
self.vae_kwargs = vae_kwargs
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.model = model
self.pipeline = pipeline
self.cpu_offload = cpu_offload
self.args = args
self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
if nccl_info.sp_size > 1:
self.device = torch.device(f"cuda:{torch.distributed.get_rank()}")
self.logger = logger
@classmethod
def from_pretrained(cls,
pretrained_model_path,
args,
device=None,
**kwargs):
"""
Initialize the Inference pipeline.
Args:
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
device (int): The device for inference. Default is 0.
logger (logging.Logger): The logger for the inference pipeline. Default is None.
"""
# ========================================================================
logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
# ======================== Get the args path =============================
# Set device and disable gradient
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
logger.info("Building model...")
factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]}
in_channels = args.latent_channels
out_channels = args.latent_channels
print("="*25, f"build model", "="*25)
model = load_model(
args,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs
)
if args.use_fp8:
convert_fp8_linear(model, pretrained_model_path, original_dtype=PRECISION_TO_TYPE[args.precision])
if args.cpu_offload:
print(f'='*20, f'load transformer to cpu')
model = model.to('cpu')
torch.cuda.empty_cache()
else:
model = model.to(device)
model = Inference.load_state_dict(args, model, pretrained_model_path)
model.eval()
# ============================= Build extra models ========================
# VAE
print("="*25, f"load vae", "="*25)
vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device='cpu' if args.cpu_offload else device)
vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio}
# Text encoder
if args.prompt_template_video is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
else:
crop_start = 0
max_length = args.text_len + crop_start
# prompt_template_video
prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None
print("="*25, f"load llava", "="*25)
text_encoder = TextEncoder(text_encoder_type = args.text_encoder,
max_length = max_length,
text_encoder_precision = args.text_encoder_precision,
tokenizer_type = args.tokenizer,
use_attention_mask = args.use_attention_mask,
prompt_template_video = prompt_template_video,
hidden_state_skip_layer = args.hidden_state_skip_layer,
apply_final_norm = args.apply_final_norm,
reproduce = args.reproduce,
logger = logger,
device = 'cpu' if args.cpu_offload else device ,
)
text_encoder_2 = None
if args.text_encoder_2 is not None:
text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2,
max_length=args.text_len_2,
text_encoder_precision=args.text_encoder_precision_2,
tokenizer_type=args.tokenizer_2,
use_attention_mask=args.use_attention_mask,
reproduce=args.reproduce,
logger=logger,
device='cpu' if args.cpu_offload else device , # if not args.use_cpu_offload else 'cpu'
)
return cls(args=args,
vae=vae,
vae_kwargs=vae_kwargs,
text_encoder=text_encoder,
model=model,
text_encoder_2=text_encoder_2,
device=device,
logger=logger)
@staticmethod
def load_state_dict(args, model, ckpt_path):
load_key = args.load_key
ckpt_path = Path(ckpt_path)
if ckpt_path.is_dir():
ckpt_path = next(ckpt_path.glob("*_model_states.pt"))
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
if load_key in state_dict:
state_dict = state_dict[load_key]
elif load_key == ".":
pass
else:
raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}")
model.load_state_dict(state_dict, strict=False)
return model
def get_exp_dir_and_ckpt_id(self):
if self.ckpt is None:
raise ValueError("The checkpoint path is not provided.")
ckpt = Path(self.ckpt)
if ckpt.parents[1].name == "checkpoints":
# It should be a standard checkpoint path. We use the parent directory as the default save directory.
exp_dir = ckpt.parents[2]
else:
raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. "
f"It seems that the checkpoint path is not standard. Please explicitly provide the "
f"save path by --save-path.")
return exp_dir, ckpt.parent.name
@staticmethod
def parse_size(size):
if isinstance(size, int):
size = [size]
if not isinstance(size, (list, tuple)):
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
if len(size) == 1:
size = [size[0], size[0]]
if len(size) != 2:
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
return size