Spaces:
Build error
Build error
File size: 7,762 Bytes
357c94c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import torch
from pathlib import Path
from loguru import logger
from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE
from hymm_sp.vae import load_vae
from hymm_sp.modules import load_model
from hymm_sp.text_encoder import TextEncoder
import torch.distributed
from hymm_sp.modules.parallel_states import (
nccl_info,
)
from hymm_sp.modules.fp8_optimization import convert_fp8_linear
class Inference(object):
def __init__(self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
cpu_offload=False,
device=None,
logger=None):
self.vae = vae
self.vae_kwargs = vae_kwargs
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.model = model
self.pipeline = pipeline
self.cpu_offload = cpu_offload
self.args = args
self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
if nccl_info.sp_size > 1:
self.device = torch.device(f"cuda:{torch.distributed.get_rank()}")
self.logger = logger
@classmethod
def from_pretrained(cls,
pretrained_model_path,
args,
device=None,
**kwargs):
"""
Initialize the Inference pipeline.
Args:
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
device (int): The device for inference. Default is 0.
logger (logging.Logger): The logger for the inference pipeline. Default is None.
"""
# ========================================================================
logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
# ======================== Get the args path =============================
# Set device and disable gradient
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
logger.info("Building model...")
factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]}
in_channels = args.latent_channels
out_channels = args.latent_channels
print("="*25, f"build model", "="*25)
model = load_model(
args,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs
)
if args.use_fp8:
convert_fp8_linear(model, pretrained_model_path, original_dtype=PRECISION_TO_TYPE[args.precision])
if args.cpu_offload:
print(f'='*20, f'load transformer to cpu')
model = model.to('cpu')
torch.cuda.empty_cache()
else:
model = model.to(device)
model = Inference.load_state_dict(args, model, pretrained_model_path)
model.eval()
# ============================= Build extra models ========================
# VAE
print("="*25, f"load vae", "="*25)
vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device='cpu' if args.cpu_offload else device)
vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio}
# Text encoder
if args.prompt_template_video is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
else:
crop_start = 0
max_length = args.text_len + crop_start
# prompt_template_video
prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None
print("="*25, f"load llava", "="*25)
text_encoder = TextEncoder(text_encoder_type = args.text_encoder,
max_length = max_length,
text_encoder_precision = args.text_encoder_precision,
tokenizer_type = args.tokenizer,
use_attention_mask = args.use_attention_mask,
prompt_template_video = prompt_template_video,
hidden_state_skip_layer = args.hidden_state_skip_layer,
apply_final_norm = args.apply_final_norm,
reproduce = args.reproduce,
logger = logger,
device = 'cpu' if args.cpu_offload else device ,
)
text_encoder_2 = None
if args.text_encoder_2 is not None:
text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2,
max_length=args.text_len_2,
text_encoder_precision=args.text_encoder_precision_2,
tokenizer_type=args.tokenizer_2,
use_attention_mask=args.use_attention_mask,
reproduce=args.reproduce,
logger=logger,
device='cpu' if args.cpu_offload else device , # if not args.use_cpu_offload else 'cpu'
)
return cls(args=args,
vae=vae,
vae_kwargs=vae_kwargs,
text_encoder=text_encoder,
model=model,
text_encoder_2=text_encoder_2,
device=device,
logger=logger)
@staticmethod
def load_state_dict(args, model, ckpt_path):
load_key = args.load_key
ckpt_path = Path(ckpt_path)
if ckpt_path.is_dir():
ckpt_path = next(ckpt_path.glob("*_model_states.pt"))
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
if load_key in state_dict:
state_dict = state_dict[load_key]
elif load_key == ".":
pass
else:
raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}")
model.load_state_dict(state_dict, strict=False)
return model
def get_exp_dir_and_ckpt_id(self):
if self.ckpt is None:
raise ValueError("The checkpoint path is not provided.")
ckpt = Path(self.ckpt)
if ckpt.parents[1].name == "checkpoints":
# It should be a standard checkpoint path. We use the parent directory as the default save directory.
exp_dir = ckpt.parents[2]
else:
raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. "
f"It seems that the checkpoint path is not standard. Please explicitly provide the "
f"save path by --save-path.")
return exp_dir, ckpt.parent.name
@staticmethod
def parse_size(size):
if isinstance(size, int):
size = [size]
if not isinstance(size, (list, tuple)):
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
if len(size) == 1:
size = [size[0], size[0]]
if len(size) != 2:
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
return size
|