|
from typing import Any |
|
import torch.nn as nn |
|
|
|
from .backbone import Backbone |
|
from .backbone_croco_multiview import AsymmetricCroCoMulti |
|
from .backbone_dino import BackboneDino, BackboneDinoCfg |
|
from .backbone_resnet import BackboneResnet, BackboneResnetCfg |
|
from .backbone_croco import AsymmetricCroCo, BackboneCrocoCfg |
|
|
|
BACKBONES: dict[str, Backbone[Any]] = { |
|
"resnet": BackboneResnet, |
|
"dino": BackboneDino, |
|
"croco": AsymmetricCroCo, |
|
"croco_multi": AsymmetricCroCoMulti, |
|
} |
|
|
|
BackboneCfg = BackboneResnetCfg | BackboneDinoCfg | BackboneCrocoCfg |
|
|
|
|
|
def get_backbone(cfg: BackboneCfg, d_in: int = 3) -> nn.Module: |
|
return BACKBONES[cfg.name](cfg, d_in) |
|
|