Spaces:
Running
on
Zero
Running
on
Zero
from copy import copy | |
import ignite.distributed as idist | |
from torch import optim | |
from torch.utils.data import DataLoader, Subset | |
from scenedino.training.base_trainer import base_training | |
# TODO: change dataset | |
from scenedino.datasets import make_datasets | |
from scenedino.common.scheduler import make_scheduler | |
from scenedino.renderer import NeRFRenderer | |
from scenedino.models.backbones.dino.dinov2_module import * | |
from scenedino.training.trainer import BTSWrapper | |
from scenedino.models import make_model | |
from scenedino.common.ray_sampler import get_ray_sampler | |
from scenedino.losses import make_loss | |
class EncoderDummy(nn.Module): | |
def __init__(self, size, feat_dim, num_views=1) -> None: | |
super().__init__() ## initializes this feature map as a random tensor of a specified size | |
self.feats = nn.Parameter(torch.randn(num_views, feat_dim, *size)) | |
self.latent_size = feat_dim | |
def forward(self, x): | |
n = x.shape[0] | |
return [self.feats.expand(n, -1, -1, -1)] | |
class EncoderDinoDummy(nn.Module): | |
def __init__(self, | |
mode: str, # downsample-prediction, upsample-gt | |
decoder_arch: str, # nearest, bilinear, sfp, dpt | |
upsampler_arch: Optional[str], # nearest, bilinear, multiscale-crop | |
downsampler_arch: Optional[str], # sample-center, featup | |
encoder_arch: str, # vit-s, vit-b, fit3d-s | |
separate_gt_encoder_arch: Optional[str], # vit-s, vit-b, fit3d-s, None (reuses encoder) | |
encoder_freeze: bool, | |
dim_reduction_arch: str, # orthogonal-linear, mlp | |
num_ch_enc: np.array, | |
intermediate_features: List[int], | |
decoder_out_dim: int, | |
dino_pca_dim: int, | |
image_size: Tuple[int, int], | |
key_features: bool, | |
): | |
super().__init__() | |
self.feats = nn.Parameter(torch.randn(1, decoder_out_dim, *image_size)) | |
self.latent_size = decoder_out_dim | |
if separate_gt_encoder_arch is None: | |
self.gt_encoder = build_encoder(encoder_arch, image_size, [], key_features) # ONLY IN OVERFIT DUMMY! | |
else: | |
self.gt_encoder = build_encoder(separate_gt_encoder_arch, image_size, [], key_features) | |
for p in self.gt_encoder.parameters(True): | |
p.requires_grad = False | |
# General way of creating loss | |
if mode == "downsample-prediction": | |
assert upsampler_arch is None | |
self.downsampler = build_downsampler(downsampler_arch, self.gt_encoder.latent_size) | |
self.gt_wrapper = self.gt_encoder | |
elif mode == "upsample-gt": | |
assert downsampler_arch is None | |
self.downsampler = None | |
self.gt_wrapper = build_gt_upsampling_wrapper(upsampler_arch, self.gt_encoder, image_size) | |
else: | |
raise NotImplementedError | |
self.extra_outs = 0 | |
self.latent_size = decoder_out_dim | |
self.dino_pca_dim = dino_pca_dim | |
self.dim_reduction = build_dim_reduction(dim_reduction_arch, self.gt_encoder.latent_size, dino_pca_dim) | |
self.visualization = VisualizationModule(self.gt_encoder.latent_size) | |
def forward(self, x, ground_truth=False): | |
if ground_truth: | |
return self.gt_wrapper(x) | |
return [self.feats.expand(x.shape[0], -1, -1, -1)] | |
def downsample(self, x, mode="patch"): | |
if self.downsampler is None: | |
return None | |
else: | |
return self.downsampler(x, mode) | |
def expand_dim(self, features): | |
return self.dim_reduction.transform_expand(features) | |
def fit_visualization(self, features, refit=True): | |
return self.visualization.fit_pca(features, refit) | |
def transform_visualization(self, features, norm=False, from_dim=0): | |
return self.visualization.transform_pca(features, norm, from_dim) | |
def fit_transform_kmeans_visualization(self, features): | |
return self.visualization.fit_transform_kmeans_batch(features) | |
def from_conf(cls, conf): | |
return cls( | |
mode=conf.mode, | |
decoder_arch=conf.decoder_arch, | |
upsampler_arch=conf.get("upsampler_arch", None), | |
downsampler_arch=conf.get("downsampler_arch", None), | |
encoder_arch=conf.encoder_arch, | |
separate_gt_encoder_arch=conf.get("separate_gt_encoder_arch", None), | |
encoder_freeze=conf.encoder_freeze, | |
dim_reduction_arch=conf.dim_reduction_arch, | |
num_ch_enc=conf.get("num_ch_enc", None), | |
intermediate_features=conf.get("intermediate_features", []), | |
decoder_out_dim=conf.decoder_out_dim, | |
dino_pca_dim=conf.dino_pca_dim, | |
image_size=conf.image_size, | |
key_features=conf.key_features, | |
) | |
class BTSWrapperOverfit(BTSWrapper): | |
def __init__(self, renderer, ray_sampler, config, eval_nvs=False, size=None) -> None: | |
super().__init__(renderer, ray_sampler, config, eval_nvs) | |
if config["predict_dino"]: | |
encoder_dummy = EncoderDinoDummy.from_conf(config["encoder"]) | |
else: | |
encoder_dummy = EncoderDummy( | |
size, | |
config["encoder"]["d_out"], | |
) | |
self.renderer.net.encoder = encoder_dummy | |
def training(local_rank, config): | |
return base_training( | |
local_rank, | |
config, | |
get_dataflow, | |
initialize, | |
) | |
def get_dataflow(config): | |
# - Get train/test datasets | |
if idist.get_local_rank() > 0: | |
# Ensure that only local rank 0 download the dataset | |
# Thus each node will download a copy of the datasetMVBTSNet | |
idist.barrier() | |
train_dataset_full = make_datasets(config["dataset"])[0] | |
train_dataset = Subset( | |
train_dataset_full, | |
[config.get("example", config["dataset"].get("skip", 0))], | |
) | |
train_dataset.dataset._skip = config["dataset"].get("skip", 0) | |
validation_datasets = {} | |
for name, validation_config in config["validation"].items(): | |
dataset = copy(train_dataset) | |
dataset.dataset.return_depth = True | |
validation_datasets[name] = dataset | |
if idist.get_local_rank() == 0: | |
# Ensure that only local rank 0 download the dataset | |
idist.barrier() ## Once the dataset has been downloaded, the barrier is invoked, and only then are the other processes allowed to proceed. | |
## By using this method, you can control the order of execution in a distributed setting and ensure that certain | |
## steps are not performed multiple times by different processes. This can be very useful when working with shared | |
## resources or when coordination is required between different processes. | |
# Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu | |
train_loader_full = DataLoader(train_dataset_full) | |
train_loader = DataLoader(train_dataset) | |
validation_loaders = {} | |
for name, dataset in validation_datasets.items(): | |
validation_loaders[name] = DataLoader(dataset) | |
return (train_loader, train_loader_full), validation_loaders | |
def initialize(config: dict): | |
net = make_model(config["model"]) | |
renderer = NeRFRenderer.from_conf(config["renderer"]) | |
renderer = renderer.bind_parallel(net, gpus=None).eval() | |
mode = config.get("mode", "depth") | |
ray_sampler = get_ray_sampler(config["training"]["ray_sampler"]) | |
model = BTSWrapperOverfit( | |
renderer, | |
ray_sampler, | |
config["model"], | |
mode == "nvs", | |
size=config["dataset"].get("image_size", (192, 640)), | |
) | |
model = idist.auto_model(model) | |
optimizer = optim.Adam(model.parameters(), **config["training"]["optimizer"]["args"]) | |
optimizer = idist.auto_optim(optimizer) | |
lr_scheduler = make_scheduler(config["training"].get("scheduler", {}), optimizer) | |
criterion = [ | |
make_loss(config_loss) | |
for config_loss in config["training"]["loss"] | |
] | |
return model, optimizer, criterion, lr_scheduler | |