File size: 2,086 Bytes
4bf9661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from .base_prompter import BasePrompter
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
from ..models.stepvideo_text_encoder import STEP1TextEncoder
from transformers import BertTokenizer
import os, torch


class StepVideoPrompter(BasePrompter):

    def __init__(
        self,
        tokenizer_1_path=None,
    ):
        if tokenizer_1_path is None:
            base_path = os.path.dirname(os.path.dirname(__file__))
            tokenizer_1_path = os.path.join(
                base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
        super().__init__()
        self.tokenizer_1 = BertTokenizer.from_pretrained(tokenizer_1_path)

    def fetch_models(self, text_encoder_1: HunyuanDiTCLIPTextEncoder = None, text_encoder_2: STEP1TextEncoder = None):
        self.text_encoder_1 = text_encoder_1
        self.text_encoder_2 = text_encoder_2

    def encode_prompt_using_clip(self, prompt, max_length, device):
        text_inputs = self.tokenizer_1(
            prompt,
            padding="max_length",
            max_length=max_length,
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt",
        )
        prompt_embeds = self.text_encoder_1(
            text_inputs.input_ids.to(device),
            attention_mask=text_inputs.attention_mask.to(device),
        )
        return prompt_embeds

    def encode_prompt_using_llm(self, prompt, max_length, device):
        y, y_mask = self.text_encoder_2(prompt, max_length=max_length, device=device)
        return y, y_mask

    def encode_prompt(self,
                      prompt,
                      positive=True,
                      device="cuda"):

        prompt = self.process_prompt(prompt, positive=positive)

        clip_embeds = self.encode_prompt_using_clip(prompt, max_length=77, device=device)
        llm_embeds, llm_mask = self.encode_prompt_using_llm(prompt, max_length=320, device=device)

        llm_mask = torch.nn.functional.pad(llm_mask, (clip_embeds.shape[1], 0), value=1)

        return clip_embeds, llm_embeds, llm_mask