UniPic2-Metaquery / unipicv2 /stable_diffusion_3_conditioner.py
yichenchenchen's picture
Upload 7 files
a0250fc verified
import torch
import torch.nn as nn
# from transformers.modeling_utils import PreTrainedModel
from diffusers.configuration_utils import register_to_config, ConfigMixin
from unipicv2.modeling_connector import ConnectorEncoder
from unipicv2.configuration_connector import ConnectorConfig
from diffusers.models.modeling_utils import ModelMixin
class StableDiffusion3Conditioner(ModelMixin, ConfigMixin):
model_type: str = "sd3_conditioner" # stored into config for hub niceties
@register_to_config
def __init__(
self,
connector_config: dict, # dict passed to ConnectorConfig(**connector)
num_queries: int = 256,
llm_hidden_size: int = 3584,
pooled_projection_dim: int = 2048,
joint_attention_dim: int = 4096,
):
super().__init__()
self.connector = ConnectorEncoder(ConnectorConfig(**connector_config))
self.projector_1 = nn.Linear(llm_hidden_size, self.connector.config.hidden_size)
self.projector_2 = nn.Linear(self.connector.config.hidden_size, pooled_projection_dim)
self.projector_3 = nn.Linear(self.connector.config.hidden_size, joint_attention_dim)
self.meta_queries = nn.Parameter(torch.zeros(num_queries, llm_hidden_size))
def _init_weights(self, module):
pass
def forward(self, x: torch.Tensor):
"""
x: (batch, seq_len, llm_hidden_size)
Returns:
prompt_embeds: (batch, seq_len, joint_attention_dim)
pooled_prompt_embeds: (batch, pooled_projection_dim)
"""
x = self.projector_1(x)
x = self.connector(x) # expects (B, L, hidden)
pooled_prompt_embeds = self.projector_2(x.mean(1))
prompt_embeds = self.projector_3(x)
return prompt_embeds, pooled_prompt_embeds
if __name__ == "__main__":
import torch
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--output", type=str, default=None)
args = parser.parse_args()
pretrained_model_name_or_path = "stabilityai/stable-diffusion-3.5-medium"
conditioner = StableDiffusion3Conditioner(
num_queries=256,
connector_config=dict(
hidden_size=1536,
intermediate_size=8960,
num_hidden_layers=24,
_attn_implementation='flash_attention_2',
num_attention_heads=24, ),
llm_hidden_size=3584,
pooled_projection_dim=2048,
joint_attention_dim=4096,
).bfloat16()
checkpoint = torch.load(args.checkpoint)
info = conditioner.load_state_dict(checkpoint, strict=False)
import pdb; pdb.set_trace()
os.makedirs(args.output, exist_ok=True)
conditioner.save_pretrained(args.output)