alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame contribute delete
666 Bytes
from typing import Any
import torch.nn as nn
from .backbone import Backbone
from .backbone_croco_multiview import AsymmetricCroCoMulti
from .backbone_dino import BackboneDino, BackboneDinoCfg
from .backbone_resnet import BackboneResnet, BackboneResnetCfg
from .backbone_croco import AsymmetricCroCo, BackboneCrocoCfg
BACKBONES: dict[str, Backbone[Any]] = {
"resnet": BackboneResnet,
"dino": BackboneDino,
"croco": AsymmetricCroCo,
"croco_multi": AsymmetricCroCoMulti,
}
BackboneCfg = BackboneResnetCfg | BackboneDinoCfg | BackboneCrocoCfg
def get_backbone(cfg: BackboneCfg, d_in: int = 3) -> nn.Module:
return BACKBONES[cfg.name](cfg, d_in)