Spaces:
Build error
Build error
import torch | |
from pathlib import Path | |
from loguru import logger | |
from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE | |
from hymm_sp.vae import load_vae | |
from hymm_sp.modules import load_model | |
from hymm_sp.text_encoder import TextEncoder | |
import torch.distributed | |
from hymm_sp.modules.parallel_states import ( | |
nccl_info, | |
) | |
from hymm_sp.modules.fp8_optimization import convert_fp8_linear | |
class Inference(object): | |
def __init__(self, | |
args, | |
vae, | |
vae_kwargs, | |
text_encoder, | |
model, | |
text_encoder_2=None, | |
pipeline=None, | |
cpu_offload=False, | |
device=None, | |
logger=None): | |
self.vae = vae | |
self.vae_kwargs = vae_kwargs | |
self.text_encoder = text_encoder | |
self.text_encoder_2 = text_encoder_2 | |
self.model = model | |
self.pipeline = pipeline | |
self.cpu_offload = cpu_offload | |
self.args = args | |
self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu" | |
if nccl_info.sp_size > 1: | |
self.device = torch.device(f"cuda:{torch.distributed.get_rank()}") | |
self.logger = logger | |
def from_pretrained(cls, | |
pretrained_model_path, | |
args, | |
device=None, | |
**kwargs): | |
""" | |
Initialize the Inference pipeline. | |
Args: | |
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints. | |
device (int): The device for inference. Default is 0. | |
logger (logging.Logger): The logger for the inference pipeline. Default is None. | |
""" | |
# ======================================================================== | |
logger.info(f"Got text-to-video model root path: {pretrained_model_path}") | |
# ======================== Get the args path ============================= | |
# Set device and disable gradient | |
if device is None: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch.set_grad_enabled(False) | |
logger.info("Building model...") | |
factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]} | |
in_channels = args.latent_channels | |
out_channels = args.latent_channels | |
print("="*25, f"build model", "="*25) | |
model = load_model( | |
args, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
factor_kwargs=factor_kwargs | |
) | |
if args.use_fp8: | |
convert_fp8_linear(model, pretrained_model_path, original_dtype=PRECISION_TO_TYPE[args.precision]) | |
if args.cpu_offload: | |
print(f'='*20, f'load transformer to cpu') | |
model = model.to('cpu') | |
torch.cuda.empty_cache() | |
else: | |
model = model.to(device) | |
model = Inference.load_state_dict(args, model, pretrained_model_path) | |
model.eval() | |
# ============================= Build extra models ======================== | |
# VAE | |
print("="*25, f"load vae", "="*25) | |
vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device='cpu' if args.cpu_offload else device) | |
vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio} | |
# Text encoder | |
if args.prompt_template_video is not None: | |
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0) | |
else: | |
crop_start = 0 | |
max_length = args.text_len + crop_start | |
# prompt_template_video | |
prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None | |
print("="*25, f"load llava", "="*25) | |
text_encoder = TextEncoder(text_encoder_type = args.text_encoder, | |
max_length = max_length, | |
text_encoder_precision = args.text_encoder_precision, | |
tokenizer_type = args.tokenizer, | |
use_attention_mask = args.use_attention_mask, | |
prompt_template_video = prompt_template_video, | |
hidden_state_skip_layer = args.hidden_state_skip_layer, | |
apply_final_norm = args.apply_final_norm, | |
reproduce = args.reproduce, | |
logger = logger, | |
device = 'cpu' if args.cpu_offload else device , | |
) | |
text_encoder_2 = None | |
if args.text_encoder_2 is not None: | |
text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2, | |
max_length=args.text_len_2, | |
text_encoder_precision=args.text_encoder_precision_2, | |
tokenizer_type=args.tokenizer_2, | |
use_attention_mask=args.use_attention_mask, | |
reproduce=args.reproduce, | |
logger=logger, | |
device='cpu' if args.cpu_offload else device , # if not args.use_cpu_offload else 'cpu' | |
) | |
return cls(args=args, | |
vae=vae, | |
vae_kwargs=vae_kwargs, | |
text_encoder=text_encoder, | |
model=model, | |
text_encoder_2=text_encoder_2, | |
device=device, | |
logger=logger) | |
def load_state_dict(args, model, ckpt_path): | |
load_key = args.load_key | |
ckpt_path = Path(ckpt_path) | |
if ckpt_path.is_dir(): | |
ckpt_path = next(ckpt_path.glob("*_model_states.pt")) | |
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage) | |
if load_key in state_dict: | |
state_dict = state_dict[load_key] | |
elif load_key == ".": | |
pass | |
else: | |
raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}") | |
model.load_state_dict(state_dict, strict=False) | |
return model | |
def get_exp_dir_and_ckpt_id(self): | |
if self.ckpt is None: | |
raise ValueError("The checkpoint path is not provided.") | |
ckpt = Path(self.ckpt) | |
if ckpt.parents[1].name == "checkpoints": | |
# It should be a standard checkpoint path. We use the parent directory as the default save directory. | |
exp_dir = ckpt.parents[2] | |
else: | |
raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. " | |
f"It seems that the checkpoint path is not standard. Please explicitly provide the " | |
f"save path by --save-path.") | |
return exp_dir, ckpt.parent.name | |
def parse_size(size): | |
if isinstance(size, int): | |
size = [size] | |
if not isinstance(size, (list, tuple)): | |
raise ValueError(f"Size must be an integer or (height, width), got {size}.") | |
if len(size) == 1: | |
size = [size[0], size[0]] | |
if len(size) != 2: | |
raise ValueError(f"Size must be an integer or (height, width), got {size}.") | |
return size | |