Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,245 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from scenedino.common.positional_encoding import PositionalEncoding
from .backbones import make_backbone
from .prediction_heads import make_head
from .bts import BTSNet
from scenedino.downstream_head import make_downstream_head
def make_model(config, downstream_config=None):
arch = config.get("arch", "BTSNet")
sample_color = config.get("sample_color", True)
predict_dino = config.get("predict_dino", False)
dino_dims = config.get("dino_dims", 16)
if sample_color and predict_dino:
d_out = 1 + dino_dims
elif sample_color:
d_out = 1
else:
d_out = 4
uncertainty_predictor_conf = config.get("uncertainty_predictor", None)
if uncertainty_predictor_conf is not None:
uncertainty_predictor = make_backbone(uncertainty_predictor_conf)
else:
uncertainty_predictor = None
match arch:
case "BTSNet":
code_xyz = PositionalEncoding.from_conf(config["code"], d_in=3)
encoder = make_backbone(config["encoder"])
d_in = encoder.latent_size + code_xyz.d_out
split_dino_heads = config.get("split_dino_heads", False)
if split_dino_heads:
heads = {
head_conf["name"]: make_head(head_conf, d_in, 1 if head_conf["name"] == "normal_head" else dino_dims)
for head_conf in config["decoder_heads"]
}
else:
heads = {
head_conf["name"]: make_head(head_conf, d_in, d_out)
for head_conf in config["decoder_heads"]
}
if downstream_config is not None:
downstream_head = make_downstream_head(downstream_config)
else:
downstream_head = None
# TODO: check ren_nc
return BTSNet(
config,
encoder,
code_xyz,
heads,
config.get("final_pred_head", None),
uncertainty_predictor=uncertainty_predictor,
ren_nc=None,
downstream_head=downstream_head
)
case _:
raise NotImplementedError("Model architecture was not implemented yet")
|