Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Dict, List, Optional, Set, Tuple, Union | |
from dataclasses import dataclass | |
from inspect import isfunction | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from diffusers.models.modeling_utils import ModelMixin | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.models.embeddings import TimestepEmbedding, Timesteps | |
from src.utils.data_utils import pad_to_square, pad_to_target | |
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPVisionModel | |
from collections import OrderedDict | |
class SquaredReLU(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return torch.square(torch.relu(x)) | |
class AdaLayerNorm(nn.Module): | |
def __init__(self, embedding_dim: int, time_embedding_dim: Optional[int] = None, ln_bias=True): | |
super().__init__() | |
if time_embedding_dim is None: | |
time_embedding_dim = embedding_dim | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True) | |
nn.init.zeros_(self.linear.weight) | |
nn.init.zeros_(self.linear.bias) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6, bias=ln_bias) | |
def forward( | |
self, x: torch.Tensor, timestep_embedding: torch.Tensor | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
emb = self.linear(self.silu(timestep_embedding)) | |
shift, scale = emb.view(len(x), 1, -1).chunk(2, dim=-1) | |
x = self.norm(x) * (1 + scale) + shift | |
return x | |
class PerceiverAttentionBlock(nn.Module): | |
def __init__( | |
self, d_model: int, n_heads: int, | |
time_embedding_dim: Optional[int] = None, | |
double_kv: Optional[bool] = True, | |
): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True) | |
self.n_heads = n_heads | |
self.mlp = nn.Sequential( | |
OrderedDict( | |
[ | |
("c_fc", nn.Linear(d_model, d_model * 4)), | |
("sq_relu", SquaredReLU()), | |
("c_proj", nn.Linear(d_model * 4, d_model)), | |
] | |
) | |
) | |
self.double_kv = double_kv | |
self.ln_1 = AdaLayerNorm(d_model, time_embedding_dim) | |
self.ln_2 = AdaLayerNorm(d_model, time_embedding_dim) | |
self.ln_ff = AdaLayerNorm(d_model, time_embedding_dim) | |
def attention(self, q: torch.Tensor, kv: torch.Tensor, attn_mask: torch.Tensor = None): | |
attn_output, attn_output_weights = self.attn(q, kv, kv, need_weights=False, key_padding_mask=attn_mask) | |
return attn_output | |
def forward( | |
self, | |
x: torch.Tensor, | |
latents: torch.Tensor, | |
timestep_embedding: torch.Tensor = None, | |
attn_mask: torch.Tensor = None | |
): | |
normed_latents = self.ln_1(latents, timestep_embedding) | |
normed_x = self.ln_2(x, timestep_embedding) | |
if self.double_kv: | |
kv = torch.cat([normed_latents, normed_x], dim=1) | |
else: | |
kv = normed_x | |
attn = self.attention( | |
q=normed_latents, | |
kv=kv, | |
attn_mask=attn_mask, | |
) | |
if attn_mask is not None: | |
query_padding_mask = attn_mask.chunk(2, -1)[0].unsqueeze(-1) # (B, 2S) -> (B, S, 1) | |
latents = latents + attn * (~query_padding_mask).to(attn) | |
else: | |
latents = latents + attn | |
latents = latents + self.mlp(self.ln_ff(latents, timestep_embedding)) | |
return latents | |
class CLIPModAdapter(ModelMixin, ConfigMixin): | |
def __init__( | |
self, | |
out_dim=3072, | |
width=1024, | |
pblock_width=512, | |
layers=6, | |
pblock_layers=1, | |
heads=8, | |
input_text_dim=4096, | |
input_image_dim=1024, | |
pblock_single_blocks=0, | |
): | |
super().__init__() | |
self.out_dim = out_dim | |
self.net = TextImageResampler( | |
width=width, | |
layers=layers, | |
heads=heads, | |
input_text_dim=input_text_dim, | |
input_image_dim=input_image_dim, | |
time_embedding_dim=64, | |
output_dim=out_dim, | |
) | |
self.net2 = TextImageResampler( | |
width=pblock_width, | |
layers=pblock_layers, | |
heads=heads, | |
input_text_dim=input_text_dim, | |
input_image_dim=input_image_dim, | |
time_embedding_dim=64, | |
output_dim=out_dim*(19+pblock_single_blocks), | |
) | |
def enable_gradient_checkpointing(self): | |
self.gradient_checkpointing = True | |
self.net.enable_gradient_checkpointing() | |
self.net2.enable_gradient_checkpointing() | |
def forward(self, t_emb, llm_hidden_states, clip_outputs): | |
if len(llm_hidden_states.shape) > 3: | |
llm_hidden_states = llm_hidden_states[..., -1, :] | |
batch_size, seq_length = llm_hidden_states.shape[:2] | |
img_cls_feat = clip_outputs["image_embeds"] # (B, 768) | |
img_last_feat = clip_outputs["last_hidden_state"] # (B, 257, 1024) | |
img_layer_feats = clip_outputs["hidden_states"] # [(B, 257, 1024) * 25] | |
img_second_last_feat = img_layer_feats[-2] # (B, 257, 1024) | |
img_hidden_states = img_second_last_feat # (B, 257, 1024) | |
x = self.net(llm_hidden_states, img_hidden_states) # (B, S, 3072) | |
x2 = self.net2(llm_hidden_states, img_hidden_states).view(batch_size, seq_length, -1, self.out_dim) # (B, S, N, 3072) | |
return x, x2 | |
class TextImageResampler(nn.Module): | |
def __init__( | |
self, | |
width: int = 768, | |
layers: int = 6, | |
heads: int = 8, | |
output_dim: int = 3072, | |
input_text_dim: int = 4096, | |
input_image_dim: int = 1024, | |
time_embedding_dim: int = 64, | |
): | |
super().__init__() | |
self.output_dim = output_dim | |
self.input_text_dim = input_text_dim | |
self.input_image_dim = input_image_dim | |
self.time_embedding_dim = time_embedding_dim | |
self.text_proj_in = nn.Linear(input_text_dim, width) | |
self.image_proj_in = nn.Linear(input_image_dim, width) | |
self.perceiver_blocks = nn.Sequential( | |
*[ | |
PerceiverAttentionBlock( | |
width, heads, time_embedding_dim=self.time_embedding_dim | |
) | |
for _ in range(layers) | |
] | |
) | |
self.proj_out = nn.Sequential( | |
nn.Linear(width, output_dim), nn.LayerNorm(output_dim) | |
) | |
self.gradient_checkpointing = False | |
def enable_gradient_checkpointing(self): | |
self.gradient_checkpointing = True | |
def forward( | |
self, | |
text_hidden_states: torch.Tensor, | |
image_hidden_states: torch.Tensor, | |
): | |
timestep_embedding = torch.zeros((text_hidden_states.shape[0], 1, self.time_embedding_dim)).to(text_hidden_states) | |
text_hidden_states = self.text_proj_in(text_hidden_states) | |
image_hidden_states = self.image_proj_in(image_hidden_states) | |
for p_block in self.perceiver_blocks: | |
if self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
text_hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(p_block), | |
image_hidden_states, | |
text_hidden_states, | |
timestep_embedding | |
) | |
else: | |
text_hidden_states = p_block(image_hidden_states, text_hidden_states, timestep_embedding=timestep_embedding) | |
text_hidden_states = self.proj_out(text_hidden_states) | |
return text_hidden_states | |