jetclustering / src /layers /utils_training.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
1.23 kB
from lightning.pytorch.callbacks import BaseFinetuning
import torch
import torch.nn as nn
class FreezeClustering(BaseFinetuning):
def __init__(
self,
):
super().__init__()
# self._unfreeze_at_epoch = unfreeze_at_epoch
def freeze_before_training(self, pl_module):
print("freezing the following module:", pl_module)
# freeze any module you want
# Here, we are freezing `feature_extractor`
self.freeze(pl_module.batch_norm)
# self.freeze(pl_module.Dense_1)
self.freeze(pl_module.gatr)
# self.freeze(pl_module.postgn_dense)
# self.freeze(pl_module.ScaledGooeyBatchNorm2_2)
self.freeze(pl_module.clustering)
self.freeze(pl_module.beta)
print("CLUSTERING HAS BEEN FROOOZEN")
def finetune_function(self, pl_module, current_epoch, optimizer):
print("Not finetunning")
# # When `current_epoch` is 10, feature_extractor will start training.
# if current_epoch == self._unfreeze_at_epoch:
# self.unfreeze_and_add_param_group(
# modules=pl_module.feature_extractor,
# optimizer=optimizer,
# train_bn=True,
# )