jev-aleks's picture
scenedino init
9e15541
from .multi_view_head import (
MultiViewHead,
SimpleMultiViewHead,
MultiViewHead2,
MultiViewHead3,
make_attn_layers,
)
from .resnetfc import ResnetFC
from .mlp import ImplicitNet, make_embedding_encoder
# TODO: reorganize this directory
def make_head(conf, d_in: int, d_out: int):
head_type = conf.get("type", "resnet")
if head_type == "mlp":
head = ImplicitNet.from_conf(conf["args"], d_in, d_out)
elif head_type == "resnet":
head = ResnetFC.from_conf(conf["args"], d_in, d_out)
elif head_type == "MultiViewHead":
head = MultiViewHead.from_conf(conf["args"], d_in, d_out)
elif head_type == "SimpleMultiViewHead":
head = SimpleMultiViewHead.from_conf(conf["args"], d_in, d_out)
elif head_type == "MultiViewHead2":
head = MultiViewHead2.from_conf(conf["args"], d_in, d_out)
elif head_type == "MultiViewHead3":
head = MultiViewHead3.from_conf(conf["args"], d_in, d_out)
## For baseline comparison
# elif head_type == "IBRNet":
# head = MultiViewHead.from_conf(conf["args"], d_in, d_out)
# elif head_type == "NeuRay":
# head = MultiViewHead.from_conf(conf["args"], d_in, d_out)
# elif head_type == "GeoNeRF":
# head = MultiViewHead.from_conf(conf["args"], d_in, d_out)
# elif head_type == "PixelNeRF":
# head = MultiViewHead.from_conf(conf["args"], d_in, d_out)
else:
raise NotImplementedError("Unsupported Head type")
if conf.get("freeze", False):
for param in head.parameters():
param.requires_grad = False
print(f"__freezing the {conf['name']} for knowledge distillation.")
# else:
# print("__No freezing heads during training.")
return head