File size: 7,762 Bytes
357c94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch
from pathlib import Path
from loguru import logger
from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE
from hymm_sp.vae import load_vae
from hymm_sp.modules import load_model
from hymm_sp.text_encoder import TextEncoder
import torch.distributed
from hymm_sp.modules.parallel_states import (
    nccl_info,
)
from hymm_sp.modules.fp8_optimization import convert_fp8_linear


class Inference(object):
    def __init__(self, 
                 args,
                 vae, 
                 vae_kwargs, 
                 text_encoder, 
                 model, 
                 text_encoder_2=None, 
                 pipeline=None, 
                 cpu_offload=False,
                 device=None, 
                 logger=None):
        self.vae = vae
        self.vae_kwargs = vae_kwargs
        
        self.text_encoder = text_encoder
        self.text_encoder_2 = text_encoder_2
        
        self.model = model
        self.pipeline = pipeline
        self.cpu_offload = cpu_offload
        
        self.args = args
        self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
        if nccl_info.sp_size > 1:
            self.device = torch.device(f"cuda:{torch.distributed.get_rank()}")
        
        self.logger = logger

    @classmethod
    def from_pretrained(cls, 
                        pretrained_model_path,
                        args,
                        device=None,
                        **kwargs):
        """
        Initialize the Inference pipeline.

        Args:
            pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
            device (int): The device for inference. Default is 0.
            logger (logging.Logger): The logger for the inference pipeline. Default is None.
        """
        # ========================================================================
        logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
        
        # ======================== Get the args path =============================
        
        # Set device and disable gradient
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        torch.set_grad_enabled(False)
        logger.info("Building model...")
        factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]}
        in_channels = args.latent_channels
        out_channels = args.latent_channels
        print("="*25, f"build model", "="*25)
        model = load_model(
            args,
            in_channels=in_channels,
            out_channels=out_channels,
            factor_kwargs=factor_kwargs
        )
        if args.use_fp8:
            convert_fp8_linear(model, pretrained_model_path, original_dtype=PRECISION_TO_TYPE[args.precision])
        if args.cpu_offload:
            print(f'='*20, f'load transformer to cpu')
            model = model.to('cpu')
            torch.cuda.empty_cache()
        else:
            model = model.to(device)
        model = Inference.load_state_dict(args, model, pretrained_model_path)
        model.eval()
        
        # ============================= Build extra models ========================
        # VAE
        print("="*25, f"load vae", "="*25)
        vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device='cpu' if args.cpu_offload else device)
        vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio}
        
        # Text encoder
        if args.prompt_template_video is not None:
            crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
        else:
            crop_start = 0
        max_length = args.text_len + crop_start

        # prompt_template_video
        prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None
        print("="*25, f"load llava", "="*25)
        text_encoder = TextEncoder(text_encoder_type = args.text_encoder,
                                   max_length = max_length,
                                   text_encoder_precision = args.text_encoder_precision,
                                   tokenizer_type = args.tokenizer,
                                   use_attention_mask = args.use_attention_mask,
                                   prompt_template_video = prompt_template_video,
                                   hidden_state_skip_layer = args.hidden_state_skip_layer,
                                   apply_final_norm = args.apply_final_norm,
                                   reproduce = args.reproduce,
                                   logger = logger,
                                   device = 'cpu' if args.cpu_offload else device ,
                                   )
        text_encoder_2 = None
        if args.text_encoder_2 is not None:
            text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2,
                                         max_length=args.text_len_2,
                                         text_encoder_precision=args.text_encoder_precision_2,
                                         tokenizer_type=args.tokenizer_2,
                                         use_attention_mask=args.use_attention_mask,
                                         reproduce=args.reproduce,
                                         logger=logger,
                                         device='cpu' if args.cpu_offload else device , # if not args.use_cpu_offload else 'cpu'
                                         )

        return cls(args=args, 
                   vae=vae, 
                   vae_kwargs=vae_kwargs, 
                   text_encoder=text_encoder,
                   model=model, 
                   text_encoder_2=text_encoder_2, 
                   device=device, 
                   logger=logger)

    @staticmethod
    def load_state_dict(args, model, ckpt_path):
        load_key = args.load_key
        ckpt_path = Path(ckpt_path)
        if ckpt_path.is_dir():
            ckpt_path = next(ckpt_path.glob("*_model_states.pt"))
        state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
        if load_key in state_dict:
            state_dict = state_dict[load_key]
        elif load_key == ".":
            pass
        else:
            raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}")
        model.load_state_dict(state_dict, strict=False)
        return model

    def get_exp_dir_and_ckpt_id(self):
        if self.ckpt is None:
            raise ValueError("The checkpoint path is not provided.")

        ckpt = Path(self.ckpt)
        if ckpt.parents[1].name == "checkpoints":
            # It should be a standard checkpoint path. We use the parent directory as the default save directory.
            exp_dir = ckpt.parents[2]
        else:
            raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. "
                             f"It seems that the checkpoint path is not standard. Please explicitly provide the "
                             f"save path by --save-path.")
        return exp_dir, ckpt.parent.name

    @staticmethod
    def parse_size(size):
        if isinstance(size, int):
            size = [size]
        if not isinstance(size, (list, tuple)):
            raise ValueError(f"Size must be an integer or (height, width), got {size}.")
        if len(size) == 1:
            size = [size[0], size[0]]
        if len(size) != 2:
            raise ValueError(f"Size must be an integer or (height, width), got {size}.")
        return size