Spaces:
Running
on
Zero
Running
on
Zero
import re | |
import copy | |
from typing import Literal | |
from PIL import Image | |
from tqdm.auto import tqdm | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from transformers import AutoTokenizer | |
from transformers.cache_utils import Cache, StaticCache | |
from models.nextstep_model import NextStep | |
from vae.nextstep_ae import AutoencoderKL | |
from utils.image_utils import to_pil | |
from utils.model_utils import layer_norm | |
from utils.compile_utils import compile_manager | |
from utils.misc import set_seed | |
DEFAULT_IMAGE_AREA_TOKEN = "<|image_area|>" | |
def hw2str(h: int, w: int) -> str: | |
return f"{h}*{w}" | |
class NextStepPipeline: | |
def __init__( | |
self, | |
model_name_or_path: str | None = None, | |
vae_name_or_path: str | None = None, | |
tokenizer: AutoTokenizer | None = None, | |
model: nn.Module | None = None, | |
vae: AutoencoderKL | None = None, | |
): | |
if model is not None: | |
self.tokenizer = copy.deepcopy(tokenizer) | |
self.tokenizer.padding_side = "left" | |
self.model = model | |
elif model_name_or_path is not None: | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name_or_path, | |
local_files_only=True, | |
model_max_length=4096, | |
padding_side="left", | |
use_fast=True, | |
) | |
self.model: NextStep = NextStep.from_pretrained(model_name_or_path, local_files_only=True) | |
else: | |
raise ValueError("model or model_name_or_path is required") | |
self.tokenizer.add_eos_token = False | |
if vae_name_or_path is None: | |
vae_name_or_path = getattr(self.model.config, "vae_name_or_path", None) | |
if vae is not None: | |
self.vae = vae | |
elif vae_name_or_path is not None: | |
self.vae = AutoencoderKL.from_pretrained(vae_name_or_path) | |
else: | |
raise ValueError("vae or vae_name_or_path is required") | |
self.model.eval() | |
self.vae.eval() | |
vae_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
self.down_factor = vae_factor * self.model.config.latent_patch_size | |
self.shift_factor = getattr(self.vae.config, "shift_factor", 0.0) | |
self.scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) | |
self.boi = self.model.config.boi | |
self.eoi = self.model.config.eoi | |
self.image_placeholder_id = self.model.config.image_placeholder_id | |
self.pil2tensor = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
self.__device = self.model.device | |
self.__dtype = self.model.dtype | |
def device(self): | |
return self.__device | |
def device_type(self): | |
if isinstance(self.__device, str): | |
return self.__device | |
return self.__device.type | |
def dtype(self): | |
return self.__dtype | |
def to(self, device: str | None = None, dtype: torch.dtype | None = None): | |
if device is not None: | |
self.__device = device | |
if dtype is not None: | |
self.__dtype = dtype | |
self.model.to(self.__device, dtype=self.__dtype) | |
self.vae.to(self.__device, dtype=self.__dtype) | |
return self | |
def _image_str(self, hw: tuple[int, int] = (256, 256)): | |
latent_hw = (hw[0] // self.down_factor, hw[1] // self.down_factor) | |
image_ids = [self.boi] + [self.image_placeholder_id] * (latent_hw[0] * latent_hw[1]) + [self.eoi] | |
image_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(*latent_hw) + self.tokenizer.decode(image_ids) | |
return image_str | |
def _check_input( | |
self, captions: str | list[str], images: Image.Image | list[Image.Image] | None | |
) -> tuple[list[str], list[Image.Image] | None]: | |
if not isinstance(captions, list): | |
captions = [captions] | |
if images is not None: | |
if not isinstance(images, list): | |
images = [images] | |
# Validate image count matches <image> tokens in captions | |
image_token_count = 0 | |
for caption in captions: | |
num_image_token = len(re.findall(r"<image>", caption)) | |
assert num_image_token == 1, f"Caption `{caption}` has {num_image_token} image tokens, but only 1 is allowed." | |
image_token_count += num_image_token | |
if image_token_count != len(images): | |
raise ValueError( | |
f"Number of images ({len(images)}) does not match number of image tokens ({image_token_count}).\n" | |
f"Captions: {captions}" | |
) | |
hws = [(image.size[1], image.size[0]) for image in images] | |
# Replace <image> tokens sequentially with corresponding image_str based on hw | |
processed_captions = [] | |
image_idx = 0 | |
for caption in captions: | |
# Process each caption | |
processed_caption = caption | |
num_image_tokens = processed_caption.count("<image>") | |
# Replace each <image> token in order | |
for _ in range(num_image_tokens): | |
processed_caption = processed_caption.replace("<image>", self._image_str(hws[image_idx]), 1) | |
image_idx += 1 | |
processed_captions.append(processed_caption) | |
captions = processed_captions | |
return captions, images | |
def _build_captions( | |
self, | |
captions: str | list[str], | |
images: list[Image.Image] | None = None, | |
num_images_per_caption: int = 1, | |
positive_prompt: str | None = None, | |
negative_prompt: str | None = None, | |
cfg: float = 1.0, | |
cfg_img: float = 1.0, | |
): | |
# 1. repeat captions and images | |
if not isinstance(captions, list): | |
captions = [captions] | |
captions = [caption for caption in captions for _ in range(num_images_per_caption)] | |
if images is not None: | |
images = [image for image in images for _ in range(num_images_per_caption)] | |
# 2. add positive prompt | |
if positive_prompt is None: | |
positive_prompt = "" | |
captions = [f"{caption} {positive_prompt}" for caption in captions] | |
# 3. add negative prompt | |
if negative_prompt is None: | |
negative_prompt = "" | |
num_samples = len(captions) | |
if cfg != 1.0 and cfg_img != 1.0: # use both image and text CFG | |
w, h = images[0].size | |
captions = ( | |
captions + [self._image_str((h, w)) + negative_prompt] * num_samples | |
) | |
images = images + images | |
captions = captions + [negative_prompt] * num_samples | |
elif cfg != 1.0 and cfg_img == 1.0: # use text CFG | |
captions = captions + [negative_prompt] * num_samples | |
elif cfg == 1.0 and cfg_img == 1.0: | |
pass | |
return captions, images | |
def _add_prefix_ids(self, hw: tuple[int, int], input_ids: torch.Tensor, attention_mask: torch.Tensor): | |
prefix_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(hw[0] // self.down_factor, hw[1] // self.down_factor) | |
prefix_output = self.tokenizer(prefix_str, truncation=False, add_special_tokens=True, return_tensors="pt") | |
prefix_input_ids = prefix_output.input_ids.to(input_ids.device, dtype=input_ids.dtype) | |
prefix_attention_mask = prefix_output.attention_mask.to(attention_mask.device, dtype=attention_mask.dtype) | |
# remove bos token | |
if self.tokenizer.bos_token is not None: | |
prefix_input_ids = prefix_input_ids[:, 1:] | |
prefix_attention_mask = prefix_attention_mask[:, 1:] | |
# add boi token | |
prefix_input_ids = torch.cat( | |
[ | |
prefix_input_ids, | |
prefix_input_ids.new_tensor([self.model.config.boi]).unsqueeze(0), | |
], | |
dim=1, | |
) | |
prefix_attention_mask = torch.cat( | |
[ | |
prefix_attention_mask, | |
prefix_attention_mask.new_ones((prefix_attention_mask.shape[0], 1)), | |
], | |
dim=1, | |
) | |
bsz = input_ids.shape[0] | |
input_ids = torch.cat([input_ids, prefix_input_ids.expand(bsz, -1)], dim=1) | |
attention_mask = torch.cat([attention_mask, prefix_attention_mask.expand(bsz, -1)], dim=1) | |
return input_ids, attention_mask | |
def decoding( | |
self, | |
c: torch.Tensor, | |
attention_mask: torch.Tensor, | |
past_key_values: Cache, | |
max_new_len: int, | |
num_images_per_caption: int, | |
use_norm: bool = False, | |
cfg: float = 1.0, | |
cfg_img: float = 1.0, | |
cfg_schedule: Literal["linear", "constant"] = "constant", | |
timesteps_shift: float = 1.0, | |
num_sampling_steps: int = 20, | |
progress: bool = True, | |
): | |
indices = list(range(max_new_len)) | |
indices = tqdm(indices, unit="tokens") if progress else indices | |
tokens = None | |
unnormed_tokens = None | |
for _ in indices: | |
# cfg schedule follow Muse | |
if cfg_schedule == "linear": | |
tokens_len = 0 if tokens is None else tokens.shape[1] | |
cfg_iter = max(cfg / 2, 1 + (cfg - 1) * tokens_len / max_new_len) | |
cfg_img_iter = max(cfg_img / 2, 1 + (cfg_img - 1) * tokens_len / max_new_len) | |
elif cfg_schedule == "constant": | |
cfg_iter = cfg | |
cfg_img_iter = cfg_img | |
else: | |
raise NotImplementedError | |
c = self.model.image_out_projector(c) | |
token_sampled = self.model.image_head.sample( | |
c=c.squeeze(1), | |
cfg=cfg_iter, | |
cfg_img=cfg_img_iter, | |
timesteps_shift=timesteps_shift, | |
num_sampling_steps=num_sampling_steps, | |
noise_repeat=num_images_per_caption, | |
) | |
unnormed_token_sampled = token_sampled.clone() | |
if use_norm: | |
token_sampled = layer_norm(token_sampled, normalized_shape=token_sampled.size()[1:]) | |
if tokens is not None: | |
tokens = torch.cat([tokens, token_sampled.unsqueeze(1)], dim=1) | |
unnormed_tokens = torch.cat([unnormed_tokens, unnormed_token_sampled.unsqueeze(1)], dim=1) | |
else: | |
tokens = token_sampled.unsqueeze(1) | |
unnormed_tokens = unnormed_token_sampled.unsqueeze(1) | |
cur_inputs_embeds = self.model.image_in_projector(tokens[:, -1:]) | |
if cfg != 1.0 and cfg_img == 1.0: | |
cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds], dim=0) | |
elif cfg != 1.0 and cfg_img != 1.0: | |
cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds, cur_inputs_embeds], dim=0) | |
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) | |
outputs = self.model.forward_model( | |
inputs_embeds=cur_inputs_embeds, | |
attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
use_cache=True, | |
) | |
past_key_values = outputs.past_key_values | |
c = outputs.last_hidden_state[:, -1:] | |
return unnormed_tokens | |
def generate_image( | |
self, | |
captions: str | list[str], | |
images: list[Image.Image] | None = None, | |
num_images_per_caption: int = 1, | |
positive_prompt: str | None = None, | |
negative_prompt: str | None = None, | |
hw: tuple[int, int] = (256, 256), | |
use_norm: bool = False, | |
cfg: float = 1.0, | |
cfg_img: float = 1.0, | |
cfg_schedule: Literal["linear", "constant"] = "constant", | |
num_sampling_steps: int = 20, | |
timesteps_shift: float = 1.0, | |
seed: int = 42, | |
progress: bool = True, | |
) -> list[Image.Image]: | |
# 1. check input | |
captions, images = self._check_input(captions, images) | |
# 2. build captions | |
captions, images = self._build_captions( | |
captions, images, num_images_per_caption, positive_prompt, negative_prompt, cfg, cfg_img | |
) | |
# 3. encode images | |
# `images` must be processed by `process_images` before calling this function | |
latents = None | |
if images is not None: | |
pixel_values = [self.pil2tensor(image) for image in images] | |
pixel_values = torch.stack(pixel_values).to(self.device) | |
with compile_manager.compile_disabled(): | |
posterior = self.vae.encode(pixel_values.to(self.vae.dtype)).latent_dist | |
latents = (posterior.sample() - self.shift_factor) * self.scaling_factor | |
if seed is not None: | |
set_seed(seed) | |
# 4. tokenize caption & add prefix ids | |
output = self.tokenizer(captions, padding="longest", truncation=False, add_special_tokens=True, return_tensors="pt", padding_side="left") | |
input_ids = output.input_ids.to(self.device) | |
attention_mask = output.attention_mask.to(self.device) | |
input_ids, attention_mask = self._add_prefix_ids(hw, input_ids, attention_mask) | |
# 5. LLM prefill | |
max_new_len = (hw[0] // self.down_factor) * (hw[1] // self.down_factor) | |
max_cache_len = input_ids.shape[1] + max_new_len | |
past_key_values = StaticCache( | |
config=self.model.config, | |
max_batch_size=input_ids.shape[0], | |
max_cache_len=max_cache_len, | |
device=self.device, | |
dtype=self.dtype, | |
) | |
inputs_embeds = self.model.prepare_inputs_embeds(input_ids, latents) | |
with compile_manager.compile_disabled(): | |
outputs = self.model.forward_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
use_cache=True, | |
) | |
past_key_values = outputs.past_key_values | |
c = outputs.last_hidden_state[:, -1:] | |
# 6. decoding | |
tokens = self.decoding( | |
c=c, | |
attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
max_new_len=max_new_len, | |
num_images_per_caption=num_images_per_caption, | |
use_norm=use_norm, | |
cfg=cfg, | |
cfg_img=cfg_img, | |
cfg_schedule=cfg_schedule, | |
timesteps_shift=timesteps_shift, | |
num_sampling_steps=num_sampling_steps, | |
progress=progress, | |
) | |
# 7. unpatchify | |
latents = self.model.unpatchify(tokens, h=hw[0] // self.down_factor, w=hw[1] // self.down_factor) | |
latents = (latents / self.scaling_factor) + self.shift_factor | |
# 8. decode latents | |
with compile_manager.compile_disabled(): | |
sampled_images = self.vae.decode(latents.to(self.vae.dtype)).sample | |
sampled_images = sampled_images.detach().cpu().to(torch.float32) | |
pil_images = [to_pil(img) for img in sampled_images] | |
return pil_images |