Spaces:
Running
on
L40S
Running
on
L40S
from .base_prompter import BasePrompter | |
from ..models.sd3_text_encoder import SD3TextEncoder1 | |
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder | |
from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor | |
import os, torch | |
from typing import Union | |
PROMPT_TEMPLATE_ENCODE = ( | |
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " | |
"quantity, text, spatial relationships of the objects and background:<|eot_id|>" | |
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") | |
PROMPT_TEMPLATE_ENCODE_VIDEO = ( | |
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " | |
"1. The main content and theme of the video." | |
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." | |
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." | |
"4. background environment, light, style and atmosphere." | |
"5. camera angles, movements, and transitions used in the video:<|eot_id|>" | |
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") | |
PROMPT_TEMPLATE_ENCODE_I2V = ( | |
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, " | |
"quantity, text, spatial relationships of the objects and background:<|eot_id|>" | |
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" | |
"<|start_header_id|>assistant<|end_header_id|>\n\n" | |
) | |
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( | |
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: " | |
"1. The main content and theme of the video." | |
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." | |
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." | |
"4. background environment, light, style and atmosphere." | |
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" | |
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" | |
"<|start_header_id|>assistant<|end_header_id|>\n\n" | |
) | |
PROMPT_TEMPLATE = { | |
"dit-llm-encode": { | |
"template": PROMPT_TEMPLATE_ENCODE, | |
"crop_start": 36, | |
}, | |
"dit-llm-encode-video": { | |
"template": PROMPT_TEMPLATE_ENCODE_VIDEO, | |
"crop_start": 95, | |
}, | |
"dit-llm-encode-i2v": { | |
"template": PROMPT_TEMPLATE_ENCODE_I2V, | |
"crop_start": 36, | |
"image_emb_start": 5, | |
"image_emb_end": 581, | |
"image_emb_len": 576, | |
"double_return_token_id": 271 | |
}, | |
"dit-llm-encode-video-i2v": { | |
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, | |
"crop_start": 103, | |
"image_emb_start": 5, | |
"image_emb_end": 581, | |
"image_emb_len": 576, | |
"double_return_token_id": 271 | |
}, | |
} | |
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" | |
class HunyuanVideoPrompter(BasePrompter): | |
def __init__( | |
self, | |
tokenizer_1_path=None, | |
tokenizer_2_path=None, | |
): | |
if tokenizer_1_path is None: | |
base_path = os.path.dirname(os.path.dirname(__file__)) | |
tokenizer_1_path = os.path.join( | |
base_path, "tokenizer_configs/hunyuan_video/tokenizer_1") | |
if tokenizer_2_path is None: | |
base_path = os.path.dirname(os.path.dirname(__file__)) | |
tokenizer_2_path = os.path.join( | |
base_path, "tokenizer_configs/hunyuan_video/tokenizer_2") | |
super().__init__() | |
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path) | |
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right') | |
self.text_encoder_1: SD3TextEncoder1 = None | |
self.text_encoder_2: HunyuanVideoLLMEncoder = None | |
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode'] | |
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video'] | |
def fetch_models(self, | |
text_encoder_1: SD3TextEncoder1 = None, | |
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None): | |
self.text_encoder_1 = text_encoder_1 | |
self.text_encoder_2 = text_encoder_2 | |
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder): | |
# processor | |
# TODO: may need to replace processor with local implementation | |
base_path = os.path.dirname(os.path.dirname(__file__)) | |
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2") | |
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path) | |
# template | |
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v'] | |
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v'] | |
def apply_text_to_template(self, text, template): | |
assert isinstance(template, str) | |
if isinstance(text, list): | |
return [self.apply_text_to_template(text_) for text_ in text] | |
elif isinstance(text, str): | |
# Will send string to tokenizer. Used for llm | |
return template.format(text) | |
else: | |
raise TypeError(f"Unsupported prompt type: {type(text)}") | |
def encode_prompt_using_clip(self, prompt, max_length, device): | |
tokenized_result = self.tokenizer_1( | |
prompt, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_attention_mask=True | |
) | |
input_ids = tokenized_result.input_ids.to(device) | |
attention_mask = tokenized_result.attention_mask.to(device) | |
return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0] | |
def encode_prompt_using_llm(self, | |
prompt, | |
max_length, | |
device, | |
crop_start, | |
hidden_state_skip_layer=2, | |
use_attention_mask=True): | |
max_length += crop_start | |
inputs = self.tokenizer_2(prompt, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=max_length, | |
truncation=True) | |
input_ids = inputs.input_ids.to(device) | |
attention_mask = inputs.attention_mask.to(device) | |
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer) | |
# crop out | |
if crop_start > 0: | |
last_hidden_state = last_hidden_state[:, crop_start:] | |
attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None) | |
return last_hidden_state, attention_mask | |
def encode_prompt_using_mllm(self, | |
prompt, | |
images, | |
max_length, | |
device, | |
crop_start, | |
hidden_state_skip_layer=2, | |
use_attention_mask=True, | |
image_embed_interleave=4): | |
image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device) | |
max_length += crop_start | |
inputs = self.tokenizer_2(prompt, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=max_length, | |
truncation=True) | |
input_ids = inputs.input_ids.to(device) | |
attention_mask = inputs.attention_mask.to(device) | |
last_hidden_state = self.text_encoder_2(input_ids=input_ids, | |
attention_mask=attention_mask, | |
hidden_state_skip_layer=hidden_state_skip_layer, | |
pixel_values=image_outputs) | |
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576)) | |
image_crop_start = self.prompt_template_video.get("image_emb_start", 5) | |
image_crop_end = self.prompt_template_video.get("image_emb_end", 581) | |
batch_indices, last_double_return_token_indices = torch.where( | |
input_ids == self.prompt_template_video.get("double_return_token_id", 271)) | |
if last_double_return_token_indices.shape[0] == 3: | |
# in case the prompt is too long | |
last_double_return_token_indices = torch.cat(( | |
last_double_return_token_indices, | |
torch.tensor([input_ids.shape[-1]]), | |
)) | |
batch_indices = torch.cat((batch_indices, torch.tensor([0]))) | |
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1]) | |
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1] | |
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4) | |
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576)) | |
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4) | |
attention_mask_assistant_crop_end = last_double_return_token_indices | |
text_last_hidden_state = [] | |
text_attention_mask = [] | |
image_last_hidden_state = [] | |
image_attention_mask = [] | |
for i in range(input_ids.shape[0]): | |
text_last_hidden_state.append( | |
torch.cat([ | |
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()], | |
last_hidden_state[i, assistant_crop_end[i].item():], | |
])) | |
text_attention_mask.append( | |
torch.cat([ | |
attention_mask[ | |
i, | |
crop_start:attention_mask_assistant_crop_start[i].item(), | |
], | |
attention_mask[i, attention_mask_assistant_crop_end[i].item():], | |
]) if use_attention_mask else None) | |
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end]) | |
image_attention_mask.append( | |
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device). | |
to(attention_mask.dtype) if use_attention_mask else None) | |
text_last_hidden_state = torch.stack(text_last_hidden_state) | |
text_attention_mask = torch.stack(text_attention_mask) | |
image_last_hidden_state = torch.stack(image_last_hidden_state) | |
image_attention_mask = torch.stack(image_attention_mask) | |
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :] | |
image_attention_mask = image_attention_mask[:, ::image_embed_interleave] | |
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and | |
image_last_hidden_state.shape[0] == image_attention_mask.shape[0]) | |
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1) | |
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1) | |
return last_hidden_state, attention_mask | |
def encode_prompt(self, | |
prompt, | |
images=None, | |
positive=True, | |
device="cuda", | |
clip_sequence_length=77, | |
llm_sequence_length=256, | |
data_type='video', | |
use_template=True, | |
hidden_state_skip_layer=2, | |
use_attention_mask=True, | |
image_embed_interleave=4): | |
prompt = self.process_prompt(prompt, positive=positive) | |
# apply template | |
if use_template: | |
template = self.prompt_template_video if data_type == 'video' else self.prompt_template | |
prompt_formated = self.apply_text_to_template(prompt, template['template']) | |
else: | |
prompt_formated = prompt | |
# Text encoder | |
if data_type == 'video': | |
crop_start = self.prompt_template_video.get("crop_start", 0) | |
else: | |
crop_start = self.prompt_template.get("crop_start", 0) | |
# CLIP | |
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device) | |
# LLM | |
if images is None: | |
prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start, | |
hidden_state_skip_layer, use_attention_mask) | |
else: | |
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device, | |
crop_start, hidden_state_skip_layer, use_attention_mask, | |
image_embed_interleave) | |
return prompt_emb, pooled_prompt_emb, attention_mask | |