jev-aleks's picture
scenedino init
9e15541
from scenedino.common.positional_encoding import PositionalEncoding
from .backbones import make_backbone
from .prediction_heads import make_head
from .bts import BTSNet
from scenedino.downstream_head import make_downstream_head
def make_model(config, downstream_config=None):
arch = config.get("arch", "BTSNet")
sample_color = config.get("sample_color", True)
predict_dino = config.get("predict_dino", False)
dino_dims = config.get("dino_dims", 16)
if sample_color and predict_dino:
d_out = 1 + dino_dims
elif sample_color:
d_out = 1
else:
d_out = 4
uncertainty_predictor_conf = config.get("uncertainty_predictor", None)
if uncertainty_predictor_conf is not None:
uncertainty_predictor = make_backbone(uncertainty_predictor_conf)
else:
uncertainty_predictor = None
match arch:
case "BTSNet":
code_xyz = PositionalEncoding.from_conf(config["code"], d_in=3)
encoder = make_backbone(config["encoder"])
d_in = encoder.latent_size + code_xyz.d_out
split_dino_heads = config.get("split_dino_heads", False)
if split_dino_heads:
heads = {
head_conf["name"]: make_head(head_conf, d_in, 1 if head_conf["name"] == "normal_head" else dino_dims)
for head_conf in config["decoder_heads"]
}
else:
heads = {
head_conf["name"]: make_head(head_conf, d_in, d_out)
for head_conf in config["decoder_heads"]
}
if downstream_config is not None:
downstream_head = make_downstream_head(downstream_config)
else:
downstream_head = None
# TODO: check ren_nc
return BTSNet(
config,
encoder,
code_xyz,
heads,
config.get("final_pred_head", None),
uncertainty_predictor=uncertainty_predictor,
ren_nc=None,
downstream_head=downstream_head
)
case _:
raise NotImplementedError("Model architecture was not implemented yet")