Spaces:
Running
on
Zero
Running
on
Zero
from .multi_view_head import ( | |
MultiViewHead, | |
SimpleMultiViewHead, | |
MultiViewHead2, | |
MultiViewHead3, | |
make_attn_layers, | |
) | |
from .resnetfc import ResnetFC | |
from .mlp import ImplicitNet, make_embedding_encoder | |
# TODO: reorganize this directory | |
def make_head(conf, d_in: int, d_out: int): | |
head_type = conf.get("type", "resnet") | |
if head_type == "mlp": | |
head = ImplicitNet.from_conf(conf["args"], d_in, d_out) | |
elif head_type == "resnet": | |
head = ResnetFC.from_conf(conf["args"], d_in, d_out) | |
elif head_type == "MultiViewHead": | |
head = MultiViewHead.from_conf(conf["args"], d_in, d_out) | |
elif head_type == "SimpleMultiViewHead": | |
head = SimpleMultiViewHead.from_conf(conf["args"], d_in, d_out) | |
elif head_type == "MultiViewHead2": | |
head = MultiViewHead2.from_conf(conf["args"], d_in, d_out) | |
elif head_type == "MultiViewHead3": | |
head = MultiViewHead3.from_conf(conf["args"], d_in, d_out) | |
## For baseline comparison | |
# elif head_type == "IBRNet": | |
# head = MultiViewHead.from_conf(conf["args"], d_in, d_out) | |
# elif head_type == "NeuRay": | |
# head = MultiViewHead.from_conf(conf["args"], d_in, d_out) | |
# elif head_type == "GeoNeRF": | |
# head = MultiViewHead.from_conf(conf["args"], d_in, d_out) | |
# elif head_type == "PixelNeRF": | |
# head = MultiViewHead.from_conf(conf["args"], d_in, d_out) | |
else: | |
raise NotImplementedError("Unsupported Head type") | |
if conf.get("freeze", False): | |
for param in head.parameters(): | |
param.requires_grad = False | |
print(f"__freezing the {conf['name']} for knowledge distillation.") | |
# else: | |
# print("__No freezing heads during training.") | |
return head | |