Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,245 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
from copy import copy
import ignite.distributed as idist
from torch import optim
from torch.utils.data import DataLoader, Subset
from scenedino.training.base_trainer import base_training
# TODO: change dataset
from scenedino.datasets import make_datasets
from scenedino.common.scheduler import make_scheduler
from scenedino.renderer import NeRFRenderer
from scenedino.models.backbones.dino.dinov2_module import *
from scenedino.training.trainer import BTSWrapper
from scenedino.models import make_model
from scenedino.common.ray_sampler import get_ray_sampler
from scenedino.losses import make_loss
class EncoderDummy(nn.Module):
def __init__(self, size, feat_dim, num_views=1) -> None:
super().__init__() ## initializes this feature map as a random tensor of a specified size
self.feats = nn.Parameter(torch.randn(num_views, feat_dim, *size))
self.latent_size = feat_dim
def forward(self, x):
n = x.shape[0]
return [self.feats.expand(n, -1, -1, -1)]
class EncoderDinoDummy(nn.Module):
def __init__(self,
mode: str, # downsample-prediction, upsample-gt
decoder_arch: str, # nearest, bilinear, sfp, dpt
upsampler_arch: Optional[str], # nearest, bilinear, multiscale-crop
downsampler_arch: Optional[str], # sample-center, featup
encoder_arch: str, # vit-s, vit-b, fit3d-s
separate_gt_encoder_arch: Optional[str], # vit-s, vit-b, fit3d-s, None (reuses encoder)
encoder_freeze: bool,
dim_reduction_arch: str, # orthogonal-linear, mlp
num_ch_enc: np.array,
intermediate_features: List[int],
decoder_out_dim: int,
dino_pca_dim: int,
image_size: Tuple[int, int],
key_features: bool,
):
super().__init__()
self.feats = nn.Parameter(torch.randn(1, decoder_out_dim, *image_size))
self.latent_size = decoder_out_dim
if separate_gt_encoder_arch is None:
self.gt_encoder = build_encoder(encoder_arch, image_size, [], key_features) # ONLY IN OVERFIT DUMMY!
else:
self.gt_encoder = build_encoder(separate_gt_encoder_arch, image_size, [], key_features)
for p in self.gt_encoder.parameters(True):
p.requires_grad = False
# General way of creating loss
if mode == "downsample-prediction":
assert upsampler_arch is None
self.downsampler = build_downsampler(downsampler_arch, self.gt_encoder.latent_size)
self.gt_wrapper = self.gt_encoder
elif mode == "upsample-gt":
assert downsampler_arch is None
self.downsampler = None
self.gt_wrapper = build_gt_upsampling_wrapper(upsampler_arch, self.gt_encoder, image_size)
else:
raise NotImplementedError
self.extra_outs = 0
self.latent_size = decoder_out_dim
self.dino_pca_dim = dino_pca_dim
self.dim_reduction = build_dim_reduction(dim_reduction_arch, self.gt_encoder.latent_size, dino_pca_dim)
self.visualization = VisualizationModule(self.gt_encoder.latent_size)
def forward(self, x, ground_truth=False):
if ground_truth:
return self.gt_wrapper(x)
return [self.feats.expand(x.shape[0], -1, -1, -1)]
def downsample(self, x, mode="patch"):
if self.downsampler is None:
return None
else:
return self.downsampler(x, mode)
def expand_dim(self, features):
return self.dim_reduction.transform_expand(features)
def fit_visualization(self, features, refit=True):
return self.visualization.fit_pca(features, refit)
def transform_visualization(self, features, norm=False, from_dim=0):
return self.visualization.transform_pca(features, norm, from_dim)
def fit_transform_kmeans_visualization(self, features):
return self.visualization.fit_transform_kmeans_batch(features)
@classmethod
def from_conf(cls, conf):
return cls(
mode=conf.mode,
decoder_arch=conf.decoder_arch,
upsampler_arch=conf.get("upsampler_arch", None),
downsampler_arch=conf.get("downsampler_arch", None),
encoder_arch=conf.encoder_arch,
separate_gt_encoder_arch=conf.get("separate_gt_encoder_arch", None),
encoder_freeze=conf.encoder_freeze,
dim_reduction_arch=conf.dim_reduction_arch,
num_ch_enc=conf.get("num_ch_enc", None),
intermediate_features=conf.get("intermediate_features", []),
decoder_out_dim=conf.decoder_out_dim,
dino_pca_dim=conf.dino_pca_dim,
image_size=conf.image_size,
key_features=conf.key_features,
)
class BTSWrapperOverfit(BTSWrapper):
def __init__(self, renderer, ray_sampler, config, eval_nvs=False, size=None) -> None:
super().__init__(renderer, ray_sampler, config, eval_nvs)
if config["predict_dino"]:
encoder_dummy = EncoderDinoDummy.from_conf(config["encoder"])
else:
encoder_dummy = EncoderDummy(
size,
config["encoder"]["d_out"],
)
self.renderer.net.encoder = encoder_dummy
def training(local_rank, config):
return base_training(
local_rank,
config,
get_dataflow,
initialize,
)
def get_dataflow(config):
# - Get train/test datasets
if idist.get_local_rank() > 0:
# Ensure that only local rank 0 download the dataset
# Thus each node will download a copy of the datasetMVBTSNet
idist.barrier()
train_dataset_full = make_datasets(config["dataset"])[0]
train_dataset = Subset(
train_dataset_full,
[config.get("example", config["dataset"].get("skip", 0))],
)
train_dataset.dataset._skip = config["dataset"].get("skip", 0)
validation_datasets = {}
for name, validation_config in config["validation"].items():
dataset = copy(train_dataset)
dataset.dataset.return_depth = True
validation_datasets[name] = dataset
if idist.get_local_rank() == 0:
# Ensure that only local rank 0 download the dataset
idist.barrier() ## Once the dataset has been downloaded, the barrier is invoked, and only then are the other processes allowed to proceed.
## By using this method, you can control the order of execution in a distributed setting and ensure that certain
## steps are not performed multiple times by different processes. This can be very useful when working with shared
## resources or when coordination is required between different processes.
# Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
train_loader_full = DataLoader(train_dataset_full)
train_loader = DataLoader(train_dataset)
validation_loaders = {}
for name, dataset in validation_datasets.items():
validation_loaders[name] = DataLoader(dataset)
return (train_loader, train_loader_full), validation_loaders
def initialize(config: dict):
net = make_model(config["model"])
renderer = NeRFRenderer.from_conf(config["renderer"])
renderer = renderer.bind_parallel(net, gpus=None).eval()
mode = config.get("mode", "depth")
ray_sampler = get_ray_sampler(config["training"]["ray_sampler"])
model = BTSWrapperOverfit(
renderer,
ray_sampler,
config["model"],
mode == "nvs",
size=config["dataset"].get("image_size", (192, 640)),
)
model = idist.auto_model(model)
optimizer = optim.Adam(model.parameters(), **config["training"]["optimizer"]["args"])
optimizer = idist.auto_optim(optimizer)
lr_scheduler = make_scheduler(config["training"].get("scheduler", {}), optimizer)
criterion = [
make_loss(config_loss)
for config_loss in config["training"]["loss"]
]
return model, optimizer, criterion, lr_scheduler
|