File size: 5,274 Bytes
491eded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""
Structured Latent Variable Encoder Module
----------------------------------------
This file defines encoder classes for the Structured Latent Variable Autoencoder (SLatVAE).
It contains implementations for the sparse transformer-based encoder that maps input
features to a latent distribution, as well as a memory-efficient elastic version.
The encoder follows a variational approach, outputting means and log variances for
the latent space representation.
"""
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import sparse as sp
from .base import SparseTransformerBase
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
class SLatEncoder(SparseTransformerBase):
"""
Sparse Latent Variable Encoder that uses transformer architecture to encode
sparse data into a latent distribution.
"""
def __init__(
self,
resolution: int,
in_channels: int,
model_channels: int,
latent_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
):
"""
Initialize the Sparse Latent Encoder.
Args:
resolution: Input data resolution
in_channels: Number of input feature channels
model_channels: Number of internal model feature channels
latent_channels: Dimension of the latent space
num_blocks: Number of transformer blocks
num_heads: Number of attention heads (optional)
num_head_channels: Channels per attention head if num_heads is None
mlp_ratio: Expansion ratio for MLP in transformer blocks
attn_mode: Type of attention mechanism to use
window_size: Size of attention windows if using windowed attention
pe_mode: Positional encoding mode (absolute or relative)
use_fp16: Whether to use half-precision floating point
use_checkpoint: Whether to use gradient checkpointing
qk_rms_norm: Whether to apply RMS normalization to query and key
"""
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
num_blocks=num_blocks,
num_heads=num_heads,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
)
self.resolution = resolution
# Output layer projects to twice the latent dimension (for mean and logvar)
self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
def initialize_weights(self) -> None:
"""
Initialize model weights with special handling for output layer.
The output layer weights are initialized to zero to stabilize training.
"""
super().initialize_weights()
# Zero-out output layers for better training stability
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
"""
Forward pass through the encoder.
Args:
x: Input sparse tensor
sample_posterior: Whether to sample from posterior or return mean
return_raw: Whether to return mean and logvar in addition to samples
Returns:
If return_raw is True:
- sampled latent variables, mean, and logvar
Otherwise:
- sampled latent variables only
"""
# Process through transformer blocks
h = super().forward(x)
h = h.type(x.dtype)
# Apply layer normalization to features
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
# Split output into mean and logvar components
mean, logvar = h.feats.chunk(2, dim=-1)
if sample_posterior:
# Reparameterization trick: z = mean + std * epsilon
std = torch.exp(0.5 * logvar)
z = mean + std * torch.randn_like(std)
else:
# Use mean directly without sampling
z = mean
z = h.replace(z)
if return_raw:
return z, mean, logvar
else:
return z
class ElasticSLatEncoder(SparseTransformerElasticMixin, SLatEncoder):
"""
SLat VAE encoder with elastic memory management.
Used for training with low VRAM by dynamically managing memory allocation
and performing operations with reduced memory footprint.
"""
pass
|