|
from .loss import Loss |
|
from .loss_depth import LossDepth, LossDepthCfgWrapper |
|
from .loss_lpips import LossLpips, LossLpipsCfgWrapper |
|
from .loss_mse import LossMse, LossMseCfgWrapper |
|
from .loss_opacity import LossOpacity, LossOpacityCfgWrapper |
|
from .loss_depth_gt import LossDepthGT, LossDepthGTCfgWrapper |
|
from .loss_lod import LossLOD, LossLODCfgWrapper |
|
from .loss_depth_consis import LossDepthConsis, LossDepthConsisCfgWrapper |
|
from .loss_normal_consis import LossNormalConsis, LossNormalConsisCfgWrapper |
|
from .loss_chamfer_distance import LossChamferDistance, LossChamferDistanceCfgWrapper |
|
LOSSES = { |
|
LossDepthCfgWrapper: LossDepth, |
|
LossLpipsCfgWrapper: LossLpips, |
|
LossMseCfgWrapper: LossMse, |
|
LossOpacityCfgWrapper: LossOpacity, |
|
LossDepthGTCfgWrapper: LossDepthGT, |
|
LossLODCfgWrapper: LossLOD, |
|
LossDepthConsisCfgWrapper: LossDepthConsis, |
|
LossNormalConsisCfgWrapper: LossNormalConsis, |
|
LossChamferDistanceCfgWrapper: LossChamferDistance, |
|
} |
|
|
|
LossCfgWrapper = LossDepthCfgWrapper | LossLpipsCfgWrapper | LossMseCfgWrapper | LossOpacityCfgWrapper | LossDepthGTCfgWrapper | LossLODCfgWrapper | LossDepthConsisCfgWrapper | LossNormalConsisCfgWrapper | LossChamferDistanceCfgWrapper |
|
|
|
def get_losses(cfgs: list[LossCfgWrapper]) -> list[Loss]: |
|
return [LOSSES[type(cfg)](cfg) for cfg in cfgs] |
|
|