File size: 6,879 Bytes
491eded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
This file implements radiance field decoders for Structured Latent VAE models.
The main class SLatRadianceFieldDecoder is a sparse transformer-based decoder that 
transforms latent codes into sparse representations of 3D scenes (Strivec representation).
It also includes an elastic memory version (ElasticSLatRadianceFieldDecoder) for low VRAM training.
"""

from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ...modules import sparse as sp
from .base import SparseTransformerBase
from ...representations import Strivec
from ..sparse_elastic_mixin import SparseTransformerElasticMixin


class SLatRadianceFieldDecoder(SparseTransformerBase):
    """
    A sparse transformer-based decoder for converting latent codes to radiance field representations.
    This decoder processes sparse tensors through transformer blocks and outputs parameters for Strivec representation.
    """
    def __init__(
        self,
        resolution: int,  # Resolution of the output 3D grid
        model_channels: int,  # Number of channels in the model's hidden layers
        latent_channels: int,  # Number of channels in the latent code
        num_blocks: int,  # Number of transformer blocks
        num_heads: Optional[int] = None,  # Number of attention heads
        num_head_channels: Optional[int] = 64,  # Channels per attention head
        mlp_ratio: float = 4,  # Ratio for MLP hidden dimension
        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",  # Attention mode
        window_size: int = 8,  # Size of local attention window
        pe_mode: Literal["ape", "rope"] = "ape",  # Positional encoding mode
        use_fp16: bool = False,  # Whether to use half precision
        use_checkpoint: bool = False,  # Whether to use gradient checkpointing
        qk_rms_norm: bool = False,  # Whether to normalize query and key
        representation_config: dict = None,  # Configuration for output representation
    ):
        # Initialize the base sparse transformer
        super().__init__(
            in_channels=latent_channels,
            model_channels=model_channels,
            num_blocks=num_blocks,
            num_heads=num_heads,
            num_head_channels=num_head_channels,
            mlp_ratio=mlp_ratio,
            attn_mode=attn_mode,
            window_size=window_size,
            pe_mode=pe_mode,
            use_fp16=use_fp16,
            use_checkpoint=use_checkpoint,
            qk_rms_norm=qk_rms_norm,
        )
        self.resolution = resolution
        self.rep_config = representation_config
        self._calc_layout()  # Calculate the output layout
        # Final layer to project features to the output representation
        self.out_layer = sp.SparseLinear(model_channels, self.out_channels)

        self.initialize_weights()
        if use_fp16:
            self.convert_to_fp16()

    def initialize_weights(self) -> None:
        """
        Initialize the weights of the model.
        Zero-initializes the output layer for better training stability.
        """
        super().initialize_weights()
        # Zero-out output layers for better training stability
        nn.init.constant_(self.out_layer.weight, 0)
        nn.init.constant_(self.out_layer.bias, 0)

    def _calc_layout(self) -> None:
        """
        Calculate the output tensor layout for the Strivec representation.
        Defines the shapes and sizes of different components and their positions in the output tensor.
        """
        self.layout = {
            'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
            'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
            'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
        }
        # Calculate the range (start, end) indices for each component in the output tensor
        start = 0
        for k, v in self.layout.items():
            v['range'] = (start, start + v['size'])
            start += v['size']
        self.out_channels = start    
    
    def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
        """
        Convert a batch of network outputs to 3D representations.

        Args:
            x: The [N x * x C] sparse tensor output by the network.

        Returns:
            list of Strivec representations, one per batch item
        """
        ret = []
        for i in range(x.shape[0]):
            # Create a new Strivec representation
            representation = Strivec(
                sh_degree=0,
                resolution=self.resolution,
                aabb=[-0.5, -0.5, -0.5, 1, 1, 1],  # Axis-aligned bounding box
                rank=self.rep_config['rank'],
                dim=self.rep_config['dim'],
                device='cuda',
            )
            representation.density_shift = 0.0
            # Set position from sparse coordinates (normalized to [0,1])
            representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
            # Set depth (octree level) based on resolution
            representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
            
            # Extract each component from the output features according to the layout
            for k, v in self.layout.items():
                setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
            
            # Add 1 to trivec for stability (prevent zero vectors)
            representation.trivec = representation.trivec + 1
            ret.append(representation)
        return ret

    def forward(self, x: sp.SparseTensor) -> List[Strivec]:
        """
        Forward pass through the decoder.
        
        Args:
            x: Input sparse tensor containing latent codes
            
        Returns:
            List of Strivec representations
        """
        # Pass through transformer backbone
        h = super().forward(x)
        h = h.type(x.dtype)
        # Layer normalization on feature dimension
        h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
        # Final projection to output features
        h = self.out_layer(h)
        # Convert network output to Strivec representations
        return self.to_representation(h)


class ElasticSLatRadianceFieldDecoder(SparseTransformerElasticMixin, SLatRadianceFieldDecoder):
    """
    Slat VAE Radiance Field Decoder with elastic memory management.
    Used for training with low VRAM by dynamically managing memory allocation
    and performing operations in chunks when needed.
    """
    pass