"""Model code for FlowMo. Sources: https://github.com/feizc/FluxMusic/blob/main/train.py https://github.com/black-forest-labs/flux/tree/main/src/flux """ import ast import itertools import math from dataclasses import dataclass from typing import List, Tuple import einops import torch from einops import rearrange, repeat from mup import MuReadout from torch import Tensor, nn import argparse import contextlib import copy import glob import os import subprocess import tempfile import time import fsspec import psutil import torch import torch.distributed as dist from mup import MuReadout, set_base_shapes from omegaconf import OmegaConf from torch.utils.data import DataLoader from .lookup_free_quantize import LFQ MUP_ENABLED = True def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: b, h, l, d = q.shape q, k = apply_rope(q, k, pe) if torch.__version__ == "2.0.1+cu117": # tmp workaround if d != 64: print("MUP is broken in this setting! Be careful!") x = torch.nn.functional.scaled_dot_product_attention( q, k, v, ) else: x = torch.nn.functional.scaled_dot_product_attention( q, k, v, scale=8.0 / d if MUP_ENABLED else None, ) assert x.shape == q.shape x = rearrange(x, "B H L D -> B L (H D)") return x def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta**scale) out = torch.einsum("...n,d->...nd", pos, omega) out = torch.stack( [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1, ) out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.float() def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) def _get_diagonal_gaussian(parameters): mean, logvar = torch.chunk(parameters, 2, dim=1) logvar = torch.clamp(logvar, -30.0, 20.0) return mean, logvar def _sample_diagonal_gaussian(mean, logvar): std = torch.exp(0.5 * logvar) x = mean + std * torch.randn(mean.shape, device=mean.device) return x def _kl_diagonal_gaussian(mean, logvar): var = torch.exp(logvar) return 0.5 * torch.sum(torch.pow(mean, 2) + var - 1.0 - logvar, dim=1).mean() class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim): super().__init__() self.dim = dim self.theta = theta self.axes_dim = axes_dim def forward(self, ids: Tensor) -> Tensor: n_axes = ids.shape[-1] emb = torch.cat( [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3, ) return emb.unsqueeze(1) def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ t = time_factor * t half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) if torch.is_floating_point(t): embedding = embedding.to(t) return embedding class MLPEmbedder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int): super().__init__() self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) self.silu = nn.SiLU() self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) class RMSNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: Tensor): x_dtype = x.dtype x = x.float() rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) return (x * rrms).to(dtype=x_dtype) * self.scale class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() self.query_norm = RMSNorm(dim) self.key_norm = RMSNorm(dim) def forward(self, q: Tensor, k: Tensor, v: Tensor): q = self.query_norm(q) k = self.key_norm(k) return q.to(v), k.to(v) class SelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) def forward(self, x: Tensor, pe: Tensor) -> Tensor: qkv = self.qkv(x) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) x = attention(q, k, v, pe=pe) x = self.proj(x) return x @dataclass class ModulationOut: shift: Tensor scale: Tensor gate: Tensor class Modulation(nn.Module): def __init__(self, dim: int, double: bool): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) self.lin.weight[dim * 2 : dim * 3].data[:] = 0.0 self.lin.bias[dim * 2 : dim * 3].data[:] = 0.0 self.lin.weight[dim * 5 : dim * 6].data[:] = 0.0 self.lin.bias[dim * 5 : dim * 6].data[:] = 0.0 def forward(self, vec: Tensor) -> Tuple[ModulationOut, ModulationOut]: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( self.multiplier, dim=-1 ) return ( ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None, ) class DoubleStreamBlock(nn.Module): def __init__( self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, ): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size self.img_mod = Modulation(hidden_size, double=True) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_attn = SelfAttention( dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias ) self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.GELU(approximate="tanh"), nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) self.txt_mod = Modulation(hidden_size, double=True) self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_attn = SelfAttention( dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias ) self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.GELU(approximate="tanh"), nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): pe_single, pe_double = pe p = 1 if vec is None: img_mod1, img_mod2 = ModulationOut(0, 1 - p, 1), ModulationOut(0, 1 - p, 1) txt_mod1, txt_mod2 = ModulationOut(0, 1 - p, 1), ModulationOut(0, 1 - p, 1) else: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) # prepare image for attention img_modulated = self.img_norm1(img) img_modulated = (p + img_mod1.scale) * img_modulated + img_mod1.shift img_qkv = self.img_attn.qkv(img_modulated) img_q, img_k, img_v = rearrange( img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads ) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) txt_modulated = (p + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_qkv = self.txt_attn.qkv(txt_modulated) txt_q, txt_k, txt_v = rearrange( txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads ) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention q = torch.cat((txt_q, img_q), dim=2) k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) attn = attention(q, k, v, pe=pe_double) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod2.gate * self.img_mlp( (p + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift ) # calculate the txt bloks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp( (p + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift ) return img, txt class LastLayer(nn.Module): def __init__( self, hidden_size: int, patch_size: int, out_channels: int, readout_zero_init=False, ): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) if MUP_ENABLED: self.linear = MuReadout( hidden_size, patch_size * patch_size * out_channels, bias=True, readout_zero_init=readout_zero_init, ) else: self.linear = nn.Linear( hidden_size, patch_size * patch_size * out_channels, bias=True ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x: Tensor, vec) -> Tensor: if vec is None: pass else: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] x = self.norm_final(x) x = self.linear(x) return x @dataclass class FluxParams: in_channels: int patch_size: int context_dim: int hidden_size: int mlp_ratio: float num_heads: int depth: int axes_dim: List[int] theta: int qkv_bias: bool DIT_ZOO = dict( dit_xl_4=dict( hidden_size=1152, mlp_ratio=4.0, num_heads=16, axes_dim=[8, 28, 28], theta=10_000, qkv_bias=True, ), dit_l_4=dict( hidden_size=1024, mlp_ratio=4.0, num_heads=16, axes_dim=[8, 28, 28], theta=10_000, qkv_bias=True, ), dit_b_4=dict( hidden_size=768, mlp_ratio=4.0, num_heads=12, axes_dim=[8, 28, 28], theta=10_000, qkv_bias=True, ), dit_s_4=dict( hidden_size=384, mlp_ratio=4.0, num_heads=6, axes_dim=[8, 28, 28], theta=10_000, qkv_bias=True, ), dit_mup_test=dict( hidden_size=768, mlp_ratio=4.0, num_heads=12, axes_dim=[8, 28, 28], theta=10_000, qkv_bias=True, ), ) def prepare_idxs(img, code_length, patch_size): bs, c, h, w = img.shape img_ids = torch.zeros(h // patch_size, w // patch_size, 3, device=img.device) img_ids[..., 1] = ( img_ids[..., 1] + torch.arange(h // patch_size, device=img.device)[:, None] ) img_ids[..., 2] = ( img_ids[..., 2] + torch.arange(w // patch_size, device=img.device)[None, :] ) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = ( torch.zeros((bs, code_length, 3), device=img.device) + torch.arange(code_length, device=img.device)[None, :, None] ) return img_ids, txt_ids class Flux(nn.Module): """ Transformer model for flow matching on sequences. """ def __init__(self, params: FluxParams, name="", lsg=False): super().__init__() self.name = name self.lsg = lsg self.params = params self.in_channels = params.in_channels self.patch_size = params.patch_size self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" ) pe_dim = params.hidden_size // params.num_heads if sum(params.axes_dim) != pe_dim: raise ValueError( f"Got {params.axes_dim} but expected positional dim {pe_dim}" ) self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND( dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim ) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) self.txt_in = nn.Linear(params.context_dim, self.hidden_size) self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, ) for idx in range(params.depth) ] ) self.final_layer_img = LastLayer( self.hidden_size, 1, self.out_channels, readout_zero_init=False ) self.final_layer_txt = LastLayer( self.hidden_size, 1, params.context_dim, readout_zero_init=False ) def forward( self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, ) -> Tensor: b, c, h, w = img.shape img = rearrange( img, "b c (gh ph) (gw pw) -> b (gh gw) (ph pw c)", ph=self.patch_size, pw=self.patch_size, ) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") img = self.img_in(img) if timesteps is None: vec = None else: vec = self.time_in(timestep_embedding(timesteps, 256)) txt = self.txt_in(txt) pe_single = self.pe_embedder(torch.cat((txt_ids,), dim=1)) pe_double = self.pe_embedder(torch.cat((txt_ids, img_ids), dim=1)) for block in self.double_blocks: img, txt = block(img=img, txt=txt, pe=(pe_single, pe_double), vec=vec) img = self.final_layer_img(img, vec=vec) img = rearrange( img, "b (gh gw) (ph pw c) -> b c (gh ph) (gw pw)", ph=self.patch_size, pw=self.patch_size, gh=h // self.patch_size, gw=w // self.patch_size, ) txt = self.final_layer_txt(txt, vec=vec) return img, txt, {"final_txt": txt} def get_weights_to_fix(model): with torch.no_grad(): for name, module in itertools.chain(model.named_modules()): if "double_blocks" in name and isinstance(module, torch.nn.Linear): yield name, module.weight class FlowMo(nn.Module): def __init__(self, width, config): super().__init__() code_length = config.model.code_length context_dim = config.model.context_dim enc_depth = config.model.enc_depth dec_depth = config.model.dec_depth patch_size = config.model.patch_size self.config = config self.image_size = config.data.image_size self.patch_size = config.model.patch_size self.code_length = code_length self.dit_mode = "dit_b_4" self.context_dim = context_dim self.encoder_context_dim = context_dim * ( 1 + (self.config.model.quantization_type == "kl") ) if config.model.quantization_type == "lfq": self.quantizer = LFQ( codebook_size=2**self.config.model.codebook_size_for_entropy, dim=self.config.model.codebook_size_for_entropy, num_codebooks=1, token_factorization=False, ) if self.config.model.enc_mup_width is not None: enc_width = self.config.model.enc_mup_width else: enc_width = width encoder_params = FluxParams( in_channels=3 * patch_size**2, context_dim=self.encoder_context_dim, patch_size=patch_size, depth=enc_depth, **DIT_ZOO[self.dit_mode], ) decoder_params = FluxParams( in_channels=3 * patch_size**2, context_dim=context_dim + 1, patch_size=patch_size, depth=dec_depth, **DIT_ZOO[self.dit_mode], ) # width=4, dit_b_4 is the usual model encoder_params.hidden_size = enc_width * (encoder_params.hidden_size // 4) decoder_params.hidden_size = width * (decoder_params.hidden_size // 4) encoder_params.axes_dim = [ (d // 4) * enc_width for d in encoder_params.axes_dim ] decoder_params.axes_dim = [(d // 4) * width for d in decoder_params.axes_dim] self.encoder = Flux(encoder_params, name="encoder") self.decoder = Flux(decoder_params, name="decoder") @torch.compile def encode(self, img): b, c, h, w = img.shape img_idxs, txt_idxs = prepare_idxs(img, self.code_length, self.patch_size) txt = torch.zeros( (b, self.code_length, self.encoder_context_dim), device=img.device ) _, code, aux = self.encoder(img, img_idxs, txt, txt_idxs, timesteps=None) return code, aux def _decode(self, img, code, timesteps): b, c, h, w = img.shape img_idxs, txt_idxs = prepare_idxs( img, self.code_length, self.patch_size, ) pred, _, decode_aux = self.decoder( img, img_idxs, code, txt_idxs, timesteps=timesteps ) return pred, decode_aux @torch.compile def decode(self, *args, **kwargs): return self._decode(*args, **kwargs) @torch.compile def decode_checkpointed(self, *args, **kwargs): # Need to compile(checkpoint), not checkpoint(compile) assert not kwargs, kwargs return torch.utils.checkpoint.checkpoint( self._decode, *args, # WARNING: Do not use_reentrant=True with compile, it will silently # produce incorrect gradients! use_reentrant=False, ) @torch.compile def _quantize(self, code): """ Args: code: [b codelength context dim] Returns: quantized code of the same shape """ b, t, f = code.shape indices = None if self.config.model.quantization_type == "noop": quantized = code quantizer_loss = torch.tensor(0.0).to(code.device) elif self.config.model.quantization_type == "kl": # colocating features of same token before split is maybe slightly # better? mean, logvar = _get_diagonal_gaussian( einops.rearrange(code, "b t f -> b (f t)") ) code = einops.rearrange( _sample_diagonal_gaussian(mean, logvar), "b (f t) -> b t f", f=f // 2, t=t, ) quantizer_loss = _kl_diagonal_gaussian(mean, logvar) elif self.config.model.quantization_type == "lfq": assert f % self.config.model.codebook_size_for_entropy == 0, f code = einops.rearrange( code, "b t (fg fh) -> b fg (t fh)", fg=self.config.model.codebook_size_for_entropy, ) (quantized, entropy_aux_loss, indices), breakdown = self.quantizer( code, return_loss_breakdown=True ) assert quantized.shape == code.shape quantized = einops.rearrange(quantized, "b fg (t fh) -> b t (fg fh)", t=t) quantizer_loss = ( entropy_aux_loss * self.config.model.entropy_loss_weight + breakdown.commitment * self.config.model.commit_loss_weight ) code = quantized else: raise NotImplementedError return code, indices, quantizer_loss # def forward( # self, # img, # noised_img, # timesteps, # enable_cfg=True, # ): # aux = {} # # code, encode_aux = self.encode(img) # # aux["original_code"] = code # # b, t, f = code.shape # # code, _, aux["quantizer_loss"] = self._quantize(code) # # mask = torch.ones_like(code[..., :1]) # code = torch.concatenate([code, mask], axis=-1) # code_pre_cfg = code # # if self.config.model.enable_cfg and enable_cfg: # cfg_mask = (torch.rand((b,), device=code.device) > 0.1)[:, None, None] # code = code * cfg_mask # # v_est, decode_aux = self.decode(noised_img, code, timesteps) # aux.update(decode_aux) # # if self.config.model.posttrain_sample: # aux["posttrain_sample"] = self.reconstruct_checkpoint(code_pre_cfg) # # return v_est, aux def forward(self, img): return self.reconstruct(img) def reconstruct_checkpoint(self, code): with torch.autocast( "cuda", dtype=torch.bfloat16, ): bs, *_ = code.shape z = torch.randn((bs, 3, self.image_size, self.image_size)).cuda() ts = ( torch.rand((bs, self.config.model.posttrain_sample_k + 1)) .cumsum(dim=1) .cuda() ) ts = ts - ts[:, :1] ts = (ts / ts[:, -1:]).flip(dims=(1,)) dts = ts[:, :-1] - ts[:, 1:] for i, (t, dt) in enumerate((zip(ts.T, dts.T))): if self.config.model.posttrain_sample_enable_cfg: mask = (torch.rand((bs,), device=code.device) > 0.1)[ :, None, None ].to(code.dtype) code_t = code * mask else: code_t = code vc, _ = self.decode_checkpointed(z, code_t, t) z = z - dt[:, None, None, None] * vc return z @torch.no_grad() def reconstruct(self, images, dtype=torch.bfloat16, code=None): """ Args: images in [bchw] [-1, 1] Returns: images in [bchw] [-1, 1] """ model = self config = self.config.eval.sampling with torch.autocast( "cuda", dtype=dtype, ): bs, c, h, w = images.shape if code is None: x = images.cuda() prequantized_code = model.encode(x)[0].cuda() code, indices, _ = model._quantize(prequantized_code) z = torch.randn((bs, 3, h, w)).cuda() mask = torch.ones_like(code[..., :1]) code = torch.concatenate([code * mask, mask], axis=-1) cfg_mask = 0.0 null_code = code * cfg_mask if config.cfg != 1.0 else None samples = rf_sample( model, z, code, null_code=null_code, sample_steps=config.sample_steps, cfg=config.cfg, schedule=config.schedule, )[-1].clip(-1, 1) return samples.to(torch.float32), code, prequantized_code def rf_loss(config, model, batch, aux_state): x = batch["image"] b = x.size(0) if config.opt.schedule == "lognormal": nt = torch.randn((b,)).to(x.device) t = torch.sigmoid(nt) elif config.opt.schedule == "fat_lognormal": nt = torch.randn((b,)).to(x.device) t = torch.sigmoid(nt) t = torch.where(torch.rand_like(t) <= 0.9, t, torch.rand_like(t)) elif config.opt.schedule == "uniform": t = torch.rand((b,), device=x.device) elif config.opt.schedule.startswith("debug"): p = float(config.opt.schedule.split("_")[1]) t = torch.ones((b,), device=x.device) * p else: raise NotImplementedError t = t.view([b, *([1] * len(x.shape[1:]))]) z1 = torch.randn_like(x) zt = (1 - t) * x + t * z1 zt, t = zt.to(x.dtype), t.to(x.dtype) vtheta, aux = model( img=x, noised_img=zt, timesteps=t.reshape((b,)), ) diff = z1 - vtheta - x x_pred = zt - vtheta * t loss = ((diff) ** 2).mean(dim=list(range(1, len(x.shape)))) loss = loss.mean() aux["loss_dict"] = {} aux["loss_dict"]["diffusion_loss"] = loss aux["loss_dict"]["quantizer_loss"] = aux["quantizer_loss"] if config.opt.lpips_weight != 0.0: aux_loss = 0.0 if config.model.posttrain_sample: x_pred = aux["posttrain_sample"] lpips_dist = aux_state["lpips_model"](x, x_pred) lpips_dist = (config.opt.lpips_weight * lpips_dist).mean() + aux_loss aux["loss_dict"]["lpips_loss"] = lpips_dist else: lpips_dist = 0.0 loss = loss + aux["quantizer_loss"] + lpips_dist aux["loss_dict"]["total_loss"] = loss return loss, aux def _edm_to_flow_convention(noise_level): # z = x + \sigma z' return noise_level / (1 + noise_level) def rf_sample( model, z, code, null_code=None, sample_steps=25, cfg=2.0, schedule="linear", ): b = z.size(0) if schedule == "linear": ts = torch.arange(1, sample_steps + 1).flip(0) / sample_steps dts = torch.ones_like(ts) * (1.0 / sample_steps) elif schedule.startswith("pow"): p = float(schedule.split("_")[1]) ts = torch.arange(0, sample_steps + 1).flip(0) ** (1 / p) / sample_steps ** ( 1 / p ) dts = ts[:-1] - ts[1:] else: raise NotImplementedError if model.config.eval.sampling.cfg_interval is None: interval = None else: cfg_lo, cfg_hi = ast.literal_eval(model.config.eval.sampling.cfg_interval) interval = _edm_to_flow_convention(cfg_lo), _edm_to_flow_convention(cfg_hi) images = [] for i, (t, dt) in enumerate((zip(ts, dts))): timesteps = torch.tensor([t] * b).to(z.device) vc, decode_aux = model.decode(img=z, timesteps=timesteps, code=code) if null_code is not None and ( interval is None or ((t.item() >= interval[0]) and (t.item() <= interval[1])) ): vu, _ = model.decode(img=z, timesteps=timesteps, code=null_code) vc = vu + cfg * (vc - vu) z = z - dt * vc images.append(z) return images def build_model(config): with tempfile.TemporaryDirectory() as log_dir: MUP_ENABLED = config.model.enable_mup model_partial = FlowMo shared_kwargs = dict(config=config) model = model_partial( **shared_kwargs, width=config.model.mup_width, ).cuda() if config.model.enable_mup: print("Mup enabled!") with torch.device("cpu"): base_model = model_partial( **shared_kwargs, width=config.model.mup_width ) delta_model = model_partial( **shared_kwargs, width=( config.model.mup_width * 4 if config.model.mup_width == 1 else 1 ), ) true_model = model_partial( **shared_kwargs, width=config.model.mup_width ) if torch.distributed.is_initialized(): bsh_path = os.path.join(log_dir, f"{dist.get_rank()}.bsh") else: bsh_path = os.path.join(log_dir, "0.bsh") set_base_shapes( true_model, base_model, delta=delta_model, savefile=bsh_path ) model = set_base_shapes(model, base=bsh_path) for module in model.modules(): if isinstance(module, MuReadout): module.width_mult = lambda: module.weight.infshape.width_mult() return model