diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7f218fe4cad3e2013abccab252da5556e31023 --- /dev/null +++ b/app.py @@ -0,0 +1,81 @@ +import cv2 +import glob +import gradio as gr +import mediapy +import nibabel +import numpy as np +import shutil +import torch +import torch.nn.functional as F + +from omegaconf import OmegaConf +from skp import builder + + +def window(x, WL=400, WW=2500): + lower, upper = WL - WW // 2, WL + WW // 2 + x = np.clip(x, lower, upper) + x = x - lower + x = x / (upper - lower) + return (x * 255).astype("uint8") + + +def rescale(x): + x = x / 255. + x = x - 0.5 + x = x * 2.0 + return x + + +def generate_segmentation_video(study): + img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0) + img = window(img) + + X = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0) + X = F.interpolate(X, size=(192, 192, 192), mode="nearest") + X = rescale(X) + with torch.no_grad(): + seg_output = seg_model(X) + + seg_output = torch.sigmoid(seg_output) + p_spine = seg_output[:, :7].sum(1) + seg_output = torch.argmax(seg_output, dim=1) + 1 + seg_output[p_spine < 0.5] = 0 + seg_output = F.interpolate(seg_output.unsqueeze(0).float(), size=img.shape, mode="nearest") + seg_output = seg_output.squeeze(0).squeeze(0).numpy() + seg_output = (seg_output * 255 / 7).astype("uint8") + seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output]) + + frames = [] + skip = 8 + for idx in range(0, img.shape[2], skip): + i = img[:, :, idx] + o = seg_output[:, :, idx] + i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB) + frame = np.concatenate((i, o), 1) + frames.append(frame) + mediapy.write_video("video.mp4", frames, fps=30) + return "video.mp4" + + +ffmpeg_path = shutil.which('ffmpeg') +mediapy.set_ffmpeg(ffmpeg_path) + +config = OmegaConf.load("configs/pseudoseg000.yaml") +config.model.load_pretrained = "seg.ckpt" +seg_model = builder.build_model(config).eval() +examples = glob.glob("examples/*.nii.gz") + +with gr.Blocks(theme="dark-peach") as demo: + select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study") + button_predict = gr.Button("Predict") + video_output = gr.Video() + button_predict.click(fn=generate_segmentation_video, + inputs=select_study, + outputs=video_output) + + +if __name__ == "__main__": + demo.launch(debug=True, share=True) + + diff --git a/configs/chunk000.yaml b/configs/chunk000.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a1a259316b6495322a87cd1661114d7616036ba3 --- /dev/null +++ b/configs/chunk000.yaml @@ -0,0 +1,89 @@ +experiment: + seed: 88 + save_dir: ../experiments/ + + +data: + annotations: ../data/train_vertebra_chunks_kfold.csv + data_dir: ../data/train-numpy-vertebra-chunks + input: filename + target: fracture + outer_fold: 0 + dataset: + name: NumpyChunkDataset + params: + flip: true + invert: false + channels: grayscale + z_lt: resample_resample + z_gt: resample_resample + num_images: 64 + + +transform: + resize: + name: resize_ignore_3d + params: + imsize: [64, 288, 288] + augment: + null + crop: + null + preprocess: + name: Preprocessor + params: + image_range: [0, 255] + input_range: [0, 1] + mean: [0.5] + sdev: [0.5] + + +task: + name: ClassificationTask + params: + + +model: + name: Net3D + params: + backbone: x3d_l + backbone_params: + z_strides: [1, 1, 1, 1, 1] + pretrained: true + num_classes: 1 + dropout: 0.2 + pool: avg + in_channels: 1 + multisample_dropout: true + + +loss: + name: BCEWithLogitsLoss + params: + + +optimizer: + name: AdamW + params: + lr: 3.0e-4 + weight_decay: 5.0e-4 + + +scheduler: + name: CosineAnnealingLR + params: + final_lr: 0.0 + + +train: + batch_size: 4 + num_epochs: 10 + + +evaluate: + metrics: [AUROC] + monitor: auc_mean + mode: max + + + diff --git a/configs/chunkseq003.yaml b/configs/chunkseq003.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3adc2861393026265f4402a9826fa95d525aabf8 --- /dev/null +++ b/configs/chunkseq003.yaml @@ -0,0 +1,67 @@ +experiment: + seed: 88 + save_dir: ../experiments/ + + +data: + annotations: ../data/train_chunk_features_kfold.csv + data_dir: ../data/train-chunk000-features/foldx + input: filename + target: [C1, C2, C3, C4, C5, C6, C7, patient_overall] + outer_fold: 0 + dataset: + name: FeatureDataset + params: + seq_len: 7 + reverse: false + normalize: false + exam_level_label: true + + +task: + name: ClassificationTask + params: + + +model: + name: DualTransformer + params: + num_classes: 1 + embedding_dim: 432 + hidden_dim: 864 + n_layers: 3 + n_heads: 16 + + +loss: + name: MultilabelWeightedBCE + params: + weights: [1, 1, 1, 1, 1, 1, 1, 7] + pos_weight: 2.0 + + +optimizer: + name: AdamW + params: + lr: 1.0e-5 + weight_decay: 5.0e-4 + + +scheduler: + name: CosineAnnealingLR + params: + final_lr: 0 + + +train: + batch_size: 32 + num_epochs: 25 + + +evaluate: + batch_size: 1 + metrics: [CompetitionMetric, AUROC] + monitor: comp_metric + mode: min + + diff --git a/configs/pseudoseg000.yaml b/configs/pseudoseg000.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a8e6059c7d37a5aaf2938054f23a6eaf31670bd0 --- /dev/null +++ b/configs/pseudoseg000.yaml @@ -0,0 +1,110 @@ +experiment: + seed: 88 + save_dir: ../experiments/ + + +data: + annotations: ../data/train_seg_whole_192_kfold_with_pseudo.csv + data_dir: ../data/ + input: filename + target: label + outer_fold: 0 + dataset: + name: NumpyChunkSegmentDataset + params: + segmentation_format: numpy + channels: grayscale + flip: true + transpose: true + invert: false + verbose: true + num_images: 192 + z_lt: resample_resample + z_gt: resample_resample + one_hot_encode: true + num_classes: 8 + add_foreground_channel: false + + +transform: + resize: + name: resize_ignore_3d + params: + imsize: [192, 192, 192] + augment: + null + crop: + null + preprocess: + name: Preprocessor + params: + image_range: [0, 255] + input_range: [0, 1] + mean: [0.5] + sdev: [0.5] + + +task: + name: SegmentationTask3D + params: + chunk_validation: true + + +model: + name: NetSegment3D + params: + architecture: DeepLabV3Plus_3D + encoder_name: x3d_l + encoder_params: + pretrained: true + output_stride: 16 + z_strides: [2, 2, 2, 2, 2] + decoder_params: + upsampling: 4 + deep_supervision: true + num_classes: 8 + in_channels: 1 + dropout: 0.2 + + +loss: + name: SupervisorLoss + params: + segmentation_loss: DiceBCELoss + scale_factors: [0.25, 0.25] + loss_weights: [1.0, 0.25, 0.25] + loss_params: + dice_loss_params: + mode: multilabel + exponent: 2 + smooth: 1.0 + bce_loss_params: + smooth_factor: 0.01 + pos_weight: 1.0 + dice_loss_weight: 1.0 + bce_loss_weight: 0.2 + + +optimizer: + name: AdamW + params: + lr: 3.0e-4 + weight_decay: 5.0e-4 + + +scheduler: + name: CosineAnnealingLR + params: + final_lr: 0.0 + + +train: + batch_size: 4 + num_epochs: 10 + + +evaluate: + batch_size: 1 + metrics: [DSC] + monitor: dsc_ignore_mean + mode: max \ No newline at end of file diff --git a/examples/1.2.826.0.1.3680043.15773.nii.gz b/examples/1.2.826.0.1.3680043.15773.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..5674d93fa3a95212de1f15b8eecee183f6ecf41e --- /dev/null +++ b/examples/1.2.826.0.1.3680043.15773.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a316d4cdb9534c662a209dea2b50fd57168398b1a658d14937ce285d3b792917 +size 65868417 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..a9f1eea092d5e971b5475b82ee835cec7f196bad --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +ffmpeg \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5bca9c4ceed5fc000408cda0a8f93b9099d66a1f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +omegaconf +mediapy +nibabel +opencv-python +timm +torch +transformers diff --git a/seg.ckpt b/seg.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..fc3175a6886e2ef6e9a73bf0a54ff688fbb9a792 --- /dev/null +++ b/seg.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa6ee0036af98df68621b5cacbed9b4cd290eb1b59c6af7785a7e9c81ed74afa +size 21569386 diff --git a/seq.ckpt b/seq.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..de8e2ba34fda65d84635d93d49d2ffe5ccc03e05 --- /dev/null +++ b/seq.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:653637b500e3ae5ffab8b07f34d36662396ab3eacec8024e5ecea952d7c2c07e +size 18011334 diff --git a/skp/.DS_Store b/skp/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..870555a0e96bf63b46a3981ae6df262585d35819 Binary files /dev/null and b/skp/.DS_Store differ diff --git a/skp/__init__.py b/skp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/skp/__pycache__/__init__.cpython-39.pyc b/skp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..267c19f530086cc5577d119e2905d8d6315d7d4b Binary files /dev/null and b/skp/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/__pycache__/builder.cpython-39.pyc b/skp/__pycache__/builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f50c43a5635b18a353ccbfd8cd3f10d9c6e970e Binary files /dev/null and b/skp/__pycache__/builder.cpython-39.pyc differ diff --git a/skp/builder.py b/skp/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..635f03864e059b4feb515765ba177c58927334dc --- /dev/null +++ b/skp/builder.py @@ -0,0 +1,187 @@ +import numpy as np +import torch + +from . import models + + +def get_name_and_params(base): + name = getattr(base, 'name') + params = getattr(base, 'params') or {} + return name, params + + +def get_transform(base, transform, mode=None): + if not base: return None + transform = getattr(base, transform) + if not transform: return None + name, params = get_name_and_params(transform) + if mode: + params.update({'mode': mode}) + return getattr(data.transforms, name)(**params) + + +def build_transforms(cfg, mode): + # 1-Resize + resizer = get_transform(cfg.transform, 'resize') + # 2-(Optional) Data augmentation + augmenter = None + if mode == "train": + augmenter = get_transform(cfg.transform, 'augment') + # 3-(Optional) Crop + cropper = get_transform(cfg.transform, 'crop', mode=mode) + # 4-Preprocess + preprocessor = get_transform(cfg.transform, 'preprocess') + return { + 'resize': resizer, + 'augment': augmenter, + 'crop': cropper, + 'preprocess': preprocessor + } + + +def build_dataset(cfg, data_info, mode): + dataset_class = getattr(data.datasets, cfg.data.dataset.name) + dataset_params = cfg.data.dataset.params + dataset_params.test_mode = mode != 'train' + dataset_params = dict(dataset_params) + if "FeatureDataset" not in cfg.data.dataset.name: + transforms = build_transforms(cfg, mode) + dataset_params.update(transforms) + dataset_params.update(data_info) + return dataset_class(**dataset_params) + + +def build_dataloader(cfg, dataset, mode): + + def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + dataloader_params = {} + dataloader_params['num_workers'] = cfg.data.num_workers + dataloader_params['drop_last'] = mode == 'train' + dataloader_params['shuffle'] = mode == 'train' + dataloader_params["pin_memory"] = cfg.data.get("pin_memory", True) + if mode in ('train', 'valid'): + if mode == "train": + dataloader_params['batch_size'] = cfg.train.batch_size + elif mode == "valid": + dataloader_params["batch_size"] = cfg.evaluate.get("batch_size") or cfg.train.batch_size + sampler = None + if cfg.data.get("sampler") and mode == 'train': + name, params = get_name_and_params(cfg.data.sampler) + sampler = getattr(data.samplers, name)(dataset, **params) + if sampler: + dataloader_params['shuffle'] = False + if cfg.strategy == 'ddp': + sampler = data.samplers.DistributedSamplerWrapper(sampler) + dataloader_params['sampler'] = sampler + print(f'Using sampler {sampler} for training ...') + elif cfg.strategy == 'ddp': + dataloader_params["shuffle"] = False + dataloader_params['sampler'] = DistributedSampler(dataset, shuffle=mode=="train") + else: + assert cfg.strategy != "ddp", "DDP currently not supported for inference" + dataloader_params['batch_size'] = cfg.evaluate.get("batch_size") or cfg.train.batch_size + + loader = DataLoader(dataset, + **dataloader_params, + worker_init_fn=worker_init_fn) + return loader + + +def build_model(cfg): + name, params = get_name_and_params(cfg.model) + if cfg.model.params.get("cnn_params", None): + cnn_params = cfg.model.params.cnn_params + if cnn_params.get("load_pretrained_backbone", None): + if "foldx" in cnn_params.load_pretrained_backbone: + cfg.model.params.cnn_params.load_pretrained_backbone = cnn_params.load_pretrained_backbone.\ + replace("foldx", f"fold{cfg.data.outer_fold}") + print(f'Creating model <{name}> ...') + model = getattr(models.engine, name)(**params) + if 'backbone' in cfg.model.params: + print(f' Using backbone <{cfg.model.params.backbone}> ...') + if 'pretrained' in cfg.model.params: + print(f' Pretrained : {cfg.model.params.pretrained}') + if "load_pretrained" in cfg.model: + import re + if "foldx" in cfg.model.load_pretrained: + cfg.model.load_pretrained = cfg.model.load_pretrained.replace("foldx", f"fold{cfg.data.outer_fold}") + print(f" Loading pretrained checkpoint from {cfg.model.load_pretrained}") + weights = torch.load(cfg.model.load_pretrained, map_location=lambda storage, loc: storage)['state_dict'] + weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items() if "loss_fn" not in k} + model.load_state_dict(weights) + return model + + +def build_loss(cfg): + name, params = get_name_and_params(cfg.loss) + print(f'Using loss function <{name}> ...') + params = dict(params) + if "pos_weight" in params: + params["pos_weight"] = torch.tensor(params["pos_weight"]) + criterion = getattr(losses, name)(**params) + return criterion + + +def build_scheduler(cfg, optimizer): + # Some schedulers will require manipulation of config params + # My specifications were to make it more intuitive for me + name, params = get_name_and_params(cfg.scheduler) + print(f'Using learning rate schedule <{name}> ...') + + if name == 'CosineAnnealingLR': + # eta_min <-> final_lr + # Set T_max as 100000 ... this is changed in on_train_start() method + # of the LightningModule task + + params = { + 'T_max': 100000, + 'eta_min': max(params.final_lr, 1.0e-8) + } + + if name in ('OneCycleLR', 'CustomOneCycleLR'): + # Use learning rate from optimizer parameters as initial learning rate + lr_0 = cfg.optimizer.params.lr + lr_1 = params.max_lr + lr_2 = params.final_lr + # lr_0 -> lr_1 -> lr_2 + pct_start = params.pct_start + params = {} + params['steps_per_epoch'] = 100000 # see above- will fix in task + params['epochs'] = cfg.train.num_epochs + params['max_lr'] = lr_1 + params['pct_start'] = pct_start + params['div_factor'] = lr_1 / lr_0 # max/init + params['final_div_factor'] = lr_0 / max(lr_2, 1.0e-8) # init/final + + scheduler = getattr(optim, name)(optimizer=optimizer, **params) + + # Some schedulers might need more manipulation after instantiation + if name in ('OneCycleLR', 'CustomOneCycleLR'): + scheduler.pct_start = params['pct_start'] + + # Set update frequency + if name in ('OneCycleLR', 'CustomOneCycleLR', 'CosineAnnealingLR'): + scheduler.update_frequency = 'on_batch' + elif name in ('ReduceLROnPlateau'): + scheduler.update_frequency = 'on_valid' + else: + scheduler.update_frequency = 'on_epoch' + + return scheduler + + +def build_optimizer(cfg, parameters): + name, params = get_name_and_params(cfg.optimizer) + print(f'Using optimizer <{name}> ...') + optimizer = getattr(optim, name)(parameters, **params) + return optimizer + + +def build_task(cfg, model): + name, params = get_name_and_params(cfg.task) + print(f'Building task <{name}> ...') + return getattr(tasks, name)(cfg, model, **params) + + diff --git a/skp/models/__init__.py b/skp/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3652ec81abb27ddd7b48d1ae3af9d6443a7f026 --- /dev/null +++ b/skp/models/__init__.py @@ -0,0 +1 @@ +from . import engine \ No newline at end of file diff --git a/skp/models/__pycache__/__init__.cpython-39.pyc b/skp/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71d84c1afa299b577e43fcddd891f0e2c0620a6c Binary files /dev/null and b/skp/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/__pycache__/backbones.cpython-39.pyc b/skp/models/__pycache__/backbones.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b14aeb204deec73bec7a5c75ebbe0ba4ac7a5af5 Binary files /dev/null and b/skp/models/__pycache__/backbones.cpython-39.pyc differ diff --git a/skp/models/__pycache__/engine.cpython-39.pyc b/skp/models/__pycache__/engine.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75ff0aaa84321451522f75237d3f3a211422e04c Binary files /dev/null and b/skp/models/__pycache__/engine.cpython-39.pyc differ diff --git a/skp/models/__pycache__/sequence.cpython-39.pyc b/skp/models/__pycache__/sequence.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfb86308b126e54d4f9fc69b47a206840318772b Binary files /dev/null and b/skp/models/__pycache__/sequence.cpython-39.pyc differ diff --git a/skp/models/__pycache__/tools.cpython-39.pyc b/skp/models/__pycache__/tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a1587a3fb3bd7a0f790e2d4de41259b9e84bf7c Binary files /dev/null and b/skp/models/__pycache__/tools.cpython-39.pyc differ diff --git a/skp/models/backbones.py b/skp/models/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..82679caef0b2bbdfcf72264bc343b1c7d49affe4 --- /dev/null +++ b/skp/models/backbones.py @@ -0,0 +1,114 @@ +import re +import timm +import torch + +from functools import partial +from timm.models.vision_transformer import VisionTransformer +from timm.models.swin_transformer_v2 import SwinTransformerV2 + +from .vmz.backbones import * + + +def check_name(name, s): + return bool(re.search(s, name)) + + +def create_backbone(name, pretrained, features_only=False, **kwargs): + try: + model = timm.create_model(name, pretrained=pretrained, + features_only=features_only, + num_classes=0, global_pool="") + except Exception as e: + assert name in BACKBONES, f"{name} is not a valid backbone" + model = BACKBONES[name](pretrained=pretrained, features_only=features_only, **kwargs) + with torch.no_grad(): + if check_name(name, r"x3d|csn|r2plus1d|i3d"): + dim_feats = model(torch.randn((2, 3, 64, 64, 64))).size(1) + elif isinstance(model, (VisionTransformer, SwinTransformerV2)): + dim_feats = model.norm.normalized_shape[0] + else: + dim_feats = model(torch.randn((2, 3, 128, 128))).size(1) + return model, dim_feats + + +def create_csn(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs): + if features_only: + raise Exception("features_only is currently not supported") + if not pretrained: + from pytorchvideo.models import hub + model = getattr(hub, name)(pretrained=False) + else: + model = torch.hub.load("facebookresearch/pytorchvideo:main", model=name, pretrained=pretrained) + model.blocks[5] = nn.Identity() + return model + + +def create_x3d(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs): + if not pretrained: + from pytorchvideo.models import hub + model = getattr(hub, name)(pretrained=False) + else: + model = torch.hub.load("facebookresearch/pytorchvideo", model=name, pretrained=pretrained) + for idx, z in enumerate(z_strides): + assert z in [1, 2], "Only z-strides of 1 or 2 are supported" + if z == 2: + if idx == 0: + stem_layer = model.blocks[0].conv.conv_t + w = stem_layer.weight + w = w.repeat(1, 1, 3, 1, 1) + in_channels, out_channels = stem_layer.in_channels, stem_layer.out_channels + model.blocks[0].conv.conv_t = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) + else: + model.blocks[idx].res_blocks[0].branch1_conv.stride = (2, 2, 2) + model.blocks[idx].res_blocks[0].branch2.conv_b.stride = (2, 2, 2) + + if features_only: + model.blocks[-1] = nn.Identity() + model = X3D_Features(model) + else: + model.blocks[-1] = nn.Sequential( + model.blocks[-1].pool.pre_conv, + model.blocks[-1].pool.pre_norm, + model.blocks[-1].pool.pre_act, + ) + + return model + + +def create_i3d(name, pretrained, features_only=False, **kwargs): + from pytorchvideo.models import hub + model = getattr(hub, name)(pretrained=pretrained) + model.blocks[-1] = nn.Identity() + return model + + +class X3D_Features(nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + self.out_channels = [24, 24, 48, 96, 192] + + def forward(self, x): + features = [] + for idx in range(len(self.model.blocks) - 1): + x = self.model.blocks[idx](x) + features.append(x) + return features + + +BACKBONES = { + "x3d_xs": partial(create_x3d, name="x3d_xs"), + "x3d_s": partial(create_x3d, name="x3d_s"), + "x3d_m": partial(create_x3d, name="x3d_m"), + "x3d_l": partial(create_x3d, name="x3d_l"), + "i3d_r50": partial(create_i3d, name="i3d_r50"), + "csn_r101": partial(create_csn, name="csn_r101"), + "ir_csn_50": ir_csn_50, + "ir_csn_101": ir_csn_101, + "ir_csn_152": ir_csn_152, + "ip_csn_50": ip_csn_50, + "ip_csn_101": ip_csn_101, + "ip_csn_152": ip_csn_152, + "r2plus1d_34": r2plus1d_34 +} \ No newline at end of file diff --git a/skp/models/engine.py b/skp/models/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..73a13992c45ddfe4e2ecd96dbc87e40f5b94cacb --- /dev/null +++ b/skp/models/engine.py @@ -0,0 +1,257 @@ +import math +import numpy as np +import re +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorchvideo.models.x3d import create_x3d_stem +from timm.models.vision_transformer import VisionTransformer +from timm.models.swin_transformer_v2 import SwinTransformerV2 +from . import backbones +from . import segmentation +from .pooling import create_pool2d_layer, create_pool3d_layer +from .sequence import Transformer, DualTransformer, DualTransformerV2 +from .tools import change_initial_stride, change_num_input_channels + + +class Net2D(nn.Module): + + def __init__(self, + backbone, + pretrained, + num_classes, + dropout, + pool, + in_channels=3, + change_stride=None, + feature_reduction=None, + multisample_dropout=False, + load_pretrained_backbone=None, + freeze_backbone=False, + backbone_params={}, + pool_layer_params={}): + + super().__init__() + self.backbone, dim_feats = backbones.create_backbone(name=backbone, pretrained=pretrained, **backbone_params) + if isinstance(pool, str): + self.pool_layer = create_pool2d_layer(name=pool, **pool_layer_params) + else: + self.pool_layer = nn.Identity() + if pool == "catavgmax": + dim_feats *= 2 + self.msdo = multisample_dropout + if in_channels != 3: + self.backbone = change_num_input_channels(self.backbone, in_channels) + if change_stride: + self.backbone = change_initial_stride(self.backbone, tuple(change_stride), in_channels) + self.dropout = nn.Dropout(p=dropout) + if isinstance(feature_reduction, int): + # Use 1D grouped convolution to reduce # of parameters + groups = math.gcd(dim_feats, feature_reduction) + self.feature_reduction = nn.Conv1d(dim_feats, feature_reduction, groups=groups, kernel_size=1, + stride=1, bias=False) + dim_feats = feature_reduction + self.classifier = nn.Linear(dim_feats, num_classes) + + if load_pretrained_backbone: + # Assumes that model has a `backbone` attribute + # Note: if you want to load the entire pretrained model, this is done via the + # builder.build_model function + print(f"Loading pretrained backbone from {load_pretrained_backbone} ...") + weights = torch.load(load_pretrained_backbone, map_location=lambda storage, loc: storage)['state_dict'] + weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items()} + # Get feature_reduction, if present + feat_reduce_weight = {re.sub(r"^feature_reduction.", "", k): v + for k, v in weights.items() if "feature_reduction" in k} + # Get backbone only + weights = {re.sub(r'^backbone.', '', k) : v for k,v in weights.items() if 'backbone' in k} + self.backbone.load_state_dict(weights) + if len(feat_reduce_weight) > 0: + print("Also loading feature reduction layer ...") + self.feature_reduction.load_state_dict(feat_reduce_weight) + + if freeze_backbone: + print("Freezing backbone ...") + for param in self.backbone.parameters(): + param.requires_grad = False + + def extract_features(self, x): + features = self.backbone(x) + features = self.pool_layer(features) + if isinstance(self.backbone, VisionTransformer): + features = features[:, self.backbone.num_prefix_tokens:].mean(dim=1) + if isinstance(self.backbone, SwinTransformerV2): + features = features.mean(dim=1) + if hasattr(self, "feature_reduction"): + features = self.feature_reduction(features.unsqueeze(-1)).squeeze(-1) + return features + + def forward(self, x): + features = self.extract_features(x) + if self.msdo: + x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0) + else: + x = self.classifier(self.dropout(features)) + # Important nuance: + # For binary classification, the model returns a tensor of shape (N,) + # Otherwise, (N,C) + return x[:, 0] if self.classifier.out_features == 1 else x + + +class SeqNet2D(Net2D): + + def forward(self, x): + # x.shape = (N, C, Z, H, W) + features = torch.stack([self.extract_features(x[:, :, _]) for _ in range(x.size(2))], dim=2) + features = features.max(2)[0] + + if self.msdo: + x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0) + else: + x = self.classifier(self.dropout(features)) + # Important nuance: + # For binary classification, the model returns a tensor of shape (N,) + # Otherwise, (N,C) + return x[:, 0] if self.classifier.out_features == 1 else x + + +class TDCNN(nn.Module): + + def __init__(self, cnn_params, transformer_params, freeze_cnn=False, freeze_transformer=False): + super().__init__() + self.cnn = Net2D(**cnn_params) + del self.cnn.dropout + del self.cnn.classifier + self.transformer = Transformer(**transformer_params) + + if freeze_cnn: + for param in self.cnn.parameters(): + param.requires_grad = False + + if freeze_transformer: + for param in self.transformer.parameters(): + param.requires_grad = False + + def extract_features(self, x): + N, C, Z, H, W = x.size() + assert N == 1, "For feature extraction, batch size must be 1" + features = self.cnn.extract_features(x.squeeze(0).transpose(0, 1)).unsqueeze(0) + # features.shape = (1, Z, dim_feats) + return self.transformer.extract_features((features, torch.ones((features.size(0), features.size(1))).to(features.device))) + + def forward(self, x): + # BCZHW + features = torch.stack([self.cnn.extract_features(x[:, :, i]) for i in range(x.size(2))], dim=1) + # B, seq_len, dim_feat + return self.transformer((features, torch.ones((features.size(0), features.size(1))).to(features.device))) + + +class Net2DWith3DStem(Net2D): + + def __init__(self, *args, **kwargs): + stem_out_channels = kwargs.pop("stem_out_channels", 24) + load_pretrained_stem = kwargs.pop("load_pretrained_stem", None) + conv_kernel_size = tuple(kwargs.pop("conv_kernel_size", (5, 3, 3))) + conv_stride = tuple(kwargs.pop("conv_stride", (1, 2, 2))) + in_channels = kwargs.pop("in_channels", 3) + kwargs["in_channels"] = stem_out_channels + super().__init__(*args, **kwargs) + self.stem_layer = create_x3d_stem(in_channels=in_channels, + out_channels=stem_out_channels, + conv_kernel_size=conv_kernel_size, + conv_stride=conv_stride) + if kwargs["pretrained"]: + from pytorchvideo.models.hub import x3d_l + self.stem_layer.load_state_dict(x3d_l(pretrained=True).blocks[0].state_dict()) + + if load_pretrained_stem: + import re + print(f" Loading pretrained stem from {load_pretrained_stem} ...") + weights = torch.load(load_pretrained_stem, map_location=lambda storage, loc: storage)['state_dict'] + stem_weights = {k.replace("model.backbone.blocks.0.", ""): v for k, v in weights.items() if "backbone.blocks.0" in k} + self.stem_layer.load_state_dict(stem_weights) + + def forward(self, x): + x = self.stem_layer(x) + x = x.mean(3) + features = self.extract_features(x) + if self.msdo: + x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0) + else: + x = self.classifier(self.dropout(features)) + # Important nuance: + # For binary classification, the model returns a tensor of shape (N,) + # Otherwise, (N,C) + return x[:, 0] if self.classifier.out_features == 1 else x + + +class Net3D(Net2D): + + def __init__(self, *args, **kwargs): + z_strides = kwargs.pop("z_strides", [1,1,1,1,1]) + super().__init__(*args, **kwargs) + self.pool_layer = create_pool3d_layer(name=kwargs["pool"], **kwargs.pop("pool_layer_params", {})) + + +class NetSegment2D(nn.Module): + """ For now, this class essentially servers as a wrapper for the + segmentation model which is mostly defined in the segmentation submodule, + similar to the original segmentation_models.pytorch. + + It may be worth refactoring it in the future, such that you define this as + a general class, then select your choice of encoder and decoder. The encoder + is pretty much the same across all the segmentation models currently + implemented (DeepLabV3+, FPN, Unet). + """ + def __init__(self, + architecture, + encoder_name, + encoder_params, + decoder_params, + num_classes, + dropout, + in_channels, + load_pretrained_encoder=None, + freeze_encoder=False, + deep_supervision=False, + pool_layer_params={}, + aux_head_params={}): + + super().__init__() + + self.segmentation_model = getattr(segmentation, architecture)( + encoder_name=encoder_name, + encoder_params=encoder_params, + dropout=dropout, + classes=num_classes, + deep_supervision=deep_supervision, + in_channels=in_channels, + **decoder_params + ) + + + if load_pretrained_encoder: + # Assumes that model has a `encoder` attribute + # Note: if you want to load the entire pretrained model, this is done via the + # builder.build_model function + print(f"Loading pretrained encoder from {load_pretrained_encoder} ...") + weights = torch.load(load_pretrained_encoder, map_location=lambda storage, loc: storage)['state_dict'] + weights = {re.sub(r'^model.segmentation_model', '', k) : v for k,v in weights.items()} + # Get encoder only + weights = {re.sub(r'^encoder.', '', k) : v for k,v in weights.items() if 'backbone' in k} + self.segmentation_model.encoder.load_state_dict(weights) + + if freeze_encoder: + print("Freezing encoder ...") + for param in self.segmentation_model.encoder.parameters(): + param.requires_grad = False + + + def forward(self, x): + return self.segmentation_model(x) + + +class NetSegment3D(NetSegment2D): + + pass diff --git a/skp/models/pooling/__init__.py b/skp/models/pooling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7cb19e154836f0631ad018916e9cf58f7a210c --- /dev/null +++ b/skp/models/pooling/__init__.py @@ -0,0 +1,3 @@ +from .pool3d import create_pool3d_layer +from .pool2d import create_pool2d_layer +from .pool1d import create_pool1d_layer \ No newline at end of file diff --git a/skp/models/pooling/__pycache__/__init__.cpython-39.pyc b/skp/models/pooling/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..235e064807968220925aac63f8ce9045f95d5e3a Binary files /dev/null and b/skp/models/pooling/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/pooling/__pycache__/gem.cpython-39.pyc b/skp/models/pooling/__pycache__/gem.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62a29ef9408fca040be67c5a2f156b52b018427d Binary files /dev/null and b/skp/models/pooling/__pycache__/gem.cpython-39.pyc differ diff --git a/skp/models/pooling/__pycache__/pool1d.cpython-39.pyc b/skp/models/pooling/__pycache__/pool1d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..706649d7f0727a8e666f1241e257b0e66dc97753 Binary files /dev/null and b/skp/models/pooling/__pycache__/pool1d.cpython-39.pyc differ diff --git a/skp/models/pooling/__pycache__/pool2d.cpython-39.pyc b/skp/models/pooling/__pycache__/pool2d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff93608319c61792118ccfe058d87df2bb99b1b0 Binary files /dev/null and b/skp/models/pooling/__pycache__/pool2d.cpython-39.pyc differ diff --git a/skp/models/pooling/__pycache__/pool3d.cpython-39.pyc b/skp/models/pooling/__pycache__/pool3d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d70f29f88bf0d1268b336c5141ab5c21ba37e36 Binary files /dev/null and b/skp/models/pooling/__pycache__/pool3d.cpython-39.pyc differ diff --git a/skp/models/pooling/gem.py b/skp/models/pooling/gem.py new file mode 100644 index 0000000000000000000000000000000000000000..a14f145095f6b8b9ed37ce0559f4d563cb6d9bc8 --- /dev/null +++ b/skp/models/pooling/gem.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# From: https://github.com/filipradenovic/cnnimageretrieval-pytorch/blob/master/cirtorch/layers/pooling.py +def gem_1d(x, p=3, eps=1e-6): + return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1),)).pow(1./p) + + +def gem_2d(x, p=3, eps=1e-6): + return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) + + +def gem_3d(x, p=3, eps=1e-6): + return F.avg_pool3d(x.clamp(min=eps).pow(p), (x.size(-3), x.size(-2), x.size(-1))).pow(1./p) + + +_GEM_FN = { + 1: gem_1d, 2: gem_2d, 3: gem_3d +} + + +class GeM(nn.Module): + + def __init__(self, p=3, eps=1e-6, dim=2): + super().__init__() + self.p = nn.Parameter(torch.ones(1)*p) + self.eps = eps + self.dim = dim + self.flatten = nn.Flatten(1) + + def forward(self, x): + pooled = _GEM_FN[self.dim](x, p=self.p, eps=self.eps) + return self.flatten(pooled) \ No newline at end of file diff --git a/skp/models/pooling/pool1d.py b/skp/models/pooling/pool1d.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd7aa29edcff1b9bcfa7a6e53a3232520b34fdf --- /dev/null +++ b/skp/models/pooling/pool1d.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .gem import GeM + + +def adaptive_avgmax_pool1d(x, output_size=1): + x_avg = F.adaptive_avg_pool1d(x, output_size) + x_max = F.adaptive_max_pool1d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool1d(x, output_size=1): + x_avg = F.adaptive_avg_pool1d(x, output_size) + x_max = F.adaptive_max_pool1d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool1d(x, pool_type='avg', output_size=1): + """Selectable global pooling function with dynamic input kernel size + """ + if pool_type == 'avg': + x = F.adaptive_avg_pool1d(x, output_size) + elif pool_type == 'avgmax': + x = adaptive_avgmax_pool1d(x, output_size) + elif pool_type == 'catavgmax': + x = adaptive_catavgmax_pool1d(x, output_size) + elif pool_type == 'max': + x = F.adaptive_max_pool1d(x, output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + return x + + +class FastAdaptiveAvgPool1d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool1d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean(2, keepdim=not self.flatten) + + +class AdaptiveAvgMaxPool1d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveAvgMaxPool1d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_avgmax_pool1d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool1d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveCatAvgMaxPool1d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_catavgmax_pool1d(x, self.output_size) + + +class SelectAdaptivePool1d(nn.Module): + """Selectable global pooling layer with dynamic input kernel size + """ + def __init__(self, output_size=1, pool_type='fast', flatten=False): + super(SelectAdaptivePool1d, self).__init__() + self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + if pool_type == '': + self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool1d(flatten) + self.flatten = nn.Identity() + elif pool_type == 'avg': + self.pool = nn.AdaptiveAvgPool1d(output_size) + elif pool_type == 'avgmax': + self.pool = AdaptiveAvgMaxPool1d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool1d(output_size) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool1d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + + def is_identity(self): + return not self.pool_type + + def forward(self, x): + x = self.pool(x) + x = self.flatten(x) + return x + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + 'pool_type=' + self.pool_type \ + + ', flatten=' + str(self.flatten) + ')' + + +def create_pool1d_layer(name, **kwargs): + assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"] + if name != "gem": + pool1d_layer = SelectAdaptivePool1d(pool_type=name, flatten=True) + elif name == "gem": + pool1d_layer = GeM(dim=1, **kwargs) + return pool1d_layer \ No newline at end of file diff --git a/skp/models/pooling/pool2d.py b/skp/models/pooling/pool2d.py new file mode 100644 index 0000000000000000000000000000000000000000..e73fde3c6274b17e59c4a3578295c67f65d2448d --- /dev/null +++ b/skp/models/pooling/pool2d.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.models.layers import SelectAdaptivePool2d + +from .gem import GeM + + +def create_pool2d_layer(name, **kwargs): + assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"] + if name != "gem": + pool2d_layer = SelectAdaptivePool2d(pool_type=name, flatten=True) + elif name == "gem": + pool2d_layer = GeM(dim=2, **kwargs) + return pool2d_layer \ No newline at end of file diff --git a/skp/models/pooling/pool3d.py b/skp/models/pooling/pool3d.py new file mode 100644 index 0000000000000000000000000000000000000000..c25588fbd7d8480704ce7bb84fa1ceac22bb7d54 --- /dev/null +++ b/skp/models/pooling/pool3d.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .gem import GeM + + +def adaptive_avgmax_pool3d(x, output_size=1): + x_avg = F.adaptive_avg_pool3d(x, output_size) + x_max = F.adaptive_max_pool3d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool3d(x, output_size=1): + x_avg = F.adaptive_avg_pool3d(x, output_size) + x_max = F.adaptive_max_pool3d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool3d(x, pool_type='avg', output_size=1): + """Selectable global pooling function with dynamic input kernel size + """ + if pool_type == 'avg': + x = F.adaptive_avg_pool3d(x, output_size) + elif pool_type == 'avgmax': + x = adaptive_avgmax_pool3d(x, output_size) + elif pool_type == 'catavgmax': + x = adaptive_catavgmax_pool3d(x, output_size) + elif pool_type == 'max': + x = F.adaptive_max_pool3d(x, output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + return x + + +class FastAdaptiveAvgPool3d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool3d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean((2,3,4), keepdim=not self.flatten) + + +class AdaptiveAvgMaxPool3d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveAvgMaxPool3d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_avgmax_pool3d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool3d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveCatAvgMaxPool3d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_catavgmax_pool3d(x, self.output_size) + + +class SelectAdaptivePool3d(nn.Module): + """Selectable global pooling layer with dynamic input kernel size + """ + def __init__(self, output_size=1, pool_type='fast', flatten=False): + super(SelectAdaptivePool3d, self).__init__() + self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + if pool_type == '': + self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool3d(flatten) + self.flatten = nn.Identity() + elif pool_type == 'avg': + self.pool = nn.AdaptiveAvgPool3d(output_size) + elif pool_type == 'avgmax': + self.pool = AdaptiveAvgMaxPool3d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool3d(output_size) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool3d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + + def is_identity(self): + return not self.pool_type + + def forward(self, x): + x = self.pool(x) + x = self.flatten(x) + return x + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + 'pool_type=' + self.pool_type \ + + ', flatten=' + str(self.flatten) + ')' + + +def create_pool3d_layer(name, **kwargs): + assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"] + if name != "gem": + pool1d_layer = SelectAdaptivePool3d(pool_type=name, flatten=True) + elif name == "gem": + pool1d_layer = GeM(dim=3, **kwargs) + return pool1d_layer \ No newline at end of file diff --git a/skp/models/rev_mvit/REV_MVIT_B_16_CONV.yaml b/skp/models/rev_mvit/REV_MVIT_B_16_CONV.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b119537e53048439b30152f3f33df7a82013cb8c --- /dev/null +++ b/skp/models/rev_mvit/REV_MVIT_B_16_CONV.yaml @@ -0,0 +1,109 @@ +TRAIN: + ENABLE: True + DATASET: imagenet + BATCH_SIZE: 256 + EVAL_PERIOD: 10 + CHECKPOINT_PERIOD: 1 + AUTO_RESUME: True + +DATA: + # PATH_TO_DATA_DIR: path-to-imagenet-dir + MEAN: [0.485, 0.456, 0.406] + STD: [0.229, 0.224, 0.225] + NUM_FRAMES: 64 + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 224 + INPUT_CHANNEL_NUM: [3] +MVIT: + PATCH_2D: False + ZERO_DECAY_POS_CLS: False + MODE: "conv" + CLS_EMBED_ON: False + PATCH_KERNEL: [3, 7, 7] + PATCH_STRIDE: [2, 4, 4] + PATCH_PADDING: [1, 3, 3] + EMBED_DIM: 96 + NUM_HEADS: 1 + MLP_RATIO: 4.0 + QKV_BIAS: True + DROPPATH_RATE: 0.1 + DROPOUT_RATE: 0.0 + DEPTH: 16 + LAYER_SCALE_INIT_VALUE: 0.0 + HEAD_INIT_SCALE: 1.0 + USE_MEAN_POOLING: False + USE_ABS_POS: True + USE_FIXED_SINCOS_POS: False + SEP_POS_EMBED: False + REL_POS_SPATIAL: False + REL_POS_TEMPORAL: False + REL_POS_ZERO_INIT: False + RESIDUAL_POOLING: False + NORM: "layernorm" + NORM_STEM: False + DIM_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] + HEAD_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] + POOL_FIRST: null + POOL_KVQ_KERNEL: [1, 3, 3] + POOL_KV_STRIDE_ADAPTIVE: [1, 4, 4] + POOL_Q_STRIDE: [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]] + SEPARATE_QKV : True + REV: + ENABLE: True + RESPATH_FUSE: "concat" + BUFFER_LAYERS : [1,3, 14] + RES_PATH : "conv" + PRE_Q_FUSION: "concat_linear_2" +DETECTION: + ENABLE: False +AUG: + ENABLE: True + COLOR_JITTER: 0.4 + AA_TYPE: rand-m9-n6-mstd0.5-inc1 + INTERPOLATION: bicubic + RE_PROB: 0.25 + RE_MODE: pixel + RE_COUNT: 1 + RE_SPLIT: False +MIXUP: + ENABLE: True + ALPHA: 0.8 + CUTMIX_ALPHA: 1.0 + PROB: 1.0 + SWITCH_PROB: 0.5 + LABEL_SMOOTH_VALUE: 0.1 +SOLVER: + BASE_LR_SCALE_NUM_SHARDS: True + BASE_LR: 0.00025 + LR_POLICY: cosine + MAX_EPOCH: 300 + MOMENTUM: 0.9 + WEIGHT_DECAY: 0.05 + WARMUP_EPOCHS: 70.0 + WARMUP_START_LR: 1e-8 + OPTIMIZING_METHOD: adamw + COSINE_AFTER_WARMUP: True + COSINE_END_LR: 1e-6 + ZERO_WD_1D_PARAM: True + CLIP_GRAD_L2NORM: 1.0 +MODEL: + NUM_CLASSES: 1000 + ARCH: mvit + MODEL_NAME: MViT + LOSS_FUNC: soft_cross_entropy + DROPOUT_RATE: 0.0 + HEAD_ACT: "softmax" + DETACH_FINAL_FC: False +CONTRASTIVE: + NUM_MLP_LAYERS: 1 +TEST: + ENABLE: False + DATASET: imagenet + BATCH_SIZE: 256 +DATA_LOADER: + NUM_WORKERS: 8 + PIN_MEMORY: True +NUM_GPUS: 2 +NUM_SHARDS: 1 +RNG_SEED: 0 +OUTPUT_DIR: . \ No newline at end of file diff --git a/skp/models/rev_mvit/__init__.py b/skp/models/rev_mvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/skp/models/rev_mvit/__pycache__/__init__.cpython-39.pyc b/skp/models/rev_mvit/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c2a2d2c36538f17bd959a6e57366147d33613d0 Binary files /dev/null and b/skp/models/rev_mvit/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/rev_mvit/__pycache__/attention.cpython-39.pyc b/skp/models/rev_mvit/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac5f9f290c0167f1ce3c14ee9f3ec518f354ef97 Binary files /dev/null and b/skp/models/rev_mvit/__pycache__/attention.cpython-39.pyc differ diff --git a/skp/models/rev_mvit/__pycache__/batchnorm_helper.cpython-39.pyc b/skp/models/rev_mvit/__pycache__/batchnorm_helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44357684d1a581b2826d35a99f02e7028c6e4ac8 Binary files /dev/null and b/skp/models/rev_mvit/__pycache__/batchnorm_helper.cpython-39.pyc differ diff --git a/skp/models/rev_mvit/__pycache__/common.cpython-39.pyc b/skp/models/rev_mvit/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f51b691defd12a275055aec9f9708cbfabb5380 Binary files /dev/null and b/skp/models/rev_mvit/__pycache__/common.cpython-39.pyc differ diff --git a/skp/models/rev_mvit/__pycache__/head_helper.cpython-39.pyc b/skp/models/rev_mvit/__pycache__/head_helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a288ca25d3aed2d3a03ab87147cfb2bbf07f3832 Binary files /dev/null and b/skp/models/rev_mvit/__pycache__/head_helper.cpython-39.pyc differ diff --git a/skp/models/rev_mvit/__pycache__/reversible_mvit.cpython-39.pyc b/skp/models/rev_mvit/__pycache__/reversible_mvit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b9faada08c5e066bd9617438c131d5ac92f2698 Binary files /dev/null and b/skp/models/rev_mvit/__pycache__/reversible_mvit.cpython-39.pyc differ diff --git a/skp/models/rev_mvit/__pycache__/stem_helper.cpython-39.pyc b/skp/models/rev_mvit/__pycache__/stem_helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18258f4874991fd0baa841ac27ce6325a136b7b8 Binary files /dev/null and b/skp/models/rev_mvit/__pycache__/stem_helper.cpython-39.pyc differ diff --git a/skp/models/rev_mvit/__pycache__/utils.cpython-39.pyc b/skp/models/rev_mvit/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41afb76336bf8a3aef9bf8b6cf852647286c82ea Binary files /dev/null and b/skp/models/rev_mvit/__pycache__/utils.cpython-39.pyc differ diff --git a/skp/models/rev_mvit/__pycache__/video_model_builder.cpython-39.pyc b/skp/models/rev_mvit/__pycache__/video_model_builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96b14a37aae4f15a82dd408807520dccf4b3b139 Binary files /dev/null and b/skp/models/rev_mvit/__pycache__/video_model_builder.cpython-39.pyc differ diff --git a/skp/models/rev_mvit/attention.py b/skp/models/rev_mvit/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..08070a7a1e333958f13764772198ab2f20cc9a4e --- /dev/null +++ b/skp/models/rev_mvit/attention.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + + +import numpy +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ + +from .common import DropPath, Mlp + + +def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): + if pool is None: + return tensor, thw_shape + tensor_dim = tensor.ndim + if tensor_dim == 4: + pass + elif tensor_dim == 3: + tensor = tensor.unsqueeze(1) + else: + raise NotImplementedError(f"Unsupported input dimension {tensor.shape}") + + if has_cls_embed: + cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] + + B, N, L, C = tensor.shape + T, H, W = thw_shape + tensor = ( + tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + ) + + tensor = pool(tensor) + + thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] + L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4] + tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) + if has_cls_embed: + tensor = torch.cat((cls_tok, tensor), dim=2) + if norm is not None: + tensor = norm(tensor) + # Assert tensor_dim in [3, 4] + if tensor_dim == 4: + pass + else: # tensor_dim == 3: + tensor = tensor.squeeze(1) + return tensor, thw_shape + + +def get_rel_pos(rel_pos, d): + if isinstance(d, int): + ori_d = rel_pos.shape[0] + if ori_d == d: + return rel_pos + else: + # Interpolate rel pos. + new_pos_embed = F.interpolate( + rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1), + size=d, + mode="linear", + ) + + return new_pos_embed.reshape(-1, d).permute(1, 0) + + +def cal_rel_pos_spatial( + attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h, rel_pos_w +): + """ + Decomposed Spatial Relative Positional Embeddings. + """ + sp_idx = 1 if has_cls_embed else 0 + q_t, q_h, q_w = q_shape + k_t, k_h, k_w = k_shape + dh = int(2 * max(q_h, k_h) - 1) + dw = int(2 * max(q_w, k_w) - 1) + + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = ( + torch.arange(q_h)[:, None] * q_h_ratio + - torch.arange(k_h)[None, :] * k_h_ratio + ) + dist_h += (k_h - 1) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = ( + torch.arange(q_w)[:, None] * q_w_ratio + - torch.arange(k_w)[None, :] * k_w_ratio + ) + dist_w += (k_w - 1) * k_w_ratio + + # Intepolate rel pos if needed. + rel_pos_h = get_rel_pos(rel_pos_h, dh) + rel_pos_w = get_rel_pos(rel_pos_w, dw) + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + + B, n_head, q_N, dim = q.shape + + r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) + rel_h_q = torch.einsum( + "bythwc,hkc->bythwk", r_q, Rh + ) # [B, H, q_t, qh, qw, k_h] + rel_w_q = torch.einsum( + "bythwc,wkc->bythwk", r_q, Rw + ) # [B, H, q_t, qh, qw, k_w] + + attn[:, :, sp_idx:, sp_idx:] = ( + attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) + + rel_h_q[:, :, :, :, :, None, :, None] + + rel_w_q[:, :, :, :, :, None, None, :] + ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) + + return attn + + +def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t): + """ + Temporal Relative Positional Embeddings. + """ + sp_idx = 1 if has_cls_embed else 0 + q_t, q_h, q_w = q_shape + k_t, k_h, k_w = k_shape + dt = int(2 * max(q_t, k_t) - 1) + # Intepolate rel pos if needed. + rel_pos_t = get_rel_pos(rel_pos_t, dt) + + # Scale up rel pos if shapes for q and k are different. + q_t_ratio = max(k_t / q_t, 1.0) + k_t_ratio = max(q_t / k_t, 1.0) + dist_t = ( + torch.arange(q_t)[:, None] * q_t_ratio + - torch.arange(k_t)[None, :] * k_t_ratio + ) + dist_t += (k_t - 1) * k_t_ratio + Rt = rel_pos_t[dist_t.long()] + + B, n_head, q_N, dim = q.shape + + r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) + # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] + r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape( + q_t, B * n_head * q_h * q_w, dim + ) + + # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] + rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) + # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] + rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) + + attn[:, :, sp_idx:, sp_idx:] = ( + attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) + + rel[:, :, :, :, :, :, None, None] + ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) + + return attn + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim, + dim_out, + input_size, + num_heads=8, + qkv_bias=False, + drop_rate=0.0, + kernel_q=(1, 1, 1), + kernel_kv=(1, 1, 1), + stride_q=(1, 1, 1), + stride_kv=(1, 1, 1), + norm_layer=nn.LayerNorm, + has_cls_embed=True, + # Options include `conv`, `avg`, and `max`. + mode="conv", + # If True, perform pool before projection. + pool_first=False, + rel_pos_spatial=False, + rel_pos_temporal=False, + rel_pos_zero_init=False, + residual_pooling=False, + separate_qkv=False, + ): + super().__init__() + self.pool_first = pool_first + self.separate_qkv = separate_qkv + self.drop_rate = drop_rate + self.num_heads = num_heads + self.dim_out = dim_out + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + self.has_cls_embed = has_cls_embed + self.mode = mode + padding_q = [int(q // 2) for q in kernel_q] + padding_kv = [int(kv // 2) for kv in kernel_kv] + + if pool_first or separate_qkv: + self.q = nn.Linear(dim, dim_out, bias=qkv_bias) + self.k = nn.Linear(dim, dim_out, bias=qkv_bias) + self.v = nn.Linear(dim, dim_out, bias=qkv_bias) + else: + self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) + + self.proj = nn.Linear(dim_out, dim_out) + if drop_rate > 0.0: + self.proj_drop = nn.Dropout(drop_rate) + + # Skip pooling with kernel and stride size of (1, 1, 1). + if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1: + kernel_q = () + if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1: + kernel_kv = () + + if mode in ("avg", "max"): + pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d + self.pool_q = ( + pool_op(kernel_q, stride_q, padding_q, ceil_mode=False) + if len(kernel_q) > 0 + else None + ) + self.pool_k = ( + pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) + if len(kernel_kv) > 0 + else None + ) + self.pool_v = ( + pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) + if len(kernel_kv) > 0 + else None + ) + elif mode == "conv" or mode == "conv_unshared": + if pool_first: + dim_conv = dim // num_heads if mode == "conv" else dim + else: + dim_conv = dim_out // num_heads if mode == "conv" else dim_out + self.pool_q = ( + nn.Conv3d( + dim_conv, + dim_conv, + kernel_q, + stride=stride_q, + padding=padding_q, + groups=dim_conv, + bias=False, + ) + if len(kernel_q) > 0 + else None + ) + self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None + self.pool_k = ( + nn.Conv3d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + if len(kernel_kv) > 0 + else None + ) + self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None + self.pool_v = ( + nn.Conv3d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + if len(kernel_kv) > 0 + else None + ) + self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None + else: + raise NotImplementedError(f"Unsupported model {mode}") + + self.rel_pos_spatial = rel_pos_spatial + self.rel_pos_temporal = rel_pos_temporal + if self.rel_pos_spatial: + assert input_size[1] == input_size[2] + size = input_size[1] + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + rel_sp_dim = 2 * max(q_size, kv_size) - 1 + + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + if not rel_pos_zero_init: + trunc_normal_(self.rel_pos_h, std=0.02) + trunc_normal_(self.rel_pos_w, std=0.02) + if self.rel_pos_temporal: + self.rel_pos_t = nn.Parameter( + torch.zeros(2 * input_size[0] - 1, head_dim) + ) + if not rel_pos_zero_init: + trunc_normal_(self.rel_pos_t, std=0.02) + + self.residual_pooling = residual_pooling + + def forward(self, x, thw_shape): + B, N, _ = x.shape + + if self.pool_first: + if self.mode == "conv_unshared": + fold_dim = 1 + else: + fold_dim = self.num_heads + x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3) + q = k = v = x + else: + assert self.mode != "conv_unshared" + if not self.separate_qkv: + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + else: + q = k = v = x + q = ( + self.q(q) + .reshape(B, N, self.num_heads, -1) + .permute(0, 2, 1, 3) + ) + k = ( + self.k(k) + .reshape(B, N, self.num_heads, -1) + .permute(0, 2, 1, 3) + ) + v = ( + self.v(v) + .reshape(B, N, self.num_heads, -1) + .permute(0, 2, 1, 3) + ) + + q, q_shape = attention_pool( + q, + self.pool_q, + thw_shape, + has_cls_embed=self.has_cls_embed, + norm=self.norm_q if hasattr(self, "norm_q") else None, + ) + k, k_shape = attention_pool( + k, + self.pool_k, + thw_shape, + has_cls_embed=self.has_cls_embed, + norm=self.norm_k if hasattr(self, "norm_k") else None, + ) + v, v_shape = attention_pool( + v, + self.pool_v, + thw_shape, + has_cls_embed=self.has_cls_embed, + norm=self.norm_v if hasattr(self, "norm_v") else None, + ) + + if self.pool_first: + q_N = ( + numpy.prod(q_shape) + 1 + if self.has_cls_embed + else numpy.prod(q_shape) + ) + k_N = ( + numpy.prod(k_shape) + 1 + if self.has_cls_embed + else numpy.prod(k_shape) + ) + v_N = ( + numpy.prod(v_shape) + 1 + if self.has_cls_embed + else numpy.prod(v_shape) + ) + + q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) + q = ( + self.q(q) + .reshape(B, q_N, self.num_heads, -1) + .permute(0, 2, 1, 3) + ) + + v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) + v = ( + self.v(v) + .reshape(B, v_N, self.num_heads, -1) + .permute(0, 2, 1, 3) + ) + + k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) + k = ( + self.k(k) + .reshape(B, k_N, self.num_heads, -1) + .permute(0, 2, 1, 3) + ) + + N = q.shape[2] + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.rel_pos_spatial: + attn = cal_rel_pos_spatial( + attn, + q, + k, + self.has_cls_embed, + q_shape, + k_shape, + self.rel_pos_h, + self.rel_pos_w, + ) + + if self.rel_pos_temporal: + attn = cal_rel_pos_temporal( + attn, + q, + self.has_cls_embed, + q_shape, + k_shape, + self.rel_pos_t, + ) + attn = attn.softmax(dim=-1) + + x = attn @ v + + if self.residual_pooling: + if self.has_cls_embed: + x[:, :, 1:, :] += q[:, :, 1:, :] + else: + x = x + q + + x = x.transpose(1, 2).reshape(B, -1, self.dim_out) + x = self.proj(x) + + if self.drop_rate > 0.0: + x = self.proj_drop(x) + return x, q_shape + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim, + dim_out, + num_heads, + input_size, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + drop_path=0.0, + layer_scale_init_value=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + up_rate=None, + kernel_q=(1, 1, 1), + kernel_kv=(1, 1, 1), + stride_q=(1, 1, 1), + stride_kv=(1, 1, 1), + mode="conv", + has_cls_embed=True, + pool_first=False, + rel_pos_spatial=False, + rel_pos_temporal=False, + rel_pos_zero_init=False, + residual_pooling=False, + dim_mul_in_att=False, + separate_qkv=False, + ): + super().__init__() + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + self.dim_mul_in_att = dim_mul_in_att + kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + stride_skip = stride_q + padding_skip = [int(skip // 2) for skip in kernel_skip] + att_dim = dim_out if dim_mul_in_att else dim + self.attn = MultiScaleAttention( + dim, + att_dim, + num_heads=num_heads, + input_size=input_size, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=norm_layer, + has_cls_embed=has_cls_embed, + mode=mode, + pool_first=pool_first, + rel_pos_spatial=rel_pos_spatial, + rel_pos_temporal=rel_pos_temporal, + rel_pos_zero_init=rel_pos_zero_init, + residual_pooling=residual_pooling, + separate_qkv=separate_qkv, + ) + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + self.norm2 = norm_layer(att_dim) + mlp_hidden_dim = int(att_dim * mlp_ratio) + self.has_cls_embed = has_cls_embed + # TODO: check the use case for up_rate, and merge the following lines + if up_rate is not None and up_rate > 1: + mlp_dim_out = dim * up_rate + else: + mlp_dim_out = dim_out + self.mlp = Mlp( + in_features=att_dim, + hidden_features=mlp_hidden_dim, + out_features=mlp_dim_out, + act_layer=act_layer, + drop_rate=drop_rate, + ) + if layer_scale_init_value > 0: + self.gamma_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True + ) + self.gamma_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim_out)), + requires_grad=True, + ) + else: + self.gamma_1, self.gamma_2 = None, None + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + self.pool_skip = ( + nn.MaxPool3d( + kernel_skip, stride_skip, padding_skip, ceil_mode=False + ) + if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1 + else None + ) + + def forward(self, x, thw_shape=None): + x_norm = self.norm1(x) + x_block, thw_shape_new = self.attn(x_norm, thw_shape) + if self.dim_mul_in_att and self.dim != self.dim_out: + x = self.proj(x_norm) + x_res, _ = attention_pool( + x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed + ) + if self.gamma_1 is not None: + x = x_res + self.drop_path(self.gamma_1 * x_block) + else: + x = x_res + self.drop_path(x_block) + x_norm = self.norm2(x) + x_mlp = self.mlp(x_norm) + if not self.dim_mul_in_att and self.dim != self.dim_out: + x = self.proj(x_norm) + if self.gamma_2 is not None: + x = x + self.drop_path(self.gamma_2 * x_mlp) + else: + x = x + self.drop_path(x_mlp) + if thw_shape: + return x, thw_shape_new + else: + return x \ No newline at end of file diff --git a/skp/models/rev_mvit/batchnorm_helper.py b/skp/models/rev_mvit/batchnorm_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..0028c5fe1dab9f4a266284813ad7c00de6a49102 --- /dev/null +++ b/skp/models/rev_mvit/batchnorm_helper.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""BatchNorm (BN) utility functions and custom batch-size BN implementations""" + +from functools import partial +import torch +import torch.nn as nn + +from pytorchvideo.layers.batch_norm import ( + NaiveSyncBatchNorm1d, + NaiveSyncBatchNorm3d, +) # noqa + + +def get_norm(cfg): + """ + Args: + cfg (CfgNode): model building configs, details are in the comments of + the config file. + Returns: + nn.Module: the normalization layer. + """ + if cfg.BN.NORM_TYPE in {"batchnorm", "sync_batchnorm_apex"}: + return nn.BatchNorm3d + elif cfg.BN.NORM_TYPE == "sub_batchnorm": + return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS) + elif cfg.BN.NORM_TYPE == "sync_batchnorm": + return partial( + NaiveSyncBatchNorm3d, + num_sync_devices=cfg.BN.NUM_SYNC_DEVICES, + global_sync=cfg.BN.GLOBAL_SYNC, + ) + else: + raise NotImplementedError( + "Norm type {} is not supported".format(cfg.BN.NORM_TYPE) + ) + + +class SubBatchNorm3d(nn.Module): + """ + The standard BN layer computes stats across all examples in a GPU. In some + cases it is desirable to compute stats across only a subset of examples + (e.g., in multigrid training https://arxiv.org/abs/1912.00998). + SubBatchNorm3d splits the batch dimension into N splits, and run BN on + each of them separately (so that the stats are computed on each subset of + examples (1/N of batch) independently. During evaluation, it aggregates + the stats from all splits into one BN. + """ + + def __init__(self, num_splits, **args): + """ + Args: + num_splits (int): number of splits. + args (list): other arguments. + """ + super(SubBatchNorm3d, self).__init__() + self.num_splits = num_splits + num_features = args["num_features"] + # Keep only one set of weight and bias. + if args.get("affine", True): + self.affine = True + args["affine"] = False + self.weight = torch.nn.Parameter(torch.ones(num_features)) + self.bias = torch.nn.Parameter(torch.zeros(num_features)) + else: + self.affine = False + self.bn = nn.BatchNorm3d(**args) + args["num_features"] = num_features * num_splits + self.split_bn = nn.BatchNorm3d(**args) + + def _get_aggregated_mean_std(self, means, stds, n): + """ + Calculate the aggregated mean and stds. + Args: + means (tensor): mean values. + stds (tensor): standard deviations. + n (int): number of sets of means and stds. + """ + mean = means.view(n, -1).sum(0) / n + std = ( + stds.view(n, -1).sum(0) / n + + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n + ) + return mean.detach(), std.detach() + + def aggregate_stats(self): + """ + Synchronize running_mean, and running_var. Call this before eval. + """ + if self.split_bn.track_running_stats: + ( + self.bn.running_mean.data, + self.bn.running_var.data, + ) = self._get_aggregated_mean_std( + self.split_bn.running_mean, + self.split_bn.running_var, + self.num_splits, + ) + + def forward(self, x): + if self.training: + n, c, t, h, w = x.shape + x = x.view(n // self.num_splits, c * self.num_splits, t, h, w) + x = self.split_bn(x) + x = x.view(n, c, t, h, w) + else: + x = self.bn(x) + if self.affine: + x = x * self.weight.view((-1, 1, 1, 1)) + x = x + self.bias.view((-1, 1, 1, 1)) + return x \ No newline at end of file diff --git a/skp/models/rev_mvit/common.py b/skp/models/rev_mvit/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ed249cd8799784e7817ad2e48731c3a37f8f06e9 --- /dev/null +++ b/skp/models/rev_mvit/common.py @@ -0,0 +1,154 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop_rate=0.0, + ): + super().__init__() + self.drop_rate = drop_rate + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + if self.drop_rate > 0.0: + self.drop = nn.Dropout(drop_rate) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + if self.drop_rate > 0.0: + x = self.drop(x) + x = self.fc2(x) + if self.drop_rate > 0.0: + x = self.drop(x) + return x + + +class Permute(nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, x): + return x.permute(*self.dims) + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """ + Stochastic Depth per sample. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + mask.floor_() # binarize + output = x.div(keep_prob) * mask + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class TwoStreamFusion(nn.Module): + def __init__(self, mode, dim=None, kernel=3, padding=1): + """ + A general constructor for neural modules fusing two equal sized tensors + in forward. Following options are supported: + + "add" / "max" / "min" / "avg" : respective operations on the two halves. + "concat" : NOOP. + "concat_linear_{dim_mult}_{drop_rate}" : MLP to fuse with hidden dim "dim_mult" + (optional, def 1.) higher than input dim + with optional dropout "drop_rate" (def: 0.) + "ln+concat_linear_{dim_mult}_{drop_rate}" : perform MLP after layernorm on the input. + + """ + super().__init__() + self.mode = mode + if mode == "add": + self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).sum( + dim=0 + ) + elif mode == "max": + self.fuse_fn = ( + lambda x: torch.stack(torch.chunk(x, 2, dim=2)) + .max(dim=0) + .values + ) + elif mode == "min": + self.fuse_fn = ( + lambda x: torch.stack(torch.chunk(x, 2, dim=2)) + .min(dim=0) + .values + ) + elif mode == "avg": + self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).mean( + dim=0 + ) + elif mode == "concat": + # x itself is the channel concat version + self.fuse_fn = lambda x: x + elif "concat_linear" in mode: + if len(mode.split("_")) == 2: + dim_mult = 1.0 + drop_rate = 0.0 + elif len(mode.split("_")) == 3: + dim_mult = float(mode.split("_")[-1]) + drop_rate = 0.0 + + elif len(mode.split("_")) == 4: + dim_mult = float(mode.split("_")[-2]) + drop_rate = float(mode.split("_")[-1]) + else: + raise NotImplementedError + + if mode.split("+")[0] == "ln": + self.fuse_fn = nn.Sequential( + nn.LayerNorm(dim), + Mlp( + in_features=dim, + hidden_features=int(dim * dim_mult), + act_layer=nn.GELU, + out_features=dim, + drop_rate=drop_rate, + ), + ) + else: + self.fuse_fn = Mlp( + in_features=dim, + hidden_features=int(dim * dim_mult), + act_layer=nn.GELU, + out_features=dim, + drop_rate=drop_rate, + ) + + else: + raise NotImplementedError + + def forward(self, x): + if "concat_linear" in self.mode: + return self.fuse_fn(x) + x + + else: + return self.fuse_fn(x) \ No newline at end of file diff --git a/skp/models/rev_mvit/head_helper.py b/skp/models/rev_mvit/head_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d9abc39ce7b74f7a84192e0e73be43f6e26f9e --- /dev/null +++ b/skp/models/rev_mvit/head_helper.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""ResNe(X)t Head helper.""" + +import torch +import torch.nn as nn + +from .batchnorm_helper import ( + NaiveSyncBatchNorm1d as NaiveSyncBatchNorm1d, +) + + +class MLPHead(nn.Module): + def __init__( + self, + dim_in, + dim_out, + mlp_dim, + num_layers, + bn_on=False, + bias=True, + flatten=False, + xavier_init=True, + bn_sync_num=1, + global_sync=False, + ): + super(MLPHead, self).__init__() + self.flatten = flatten + b = False if bn_on else bias + # assert bn_on or bn_sync_num=1 + mlp_layers = [nn.Linear(dim_in, mlp_dim, bias=b)] + mlp_layers[-1].xavier_init = xavier_init + for i in range(1, num_layers): + if bn_on: + if global_sync or bn_sync_num > 1: + mlp_layers.append( + NaiveSyncBatchNorm1d( + num_sync_devices=bn_sync_num, + global_sync=global_sync, + num_features=mlp_dim, + ) + ) + else: + mlp_layers.append(nn.BatchNorm1d(num_features=mlp_dim)) + mlp_layers.append(nn.ReLU(inplace=True)) + if i == num_layers - 1: + d = dim_out + b = bias + else: + d = mlp_dim + mlp_layers.append(nn.Linear(mlp_dim, d, bias=b)) + mlp_layers[-1].xavier_init = xavier_init + self.projection = nn.Sequential(*mlp_layers) + + def forward(self, x): + if x.ndim == 5: + x = x.permute((0, 2, 3, 4, 1)) + if self.flatten: + x = x.reshape(-1, x.shape[-1]) + + return self.projection(x) + + +class TransformerBasicHead(nn.Module): + """ + BasicHead. No pool. + """ + + def __init__( + self, + dim_in, + num_classes, + dropout_rate=0.0, + act_func="softmax", + cfg=None, + ): + """ + Perform linear projection and activation as head for tranformers. + Args: + dim_in (int): the channel dimension of the input to the head. + num_classes (int): the channel dimensions of the output to the head. + dropout_rate (float): dropout rate. If equal to 0.0, perform no + dropout. + act_func (string): activation function to use. 'softmax': applies + softmax on the output. 'sigmoid': applies sigmoid on the output. + """ + super(TransformerBasicHead, self).__init__() + if dropout_rate > 0.0: + self.dropout = nn.Dropout(dropout_rate) + self.projection = nn.Linear(dim_in, num_classes, bias=True) + + if cfg.CONTRASTIVE.NUM_MLP_LAYERS == 1: + self.projection = nn.Linear(dim_in, num_classes, bias=True) + else: + self.projection = MLPHead( + dim_in, + num_classes, + cfg.CONTRASTIVE.MLP_DIM, + cfg.CONTRASTIVE.NUM_MLP_LAYERS, + bn_on=cfg.CONTRASTIVE.BN_MLP, + bn_sync_num=cfg.BN.NUM_SYNC_DEVICES + if cfg.CONTRASTIVE.BN_SYNC_MLP + else 1, + global_sync=( + cfg.CONTRASTIVE.BN_SYNC_MLP and cfg.BN.GLOBAL_SYNC + ), + ) + self.detach_final_fc = cfg.MODEL.DETACH_FINAL_FC + + # Softmax for evaluation and testing. + if act_func == "softmax": + self.act = nn.Softmax(dim=1) + elif act_func == "sigmoid": + self.act = nn.Sigmoid() + elif act_func == "none": + self.act = None + else: + raise NotImplementedError( + "{} is not supported as an activation" + "function.".format(act_func) + ) + + def forward(self, x): + if hasattr(self, "dropout"): + x = self.dropout(x) + if self.detach_final_fc: + x = x.detach() + x = self.projection(x) + + if not self.training: + if self.act is not None: + x = self.act(x) + # Performs fully convolutional inference. + if x.ndim == 5 and x.shape[1:4] > torch.Size([1, 1, 1]): + x = x.mean([1, 2, 3]) + + x = x.view(x.shape[0], -1) + + return x diff --git a/skp/models/rev_mvit/reversible_mvit.py b/skp/models/rev_mvit/reversible_mvit.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad9f43649ebc49e275aa1c244f8135a0dbcf9d3 --- /dev/null +++ b/skp/models/rev_mvit/reversible_mvit.py @@ -0,0 +1,696 @@ +import sys +from functools import partial +import torch +from torch import nn +from torch.autograd import Function as Function + +from .attention import MultiScaleAttention, attention_pool +from .common import Mlp, TwoStreamFusion, drop_path +from .utils import round_width + + +class ReversibleMViT(nn.Module): + """ + Reversible model builder. This builds the reversible transformer encoder + and allows reversible training. + + Karttikeya Mangalam, Haoqi Fan, Yanghao Li, Chao-Yuan Wu, Bo Xiong, + Christoph Feichtenhofer, Jitendra Malik + "Reversible Vision Transformers" + + https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf + """ + + def __init__(self, config, model): + """ + The `__init__` method of any subclass should also contain these + arguments. + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + model (nn.Module): parent MViT module this module forms + a reversible encoder in. + """ + + super().__init__() + self.cfg = config + + embed_dim = self.cfg.MVIT.EMBED_DIM + depth = self.cfg.MVIT.DEPTH + num_heads = self.cfg.MVIT.NUM_HEADS + mlp_ratio = self.cfg.MVIT.MLP_RATIO + qkv_bias = self.cfg.MVIT.QKV_BIAS + + drop_path_rate = self.cfg.MVIT.DROPPATH_RATE + self.dropout = config.MVIT.DROPOUT_RATE + self.pre_q_fusion = self.cfg.MVIT.REV.PRE_Q_FUSION + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + input_size = model.patch_dims + + self.layers = nn.ModuleList([]) + self.no_custom_backward = False + + if self.cfg.MVIT.NORM == "layernorm": + norm_layer = partial(nn.LayerNorm, eps=1e-6) + else: + raise NotImplementedError("Only supports layernorm.") + + dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) + for i in range(len(self.cfg.MVIT.DIM_MUL)): + dim_mul[self.cfg.MVIT.DIM_MUL[i][0]] = self.cfg.MVIT.DIM_MUL[i][1] + for i in range(len(self.cfg.MVIT.HEAD_MUL)): + head_mul[self.cfg.MVIT.HEAD_MUL[i][0]] = self.cfg.MVIT.HEAD_MUL[i][ + 1 + ] + + pool_q = model.pool_q + pool_kv = model.pool_kv + stride_q = model.stride_q + stride_kv = model.stride_kv + + for i in range(depth): + + num_heads = round_width(num_heads, head_mul[i]) + + # Upsampling inside the MHPA, input to the Q-pooling block is lower C dimension + # This localizes the feature changes in a single block, making more computation reversible. + embed_dim = round_width( + embed_dim, dim_mul[i - 1] if i > 0 else 1.0, divisor=num_heads + ) + dim_out = round_width( + embed_dim, + dim_mul[i], + divisor=round_width(num_heads, head_mul[i + 1]), + ) + + if i in self.cfg.MVIT.REV.BUFFER_LAYERS: + layer_type = StageTransitionBlock + input_mult = 2 if "concat" in self.pre_q_fusion else 1 + else: + layer_type = ReversibleBlock + input_mult = 1 + + dimout_correction = ( + 2 if (input_mult == 2 and "concat" in self.pre_q_fusion) else 1 + ) + + self.layers.append( + layer_type( + dim=embed_dim + * input_mult, # added only for concat fusion before Qpooling layers + input_size=input_size, + dim_out=dim_out * input_mult // dimout_correction, + num_heads=num_heads, + cfg=self.cfg, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + kernel_q=pool_q[i] if len(pool_q) > i else [], + kernel_kv=pool_kv[i] if len(pool_kv) > i else [], + stride_q=stride_q[i] if len(stride_q) > i else [], + stride_kv=stride_kv[i] if len(stride_kv) > i else [], + layer_id=i, + pre_q_fusion=self.pre_q_fusion, + ) + ) + # F is the attention block + self.layers[-1].F.thw = input_size + + if len(stride_q[i]) > 0: + input_size = [ + size // stride + for size, stride in zip(input_size, stride_q[i]) + ] + + embed_dim = dim_out + + @staticmethod + def vanilla_backward(h, layers, buffer): + """ + Using rev layers without rev backpropagation. Debugging purposes only. + Activated with self.no_custom_backward. + """ + + # split into hidden states (h) and attention_output (a) + h, a = torch.chunk(h, 2, dim=-1) + for _, layer in enumerate(layers): + a, h = layer(a, h) + + return torch.cat([a, h], dim=-1) + + def forward(self, x): + + # process the layers in a reversible stack and an irreversible stack. + stack = [] + for l_i in range(len(self.layers)): + if isinstance(self.layers[l_i], StageTransitionBlock): + stack.append(("StageTransition", l_i)) + else: + if len(stack) == 0 or stack[-1][0] == "StageTransition": + stack.append(("Reversible", [])) + stack[-1][1].append(l_i) + + for layer_seq in stack: + + if layer_seq[0] == "StageTransition": + x = self.layers[layer_seq[1]](x) + + else: + x = torch.cat([x, x], dim=-1) + + # no need for custom backprop in eval/model stat log + if not self.training or self.no_custom_backward: + executing_fn = ReversibleMViT.vanilla_backward + else: + executing_fn = RevBackProp.apply + + x = executing_fn( + x, + self.layers[layer_seq[1][0] : layer_seq[1][-1] + 1], + [], # buffer activations + ) + + # Apply dropout + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + + return x + + +class RevBackProp(Function): + """ + Custom Backpropagation function to allow (A) flusing memory in foward + and (B) activation recomputation reversibly in backward for gradient calculation. + + Inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py + """ + + @staticmethod + def forward( + ctx, + x, + layers, + buffer_layers, # List of layer ids for int activation to buffer + ): + """ + Reversible Forward pass. Any intermediate activations from `buffer_layers` are + cached in ctx for forward pass. This is not necessary for standard usecases. + Each reversible layer implements its own forward pass logic. + """ + buffer_layers.sort() + + X_1, X_2 = torch.chunk(x, 2, dim=-1) + + intermediate = [] + + for layer in layers: + + X_1, X_2 = layer(X_1, X_2) + + if layer.layer_id in buffer_layers: + intermediate.extend([X_1.detach(), X_2.detach()]) + + if len(buffer_layers) == 0: + all_tensors = [X_1.detach(), X_2.detach()] + else: + intermediate = [torch.LongTensor(buffer_layers), *intermediate] + all_tensors = [X_1.detach(), X_2.detach(), *intermediate] + + ctx.save_for_backward(*all_tensors) + ctx.layers = layers + + return torch.cat([X_1, X_2], dim=-1) + + @staticmethod + def backward(ctx, dx): + """ + Reversible Backward pass. Any intermediate activations from `buffer_layers` are + recovered from ctx. Each layer implements its own loic for backward pass (both + activation recomputation and grad calculation). + """ + dX_1, dX_2 = torch.chunk(dx, 2, dim=-1) + + # retrieve params from ctx for backward + X_1, X_2, *int_tensors = ctx.saved_tensors + + # no buffering + if len(int_tensors) != 0: + buffer_layers = int_tensors[0].tolist() + + else: + buffer_layers = [] + + layers = ctx.layers + + for _, layer in enumerate(layers[::-1]): + + if layer.layer_id in buffer_layers: + + X_1, X_2, dX_1, dX_2 = layer.backward_pass( + Y_1=int_tensors[ + buffer_layers.index(layer.layer_id) * 2 + 1 + ], + Y_2=int_tensors[ + buffer_layers.index(layer.layer_id) * 2 + 2 + ], + dY_1=dX_1, + dY_2=dX_2, + ) + + else: + + X_1, X_2, dX_1, dX_2 = layer.backward_pass( + Y_1=X_1, + Y_2=X_2, + dY_1=dX_1, + dY_2=dX_2, + ) + + dx = torch.cat([dX_1, dX_2], dim=-1) + + del int_tensors + del dX_1, dX_2, X_1, X_2 + + return dx, None, None + + +class StageTransitionBlock(nn.Module): + """ + Blocks for changing the feature dimensions in MViT (using Q-pooling). + See Section 3.3.1 in paper for details. + """ + + def __init__( + self, + dim, + input_size, + dim_out, + num_heads, + mlp_ratio, + qkv_bias, + drop_path, + kernel_q, + kernel_kv, + stride_q, + stride_kv, + cfg, + norm_layer=nn.LayerNorm, + pre_q_fusion=None, + layer_id=0, + ): + """ + Uses the same structure of F and G functions as Reversible Block except + without using reversible forward (and backward) pass. + """ + super().__init__() + + self.drop_path_rate = drop_path + + embed_dim = dim + + self.F = AttentionSubBlock( + dim=embed_dim, + input_size=input_size, + num_heads=num_heads, + cfg=cfg, + dim_out=dim_out, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=norm_layer, + ) + + self.G = MLPSubblock( + dim=dim_out, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + + self.layer_id = layer_id + + self.is_proj = False + self.has_cls_embed = cfg.MVIT.CLS_EMBED_ON + + self.is_conv = False + self.pool_first = cfg.MVIT.POOL_FIRST + self.mode = cfg.MVIT.MODE + self.pre_q_fuse = TwoStreamFusion(pre_q_fusion, dim=dim) + + if cfg.MVIT.REV.RES_PATH == "max": + self.res_conv = False + self.pool_skip = nn.MaxPool3d( + # self.attention.attn.pool_q.kernel_size, + [s + 1 if s > 1 else s for s in self.F.attn.pool_q.stride], + self.F.attn.pool_q.stride, + [int(k // 2) for k in self.F.attn.pool_q.stride], + # self.attention.attn.pool_q.padding, + ceil_mode=False, + ) + + elif cfg.MVIT.REV.RES_PATH == "conv": + self.res_conv = True + else: + raise NotImplementedError + + # Add a linear projection in residual branch + if embed_dim != dim_out: + self.is_proj = True + self.res_proj = nn.Linear(embed_dim, dim_out, bias=True) + + def forward( + self, + x, + ): + """ + Forward logic is similar to MultiScaleBlock with Q-pooling. + """ + x = self.pre_q_fuse(x) + + # fork tensor for residual connections + x_res = x + + # This uses conv to pool the residual hidden features + # but done before pooling only if not pool_first + if self.is_proj and not self.pool_first: + x_res = self.res_proj(x_res) + + if self.res_conv: + + # Pooling the hidden features with the same conv as Q + N, L, C = x_res.shape + + # This handling is the same as that of q in MultiScaleAttention + if self.mode == "conv_unshared": + fold_dim = 1 + else: + fold_dim = self.F.attn.num_heads + + # Output is (B, N, L, C) + x_res = x_res.reshape(N, L, fold_dim, C // fold_dim).permute( + 0, 2, 1, 3 + ) + + x_res, _ = attention_pool( + x_res, + self.F.attn.pool_q, + # thw_shape = self.attention.attn.thw, + thw_shape=self.F.thw, + has_cls_embed=self.has_cls_embed, + norm=self.F.attn.norm_q + if hasattr(self.F.attn, "norm_q") + else None, + ) + x_res = x_res.permute(0, 2, 1, 3).reshape(N, x_res.shape[2], C) + + else: + # Pooling the hidden features with max op + x_res, _ = attention_pool( + x_res, + self.pool_skip, + thw_shape=self.F.attn.thw, + has_cls_embed=self.has_cls_embed, + ) + + # If pool_first then project to higher dim now + if self.is_proj and self.pool_first: + x_res = self.res_proj(x_res) + + x = self.F(x) + x = x_res + x + x = x + self.G(x) + + x = drop_path(x, drop_prob=self.drop_path_rate, training=self.training) + + return x + + +class ReversibleBlock(nn.Module): + """ + Reversible Blocks for Reversible Vision Transformer and also + for state-preserving blocks in Reversible MViT. See Section + 3.3.2 in paper for details. + """ + + def __init__( + self, + dim, + input_size, + dim_out, + num_heads, + mlp_ratio, + qkv_bias, + drop_path, + kernel_q, + kernel_kv, + stride_q, + stride_kv, + cfg, + norm_layer=nn.LayerNorm, + layer_id=0, + **kwargs + ): + """ + Block is composed entirely of function F (Attention + sub-block) and G (MLP sub-block) including layernorm. + """ + super().__init__() + + self.drop_path_rate = drop_path + + self.F = AttentionSubBlock( + dim=dim, + input_size=input_size, + num_heads=num_heads, + cfg=cfg, + dim_out=dim_out, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=norm_layer, + ) + + self.G = MLPSubblock( + dim=dim, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + + self.layer_id = layer_id + + self.seeds = {} + + def seed_cuda(self, key): + """ + Fix seeds to allow for stochastic elements such as + dropout to be reproduced exactly in activation + recomputation in the backward pass. + """ + + # randomize seeds + # use cuda generator if available + if ( + hasattr(torch.cuda, "default_generators") + and len(torch.cuda.default_generators) > 0 + ): + # GPU + device_idx = torch.cuda.current_device() + seed = torch.cuda.default_generators[device_idx].seed() + else: + # CPU + seed = int(torch.seed() % sys.maxsize) + + self.seeds[key] = seed + torch.manual_seed(self.seeds[key]) + + def forward(self, X_1, X_2): + """ + forward pass equations: + Y_1 = X_1 + Attention(X_2), F = Attention + Y_2 = X_2 + MLP(Y_1), G = MLP + """ + + self.seed_cuda("attn") + # Y_1 : attn_output + f_X_2 = self.F(X_2) + + self.seed_cuda("droppath") + f_X_2_dropped = drop_path( + f_X_2, drop_prob=self.drop_path_rate, training=self.training + ) + + # Y_1 = X_1 + f(X_2) + Y_1 = X_1 + f_X_2_dropped + + # free memory + del X_1 + + self.seed_cuda("FFN") + g_Y_1 = self.G(Y_1) + + torch.manual_seed(self.seeds["droppath"]) + g_Y_1_dropped = drop_path( + g_Y_1, drop_prob=self.drop_path_rate, training=self.training + ) + + # Y_2 = X_2 + g(Y_1) + Y_2 = X_2 + g_Y_1_dropped + + del X_2 + + return Y_1, Y_2 + + def backward_pass( + self, + Y_1, + Y_2, + dY_1, + dY_2, + ): + """ + equation for activation recomputation: + X_2 = Y_2 - G(Y_1), G = MLP + X_1 = Y_1 - F(X_2), F = Attention + """ + + # temporarily record intermediate activation for G + # and use them for gradient calculcation of G + with torch.enable_grad(): + + Y_1.requires_grad = True + + torch.manual_seed(self.seeds["FFN"]) + g_Y_1 = self.G(Y_1) + + torch.manual_seed(self.seeds["droppath"]) + g_Y_1 = drop_path( + g_Y_1, drop_prob=self.drop_path_rate, training=self.training + ) + + g_Y_1.backward(dY_2, retain_graph=True) + + # activation recomputation is by design and not part of + # the computation graph in forward pass. + with torch.no_grad(): + + X_2 = Y_2 - g_Y_1 + del g_Y_1 + + dY_1 = dY_1 + Y_1.grad + Y_1.grad = None + + # record F activations and calc gradients on F + with torch.enable_grad(): + X_2.requires_grad = True + + torch.manual_seed(self.seeds["attn"]) + f_X_2 = self.F(X_2) + + torch.manual_seed(self.seeds["droppath"]) + f_X_2 = drop_path( + f_X_2, drop_prob=self.drop_path_rate, training=self.training + ) + + f_X_2.backward(dY_1, retain_graph=True) + + # propagate reverse computed acitvations at the start of + # the previou block for backprop.s + with torch.no_grad(): + + X_1 = Y_1 - f_X_2 + + del f_X_2, Y_1 + dY_2 = dY_2 + X_2.grad + + X_2.grad = None + X_2 = X_2.detach() + + return X_1, X_2, dY_1, dY_2 + + +class MLPSubblock(nn.Module): + """ + This creates the function G such that the entire block can be + expressed as F(G(X)). Includes pre-LayerNorm. + """ + + def __init__( + self, + dim, + mlp_ratio, + norm_layer=nn.LayerNorm, + ): + + super().__init__() + self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True) + + mlp_hidden_dim = int(dim * mlp_ratio) + + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=nn.GELU, + ) + + def forward(self, x): + return self.mlp(self.norm(x)) + + +class AttentionSubBlock(nn.Module): + """ + This creates the function F such that the entire block can be + expressed as F(G(X)). Includes pre-LayerNorm. + """ + + def __init__( + self, + dim, + input_size, + num_heads, + cfg, + dim_out=None, + kernel_q=(1, 1, 1), + kernel_kv=(1, 1, 1), + stride_q=(1, 1, 1), + stride_kv=(1, 1, 1), + norm_layer=nn.LayerNorm, + ): + + super().__init__() + self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True) + + # This will be set externally during init + self.thw = None + + # the actual attention details are the same as Multiscale + # attention for MViTv2 (with channel up=projection inside block) + # can also implement no upprojection attention for vanilla ViT + self.attn = MultiScaleAttention( + dim, + dim_out, + input_size=input_size, + num_heads=num_heads, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=norm_layer, + drop_rate=cfg.MVIT.DROPOUT_RATE, + qkv_bias=cfg.MVIT.QKV_BIAS, + has_cls_embed=cfg.MVIT.CLS_EMBED_ON, + mode=cfg.MVIT.MODE, + pool_first=cfg.MVIT.POOL_FIRST, + rel_pos_spatial=cfg.MVIT.REL_POS_SPATIAL, + rel_pos_temporal=cfg.MVIT.REL_POS_TEMPORAL, + rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT, + residual_pooling=cfg.MVIT.RESIDUAL_POOLING, + separate_qkv=cfg.MVIT.SEPARATE_QKV, + ) + + def forward(self, x): + out, _ = self.attn(self.norm(x), self.thw) + return out \ No newline at end of file diff --git a/skp/models/rev_mvit/stem_helper.py b/skp/models/rev_mvit/stem_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8b2020a3dc8cf611118a900de8783e91d5921f --- /dev/null +++ b/skp/models/rev_mvit/stem_helper.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""ResNe(X)t 3D stem helper.""" + +import torch +import torch.nn as nn + + +def get_stem_func(name): + """ + Retrieves the stem module by name. + """ + trans_funcs = {"x3d_stem": X3DStem, "basic_stem": ResNetBasicStem} + assert ( + name in trans_funcs.keys() + ), "Transformation function '{}' not supported".format(name) + return trans_funcs[name] + + +class VideoModelStem(nn.Module): + """ + Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool + on input data tensor for one or multiple pathways. + """ + + def __init__( + self, + dim_in, + dim_out, + kernel, + stride, + padding, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + norm_module=nn.BatchNorm3d, + stem_func_name="basic_stem", + ): + """ + The `__init__` method of any subclass should also contain these + arguments. List size of 1 for single pathway models (C2D, I3D, Slow + and etc), list size of 2 for two pathway models (SlowFast). + + Args: + dim_in (list): the list of channel dimensions of the inputs. + dim_out (list): the output dimension of the convolution in the stem + layer. + kernel (list): the kernels' size of the convolutions in the stem + layers. Temporal kernel size, height kernel size, width kernel + size in order. + stride (list): the stride sizes of the convolutions in the stem + layer. Temporal kernel stride, height kernel size, width kernel + size in order. + padding (list): the paddings' sizes of the convolutions in the stem + layer. Temporal padding size, height padding size, width padding + size in order. + inplace_relu (bool): calculate the relu on the original input + without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + stem_func_name (string): name of the the stem function applied on + input to the network. + """ + super(VideoModelStem, self).__init__() + + assert ( + len( + { + len(dim_in), + len(dim_out), + len(kernel), + len(stride), + len(padding), + } + ) + == 1 + ), "Input pathway dimensions are not consistent. {} {} {} {} {}".format( + len(dim_in), + len(dim_out), + len(kernel), + len(stride), + len(padding), + ) + + self.num_pathways = len(dim_in) + self.kernel = kernel + self.stride = stride + self.padding = padding + self.inplace_relu = inplace_relu + self.eps = eps + self.bn_mmt = bn_mmt + # Construct the stem layer. + self._construct_stem(dim_in, dim_out, norm_module, stem_func_name) + + def _construct_stem(self, dim_in, dim_out, norm_module, stem_func_name): + trans_func = get_stem_func(stem_func_name) + + for pathway in range(len(dim_in)): + stem = trans_func( + dim_in[pathway], + dim_out[pathway], + self.kernel[pathway], + self.stride[pathway], + self.padding[pathway], + self.inplace_relu, + self.eps, + self.bn_mmt, + norm_module, + ) + self.add_module("pathway{}_stem".format(pathway), stem) + + def forward(self, x): + assert ( + len(x) == self.num_pathways + ), "Input tensor does not contain {} pathway".format(self.num_pathways) + # use a new list, don't modify in-place the x list, which is bad for activation checkpointing. + y = [] + for pathway in range(len(x)): + m = getattr(self, "pathway{}_stem".format(pathway)) + y.append(m(x[pathway])) + return y + + +class ResNetBasicStem(nn.Module): + """ + ResNe(X)t 3D stem module. + Performs spatiotemporal Convolution, BN, and Relu following by a + spatiotemporal pooling. + """ + + def __init__( + self, + dim_in, + dim_out, + kernel, + stride, + padding, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + norm_module=nn.BatchNorm3d, + ): + """ + The `__init__` method of any subclass should also contain these arguments. + + Args: + dim_in (int): the channel dimension of the input. Normally 3 is used + for rgb input, and 2 or 3 is used for optical flow input. + dim_out (int): the output dimension of the convolution in the stem + layer. + kernel (list): the kernel size of the convolution in the stem layer. + temporal kernel size, height kernel size, width kernel size in + order. + stride (list): the stride size of the convolution in the stem layer. + temporal kernel stride, height kernel size, width kernel size in + order. + padding (int): the padding size of the convolution in the stem + layer, temporal padding size, height padding size, width + padding size in order. + inplace_relu (bool): calculate the relu on the original input + without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(ResNetBasicStem, self).__init__() + self.kernel = kernel + self.stride = stride + self.padding = padding + self.inplace_relu = inplace_relu + self.eps = eps + self.bn_mmt = bn_mmt + # Construct the stem layer. + self._construct_stem(dim_in, dim_out, norm_module) + + def _construct_stem(self, dim_in, dim_out, norm_module): + self.conv = nn.Conv3d( + dim_in, + dim_out, + self.kernel, + stride=self.stride, + padding=self.padding, + bias=False, + ) + self.bn = norm_module( + num_features=dim_out, eps=self.eps, momentum=self.bn_mmt + ) + self.relu = nn.ReLU(self.inplace_relu) + self.pool_layer = nn.MaxPool3d( + kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] + ) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.pool_layer(x) + return x + + +class X3DStem(nn.Module): + """ + X3D's 3D stem module. + Performs a spatial followed by a depthwise temporal Convolution, BN, and Relu following by a + spatiotemporal pooling. + """ + + def __init__( + self, + dim_in, + dim_out, + kernel, + stride, + padding, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + norm_module=nn.BatchNorm3d, + ): + """ + The `__init__` method of any subclass should also contain these arguments. + + Args: + dim_in (int): the channel dimension of the input. Normally 3 is used + for rgb input, and 2 or 3 is used for optical flow input. + dim_out (int): the output dimension of the convolution in the stem + layer. + kernel (list): the kernel size of the convolution in the stem layer. + temporal kernel size, height kernel size, width kernel size in + order. + stride (list): the stride size of the convolution in the stem layer. + temporal kernel stride, height kernel size, width kernel size in + order. + padding (int): the padding size of the convolution in the stem + layer, temporal padding size, height padding size, width + padding size in order. + inplace_relu (bool): calculate the relu on the original input + without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(X3DStem, self).__init__() + self.kernel = kernel + self.stride = stride + self.padding = padding + self.inplace_relu = inplace_relu + self.eps = eps + self.bn_mmt = bn_mmt + # Construct the stem layer. + self._construct_stem(dim_in, dim_out, norm_module) + + def _construct_stem(self, dim_in, dim_out, norm_module): + self.conv_xy = nn.Conv3d( + dim_in, + dim_out, + kernel_size=(1, self.kernel[1], self.kernel[2]), + stride=(1, self.stride[1], self.stride[2]), + padding=(0, self.padding[1], self.padding[2]), + bias=False, + ) + self.conv = nn.Conv3d( + dim_out, + dim_out, + kernel_size=(self.kernel[0], 1, 1), + stride=(self.stride[0], 1, 1), + padding=(self.padding[0], 0, 0), + bias=False, + groups=dim_out, + ) + + self.bn = norm_module( + num_features=dim_out, eps=self.eps, momentum=self.bn_mmt + ) + self.relu = nn.ReLU(self.inplace_relu) + + def forward(self, x): + x = self.conv_xy(x) + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class PatchEmbed(nn.Module): + """ + PatchEmbed. + """ + + def __init__( + self, + dim_in=3, + dim_out=768, + kernel=(1, 16, 16), + stride=(1, 4, 4), + padding=(1, 7, 7), + conv_2d=False, + ): + super().__init__() + if conv_2d: + conv = nn.Conv2d + else: + conv = nn.Conv3d + self.proj = conv( + dim_in, + dim_out, + kernel_size=kernel, + stride=stride, + padding=padding, + ) + + def forward(self, x, keep_spatial=False): + x = self.proj(x) + if keep_spatial: + return x, x.shape + # B C (T) H W -> B (T)HW C + return x.flatten(2).transpose(1, 2), x.shape \ No newline at end of file diff --git a/skp/models/rev_mvit/utils.py b/skp/models/rev_mvit/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ca56e648a56d8317dfcfac4e078a23fcc8c8d6c2 --- /dev/null +++ b/skp/models/rev_mvit/utils.py @@ -0,0 +1,221 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import numpy as np +import torch + + +def round_width(width, multiplier, min_width=1, divisor=1, verbose=False): + if not multiplier: + return width + width *= multiplier + min_width = min_width or divisor + if verbose: + print(f"min width {min_width}") + print(f"width {width} divisor {divisor}") + print(f"other {int(width + divisor / 2) // divisor * divisor}") + + width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) + if width_out < 0.9 * width: + width_out += divisor + return int(width_out) + + +def validate_checkpoint_wrapper_import(checkpoint_wrapper): + """ + Check if checkpoint_wrapper is imported. + """ + if checkpoint_wrapper is None: + raise ImportError("Please install fairscale.") + + +def get_gkern(kernlen, std): + """Returns a 2D Gaussian kernel array.""" + + def _gaussian_fn(kernlen, std): + n = torch.arange(0, kernlen).float() + n -= n.mean() + n /= std + w = torch.exp(-0.5 * n**2) + return w + + gkern1d = _gaussian_fn(kernlen, std) + gkern2d = torch.outer(gkern1d, gkern1d) + return gkern2d / gkern2d.sum() + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False): + """ + grid_size: int of the grid height and width + t_size: int of the temporal size + return: + pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + assert embed_dim % 4 == 0 + embed_dim_spatial = embed_dim // 4 * 3 + embed_dim_temporal = embed_dim // 4 + + # spatial + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid( + embed_dim_spatial, grid + ) + + # temporal + grid_t = np.arange(t_size, dtype=np.float32) + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid( + embed_dim_temporal, grid_t + ) + + # concate: [T, H, W] order + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat( + pos_embed_temporal, grid_size**2, axis=1 + ) # [T, H*W, D // 4] + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat( + pos_embed_spatial, t_size, axis=0 + ) # [T, H*W, D // 4 * 3] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) + pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] + + if cls_token: + pos_embed = np.concatenate( + [np.zeros([1, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate( + [np.zeros([1, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0] + ) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1] + ) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5 + ) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def calc_mvit_feature_geometry(cfg): + feat_size = [ + [ + cfg.DATA.NUM_FRAMES // cfg.MVIT.PATCH_STRIDE[0] + if len(cfg.MVIT.PATCH_STRIDE) > 2 + else 1, + cfg.DATA.TRAIN_CROP_SIZE // cfg.MVIT.PATCH_STRIDE[-2], + cfg.DATA.TRAIN_CROP_SIZE // cfg.MVIT.PATCH_STRIDE[-1], + ] + for i in range(cfg.MVIT.DEPTH) + ] + feat_stride = [ + [ + cfg.MVIT.PATCH_STRIDE[0] if len(cfg.MVIT.PATCH_STRIDE) > 2 else 1, + cfg.MVIT.PATCH_STRIDE[-2], + cfg.MVIT.PATCH_STRIDE[-1], + ] + for i in range(cfg.MVIT.DEPTH) + ] + for _, x in enumerate(cfg.MVIT.POOL_Q_STRIDE): + for i in range(cfg.MVIT.DEPTH): + if i >= x[0]: + for j in range(len(feat_size[i])): + feat_size[i][j] = feat_size[i][j] // x[j + 1] + feat_stride[i][j] = feat_stride[i][j] * x[j + 1] + return feat_size, feat_stride \ No newline at end of file diff --git a/skp/models/rev_mvit/video_model_builder.py b/skp/models/rev_mvit/video_model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0dd551fb9a21cf34751834e67880edc1824aadf5 --- /dev/null +++ b/skp/models/rev_mvit/video_model_builder.py @@ -0,0 +1,472 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + + +"""Video models.""" + +import math +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ + + +# import slowfast.utils.weight_init_helper as init_helper +from .attention import MultiScaleBlock +# from slowfast.models.batchnorm_helper import get_norm +from .common import TwoStreamFusion +from .reversible_mvit import ReversibleMViT +from .utils import ( + calc_mvit_feature_geometry, + get_3d_sincos_pos_embed, + round_width, + validate_checkpoint_wrapper_import, +) + +from . import head_helper, stem_helper # noqae + + +class MViT(nn.Module): + """ + Model builder for MViTv1 and MViTv2. + + "MViTv2: Improved Multiscale Vision Transformers for Classification and Detection" + Yanghao Li, Chao-Yuan Wu, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, Christoph Feichtenhofer + https://arxiv.org/abs/2112.01526 + "Multiscale Vision Transformers" + Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra Malik, Christoph Feichtenhofer + https://arxiv.org/abs/2104.11227 + """ + + def __init__(self, cfg): + super().__init__() + # Get parameters. + assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE + self.cfg = cfg + pool_first = cfg.MVIT.POOL_FIRST + # Prepare input. + spatial_size = cfg.DATA.TRAIN_CROP_SIZE + temporal_size = cfg.DATA.NUM_FRAMES + in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0] + self.use_2d_patch = cfg.MVIT.PATCH_2D + self.enable_detection = cfg.DETECTION.ENABLE + self.enable_rev = cfg.MVIT.REV.ENABLE + self.patch_stride = cfg.MVIT.PATCH_STRIDE + if self.use_2d_patch: + self.patch_stride = [1] + self.patch_stride + self.T = cfg.DATA.NUM_FRAMES // self.patch_stride[0] + self.H = cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[1] + self.W = cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[2] + # Prepare output. + num_classes = cfg.MODEL.NUM_CLASSES + embed_dim = cfg.MVIT.EMBED_DIM + # Prepare backbone + num_heads = cfg.MVIT.NUM_HEADS + mlp_ratio = cfg.MVIT.MLP_RATIO + qkv_bias = cfg.MVIT.QKV_BIAS + self.drop_rate = cfg.MVIT.DROPOUT_RATE + depth = cfg.MVIT.DEPTH + drop_path_rate = cfg.MVIT.DROPPATH_RATE + layer_scale_init_value = cfg.MVIT.LAYER_SCALE_INIT_VALUE + head_init_scale = cfg.MVIT.HEAD_INIT_SCALE + mode = cfg.MVIT.MODE + self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON + self.use_mean_pooling = cfg.MVIT.USE_MEAN_POOLING + # Params for positional embedding + self.use_abs_pos = cfg.MVIT.USE_ABS_POS + self.use_fixed_sincos_pos = cfg.MVIT.USE_FIXED_SINCOS_POS + self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED + self.rel_pos_spatial = cfg.MVIT.REL_POS_SPATIAL + self.rel_pos_temporal = cfg.MVIT.REL_POS_TEMPORAL + if cfg.MVIT.NORM == "layernorm": + norm_layer = partial(nn.LayerNorm, eps=1e-6) + else: + raise NotImplementedError("Only supports layernorm.") + self.num_classes = num_classes + self.patch_embed = stem_helper.PatchEmbed( + dim_in=in_chans, + dim_out=embed_dim, + kernel=cfg.MVIT.PATCH_KERNEL, + stride=cfg.MVIT.PATCH_STRIDE, + padding=cfg.MVIT.PATCH_PADDING, + conv_2d=self.use_2d_patch, + ) + + self.input_dims = [temporal_size, spatial_size, spatial_size] + assert self.input_dims[1] == self.input_dims[2] + self.patch_dims = [ + self.input_dims[i] // self.patch_stride[i] + for i in range(len(self.input_dims)) + ] + num_patches = math.prod(self.patch_dims) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if self.cls_embed_on: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + pos_embed_dim = num_patches + 1 + else: + pos_embed_dim = num_patches + + if self.use_abs_pos: + if self.sep_pos_embed: + self.pos_embed_spatial = nn.Parameter( + torch.zeros( + 1, self.patch_dims[1] * self.patch_dims[2], embed_dim + ) + ) + self.pos_embed_temporal = nn.Parameter( + torch.zeros(1, self.patch_dims[0], embed_dim) + ) + if self.cls_embed_on: + self.pos_embed_class = nn.Parameter( + torch.zeros(1, 1, embed_dim) + ) + else: + self.pos_embed = nn.Parameter( + torch.zeros( + 1, + pos_embed_dim, + embed_dim, + ), + requires_grad=not self.use_fixed_sincos_pos, + ) + + if self.drop_rate > 0.0: + self.pos_drop = nn.Dropout(p=self.drop_rate) + + dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) + for i in range(len(cfg.MVIT.DIM_MUL)): + dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] + for i in range(len(cfg.MVIT.HEAD_MUL)): + head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] + + pool_q = [[] for i in range(cfg.MVIT.DEPTH)] + pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] + stride_q = [[] for i in range(cfg.MVIT.DEPTH)] + stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] + + for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): + stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][ + 1: + ] + if cfg.MVIT.POOL_KVQ_KERNEL is not None: + pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL + else: + pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ + s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] + ] + + # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. + if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None: + _stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE + cfg.MVIT.POOL_KV_STRIDE = [] + for i in range(cfg.MVIT.DEPTH): + if len(stride_q[i]) > 0: + _stride_kv = [ + max(_stride_kv[d] // stride_q[i][d], 1) + for d in range(len(_stride_kv)) + ] + cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv) + + for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): + stride_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = cfg.MVIT.POOL_KV_STRIDE[ + i + ][1:] + if cfg.MVIT.POOL_KVQ_KERNEL is not None: + pool_kv[ + cfg.MVIT.POOL_KV_STRIDE[i][0] + ] = cfg.MVIT.POOL_KVQ_KERNEL + else: + pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ + s + 1 if s > 1 else s + for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] + ] + + self.pool_q = pool_q + self.pool_kv = pool_kv + self.stride_q = stride_q + self.stride_kv = stride_kv + + self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None + + input_size = self.patch_dims + + if self.enable_rev: + + # rev does not allow cls token + assert not self.cls_embed_on + + self.rev_backbone = ReversibleMViT(cfg, self) + + embed_dim = round_width( + embed_dim, dim_mul.prod(), divisor=num_heads + ) + + self.fuse = TwoStreamFusion( + cfg.MVIT.REV.RESPATH_FUSE, dim=2 * embed_dim + ) + + if "concat" in self.cfg.MVIT.REV.RESPATH_FUSE: + self.norm = norm_layer(2 * embed_dim) + else: + self.norm = norm_layer(embed_dim) + + else: + + self.blocks = nn.ModuleList() + + for i in range(depth): + num_heads = round_width(num_heads, head_mul[i]) + if cfg.MVIT.DIM_MUL_IN_ATT: + dim_out = round_width( + embed_dim, + dim_mul[i], + divisor=round_width(num_heads, head_mul[i]), + ) + else: + dim_out = round_width( + embed_dim, + dim_mul[i + 1], + divisor=round_width(num_heads, head_mul[i + 1]), + ) + attention_block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + input_size=input_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=self.drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + kernel_q=pool_q[i] if len(pool_q) > i else [], + kernel_kv=pool_kv[i] if len(pool_kv) > i else [], + stride_q=stride_q[i] if len(stride_q) > i else [], + stride_kv=stride_kv[i] if len(stride_kv) > i else [], + mode=mode, + has_cls_embed=self.cls_embed_on, + pool_first=pool_first, + rel_pos_spatial=self.rel_pos_spatial, + rel_pos_temporal=self.rel_pos_temporal, + rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT, + residual_pooling=cfg.MVIT.RESIDUAL_POOLING, + dim_mul_in_att=cfg.MVIT.DIM_MUL_IN_ATT, + separate_qkv=cfg.MVIT.SEPARATE_QKV, + ) + + self.blocks.append(attention_block) + if len(stride_q[i]) > 0: + input_size = [ + size // stride + for size, stride in zip(input_size, stride_q[i]) + ] + + embed_dim = dim_out + + self.norm = norm_layer(embed_dim) + + if self.enable_detection: + raise Exception("Detection is not supported") + else: + self.head = head_helper.TransformerBasicHead( + 2 * embed_dim + if ("concat" in cfg.MVIT.REV.RESPATH_FUSE and self.enable_rev) + else embed_dim, + num_classes, + dropout_rate=cfg.MODEL.DROPOUT_RATE, + act_func=cfg.MODEL.HEAD_ACT, + cfg=cfg, + ) + if self.use_abs_pos: + if self.sep_pos_embed: + trunc_normal_(self.pos_embed_spatial, std=0.02) + trunc_normal_(self.pos_embed_temporal, std=0.02) + if self.cls_embed_on: + trunc_normal_(self.pos_embed_class, std=0.02) + else: + trunc_normal_(self.pos_embed, std=0.02) + if self.use_fixed_sincos_pos: + pos_embed = get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], + self.H, + self.T, + cls_token=self.cls_embed_on, + ) + self.pos_embed.data.copy_( + torch.from_numpy(pos_embed).float().unsqueeze(0) + ) + + if self.cls_embed_on: + trunc_normal_(self.cls_token, std=0.02) + self.apply(self._init_weights) + + self.head.projection.weight.data.mul_(head_init_scale) + self.head.projection.bias.data.mul_(head_init_scale) + + self.feat_size, self.feat_stride = calc_mvit_feature_geometry(cfg) + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0.02) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0.02) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + names = [] + if self.cfg.MVIT.ZERO_DECAY_POS_CLS: + if self.use_abs_pos: + if self.sep_pos_embed: + names.extend( + [ + "pos_embed_spatial", + "pos_embed_temporal", + "pos_embed_class", + ] + ) + else: + names.append("pos_embed") + if self.rel_pos_spatial: + names.extend(["rel_pos_h", "rel_pos_w", "rel_pos_hw"]) + if self.rel_pos_temporal: + names.extend(["rel_pos_t"]) + if self.cls_embed_on: + names.append("cls_token") + + return names + + def _get_pos_embed(self, pos_embed, bcthw): + + if len(bcthw) == 4: + t, h, w = 1, bcthw[-2], bcthw[-1] + else: + t, h, w = bcthw[-3], bcthw[-2], bcthw[-1] + if self.cls_embed_on: + cls_pos_embed = pos_embed[:, 0:1, :] + pos_embed = pos_embed[:, 1:] + txy_num = pos_embed.shape[1] + p_t, p_h, p_w = self.patch_dims + assert p_t * p_h * p_w == txy_num + + if (p_t, p_h, p_w) != (t, h, w): + new_pos_embed = F.interpolate( + pos_embed[:, :, :] + .reshape(1, p_t, p_h, p_w, -1) + .permute(0, 4, 1, 2, 3), + size=(t, h, w), + mode="trilinear", + ) + pos_embed = new_pos_embed.reshape(1, -1, t * h * w).permute(0, 2, 1) + + if self.cls_embed_on: + pos_embed = torch.cat((cls_pos_embed, pos_embed), dim=1) + + return pos_embed + + def _forward_reversible(self, x): + """ + Reversible specific code for forward computation. + """ + # rev does not support cls token or detection + assert not self.cls_embed_on + assert not self.enable_detection + + x = self.rev_backbone(x) + + if self.use_mean_pooling: + x = self.fuse(x) + x = x.mean(1) + x = self.norm(x) + else: + x = self.norm(x) + x = self.fuse(x) + x = x.mean(1) + + x = self.head(x) + + return x + + def forward(self, x, bboxes=None, return_attn=False): + x = x[0] + x, bcthw = self.patch_embed(x) + bcthw = list(bcthw) + if len(bcthw) == 4: # Fix bcthw in case of 4D tensor + bcthw.insert(2, torch.tensor(self.T)) + T, H, W = bcthw[-3], bcthw[-2], bcthw[-1] + assert len(bcthw) == 5 and (T, H, W) == (self.T, self.H, self.W), bcthw + B, N, C = x.shape + s = 1 if self.cls_embed_on else 0 + if self.use_fixed_sincos_pos: + x += self.pos_embed[:, s:, :] # s: on/off cls token + + if self.cls_embed_on: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + if self.use_fixed_sincos_pos: + cls_tokens = cls_tokens + self.pos_embed[:, :s, :] + x = torch.cat((cls_tokens, x), dim=1) + + if self.use_abs_pos: + if self.sep_pos_embed: + pos_embed = self.pos_embed_spatial.repeat( + 1, self.patch_dims[0], 1 + ) + torch.repeat_interleave( + self.pos_embed_temporal, + self.patch_dims[1] * self.patch_dims[2], + dim=1, + ) + if self.cls_embed_on: + pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1) + x += self._get_pos_embed(pos_embed, bcthw) + else: + x += self._get_pos_embed(self.pos_embed, bcthw) + + if self.drop_rate: + x = self.pos_drop(x) + + if self.norm_stem: + x = self.norm_stem(x) + + thw = [T, H, W] + + if self.enable_rev: + x = self._forward_reversible(x) + + else: + for blk in self.blocks: + x, thw = blk(x, thw) + + if self.enable_detection: + assert not self.enable_rev + + x = self.norm(x) + if self.cls_embed_on: + x = x[:, 1:] + + B, _, C = x.shape + x = x.transpose(1, 2).reshape(B, C, thw[0], thw[1], thw[2]) + + x = self.head([x], bboxes) + + else: + if self.use_mean_pooling: + if self.cls_embed_on: + x = x[:, 1:] + x = x.mean(1) + x = self.norm(x) + elif self.cls_embed_on: + x = self.norm(x) + x = x[:, 0] + else: # this is default, [norm->mean] + x = self.norm(x) + x = x.mean(1) + x = self.head(x) + + return x \ No newline at end of file diff --git a/skp/models/segmentation/__init__.py b/skp/models/segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9c72b0e5e81f547f2d9db6696ab5b711169687 --- /dev/null +++ b/skp/models/segmentation/__init__.py @@ -0,0 +1,7 @@ +from .decoders.deeplabv3 import DeepLabV3Plus +from .decoders.deeplabv3_3d import DeepLabV3Plus_3D +from .decoders.fpn import FPN +from .decoders.nas_fpn import NASFPN +from .decoders.unet import Unet +from .decoders.unet_3d import Unet_3D + diff --git a/skp/models/segmentation/__pycache__/__init__.cpython-39.pyc b/skp/models/segmentation/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36c4fcdb59d478a14eff21a1996bc6f71ecce151 Binary files /dev/null and b/skp/models/segmentation/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/segmentation/base/__init__.py b/skp/models/segmentation/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fafa66ee589cb951dd27db2ba7cd51729e5f39ab --- /dev/null +++ b/skp/models/segmentation/base/__init__.py @@ -0,0 +1,12 @@ +from .model import SegmentationModel + +from .modules import ( + Conv2dReLU, + Attention, +) + +from .heads import ( + SegmentationHead, + SegmentationHead_3D, + ClassificationHead, +) diff --git a/skp/models/segmentation/base/__pycache__/__init__.cpython-39.pyc b/skp/models/segmentation/base/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b03a9a4bb808917f1cb9d993846ee000545f4cea Binary files /dev/null and b/skp/models/segmentation/base/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/segmentation/base/__pycache__/heads.cpython-39.pyc b/skp/models/segmentation/base/__pycache__/heads.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee53ebace209b1d14b5f88b0648429362d9b8d50 Binary files /dev/null and b/skp/models/segmentation/base/__pycache__/heads.cpython-39.pyc differ diff --git a/skp/models/segmentation/base/__pycache__/initialization.cpython-39.pyc b/skp/models/segmentation/base/__pycache__/initialization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..181130b8ef3632c76e98c0c8a522b54c29e148b0 Binary files /dev/null and b/skp/models/segmentation/base/__pycache__/initialization.cpython-39.pyc differ diff --git a/skp/models/segmentation/base/__pycache__/model.cpython-39.pyc b/skp/models/segmentation/base/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2aa902c543a6fce99eb48c69bebefbd107a0220f Binary files /dev/null and b/skp/models/segmentation/base/__pycache__/model.cpython-39.pyc differ diff --git a/skp/models/segmentation/base/__pycache__/modules.cpython-39.pyc b/skp/models/segmentation/base/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31f52d8dc6f1da23357ac8f2542c67df7528fd2a Binary files /dev/null and b/skp/models/segmentation/base/__pycache__/modules.cpython-39.pyc differ diff --git a/skp/models/segmentation/base/heads.py b/skp/models/segmentation/base/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..c89439ca81baaeeff07f4eda364283c46284e6a3 --- /dev/null +++ b/skp/models/segmentation/base/heads.py @@ -0,0 +1,34 @@ +import torch.nn as nn +import torch.nn.functional as F + +from ...pooling import create_pool2d_layer + + +class SegmentationHead(nn.Sequential): + def __init__(self, in_channels, out_channels, dropout=0.2, kernel_size=3, upsampling=1): + dropout = nn.Dropout2d(p=dropout) if dropout else nn.Identity() + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(dropout, conv2d, upsampling) + + +class SegmentationHead_3D(nn.Module): + def __init__(self, in_channels, out_channels, dropout=0.2, kernel_size=3, upsampling=1): + super().__init__() + self.dropout = nn.Dropout3d(p=dropout) if dropout else nn.Identity() + self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + self.upsampling = upsampling + + def forward(self, x): + x = self.dropout(x) + x = self.conv3d(x) + x = F.interpolate(x, scale_factor=self.upsampling, mode="trilinear", align_corners=False) + return x + + +class ClassificationHead(nn.Sequential): + def __init__(self, in_channels, classes, pooling="avg", dropout=0.2): + pool = create_pool2d_layer(pooling) + dropout = nn.Dropout(p=dropout) if dropout else nn.Identity() + linear = nn.Linear(in_channels, classes, bias=True) + super().__init__(pool, dropout, linear) diff --git a/skp/models/segmentation/base/initialization.py b/skp/models/segmentation/base/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..9622130204a0172d43a5f32f4ade065e100f746e --- /dev/null +++ b/skp/models/segmentation/base/initialization.py @@ -0,0 +1,27 @@ +import torch.nn as nn + + +def initialize_decoder(module): + for m in module.modules(): + + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def initialize_head(module): + for m in module.modules(): + if isinstance(m, (nn.Linear, nn.Conv2d)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) diff --git a/skp/models/segmentation/base/model.py b/skp/models/segmentation/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8b828bbc77d94ade7c38bd8d30dc92ea5f5661 --- /dev/null +++ b/skp/models/segmentation/base/model.py @@ -0,0 +1,63 @@ +import torch +from . import initialization as init + + +class SegmentationModel(torch.nn.Module): + def initialize(self): + init.initialize_decoder(self.decoder) + init.initialize_head(self.segmentation_head) + if self.classification_head is not None: + init.initialize_head(self.classification_head) + + def check_input_shape(self, x): + + h, w = x.shape[-2:] + output_stride = self.encoder.output_stride + if h % output_stride != 0 or w % output_stride != 0: + new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h + new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w + raise RuntimeError( + f"Wrong input shape height={h}, width={w}. Expected image height and width " + f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})." + ) + + def forward(self, x): + """Sequentially pass `x` trough model`s encoder, decoder and heads""" + + self.check_input_shape(x) + + features = self.encoder(x) + decoder_output = self.decoder(*features) + + if self.deep_supervision and self.training: + output = [] + output.append(self.segmentation_head(decoder_output[-1])) + for head_idx in range(len(self.supervisor_heads)): + output.append(self.supervisor_heads[head_idx](decoder_output[-head_idx-2])) + return tuple(output) + else: + masks = self.segmentation_head(decoder_output) + + if self.classification_head is not None: + labels = self.classification_head(features[-1]) + return masks, labels + + return masks + + @torch.no_grad() + def predict(self, x): + """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` + + Args: + x: 4D torch tensor with shape (batch_size, channels, height, width) + + Return: + prediction: 4D torch tensor with shape (batch_size, classes, height, width) + + """ + if self.training: + self.eval() + + x = self.forward(x) + + return x diff --git a/skp/models/segmentation/base/modules.py b/skp/models/segmentation/base/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..fcbc845fb095a502226545a1603341cec98d145b --- /dev/null +++ b/skp/models/segmentation/base/modules.py @@ -0,0 +1,194 @@ +import torch +import torch.nn as nn + +try: + from inplace_abn import InPlaceABN +except ImportError: + InPlaceABN = None + + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + + if use_batchnorm == "inplace" and InPlaceABN is None: + raise RuntimeError( + "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " + + "To install see: https://github.com/mapillary/inplace_abn" + ) + + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + if use_batchnorm == "inplace": + bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) + relu = nn.Identity() + + elif use_batchnorm and use_batchnorm != "inplace": + bn = nn.BatchNorm2d(out_channels) + + else: + bn = nn.Identity() + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +def GroupNorm(num_channels): + + return nn.GroupNorm(num_groups=16, num_channels=num_channels) + + +class Conv3dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + + if use_batchnorm == "inplace" and InPlaceABN is None: + raise RuntimeError( + "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " + + "To install see: https://github.com/mapillary/inplace_abn" + ) + + conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + if use_batchnorm == "inplace": + bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) + relu = nn.Identity() + + elif use_batchnorm and use_batchnorm != "inplace": + bn = nn.BatchNorm3d(out_channels) + + else: + bn = GroupNorm(out_channels) + + super(Conv3dReLU, self).__init__(conv, bn, relu) + + +class SCSEModule_3D(nn.Module): + def __init__(self, in_channels, reduction=16): + super().__init__() + self.cSE = nn.Sequential( + nn.AdaptiveAvgPool3d(1), + nn.Conv3d(in_channels, in_channels // reduction, 1), + nn.ReLU(inplace=True), + nn.Conv3d(in_channels // reduction, in_channels, 1), + nn.Sigmoid(), + ) + self.sSE = nn.Sequential(nn.Conv3d(in_channels, 1, 1), nn.Sigmoid()) + + def forward(self, x): + return x * self.cSE(x) + x * self.sSE(x) + + +class SCSEModule(nn.Module): + def __init__(self, in_channels, reduction=16): + super().__init__() + self.cSE = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, in_channels // reduction, 1), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels // reduction, in_channels, 1), + nn.Sigmoid(), + ) + self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) + + def forward(self, x): + return x * self.cSE(x) + x * self.sSE(x) + + +class ArgMax(nn.Module): + def __init__(self, dim=None): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.argmax(x, dim=self.dim) + + +class Clamp(nn.Module): + def __init__(self, min=0, max=1): + super().__init__() + self.min, self.max = min, max + + def forward(self, x): + return torch.clamp(x, self.min, self.max) + + +class Activation(nn.Module): + def __init__(self, name, **params): + + super().__init__() + + if name is None or name == "identity": + self.activation = nn.Identity(**params) + elif name == "sigmoid": + self.activation = nn.Sigmoid() + elif name == "softmax2d": + self.activation = nn.Softmax(dim=1, **params) + elif name == "softmax": + self.activation = nn.Softmax(**params) + elif name == "logsoftmax": + self.activation = nn.LogSoftmax(**params) + elif name == "tanh": + self.activation = nn.Tanh() + elif name == "argmax": + self.activation = ArgMax(**params) + elif name == "argmax2d": + self.activation = ArgMax(dim=1, **params) + elif name == "clamp": + self.activation = Clamp(**params) + elif callable(name): + self.activation = name(**params) + else: + raise ValueError( + f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/" + f"argmax/argmax2d/clamp/None; got {name}" + ) + + def forward(self, x): + return self.activation(x) + + +class Attention(nn.Module): + def __init__(self, name, **params): + super().__init__() + + if name is None: + self.attention = nn.Identity(**params) + elif name == "scse": + self.attention = SCSEModule(**params) + elif name == "scse_3d": + self.attention = SCSEModule_3D(**params) + else: + raise ValueError("Attention {} is not implemented".format(name)) + + def forward(self, x): + return self.attention(x) diff --git a/skp/models/segmentation/decoders/deeplabv3/__init__.py b/skp/models/segmentation/decoders/deeplabv3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dab67386047e66725f09cc0d2be03dd93c6cf713 --- /dev/null +++ b/skp/models/segmentation/decoders/deeplabv3/__init__.py @@ -0,0 +1 @@ +from .model import DeepLabV3Plus diff --git a/skp/models/segmentation/decoders/deeplabv3/__pycache__/__init__.cpython-39.pyc b/skp/models/segmentation/decoders/deeplabv3/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d142dda3c3e6cd98ab50656a9011eb939e18dadc Binary files /dev/null and b/skp/models/segmentation/decoders/deeplabv3/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/deeplabv3/__pycache__/decoder.cpython-39.pyc b/skp/models/segmentation/decoders/deeplabv3/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09d58fb4afc13588840836d5bf74986c56eef19a Binary files /dev/null and b/skp/models/segmentation/decoders/deeplabv3/__pycache__/decoder.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/deeplabv3/__pycache__/model.cpython-39.pyc b/skp/models/segmentation/decoders/deeplabv3/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a5d9c054af2b4841124cb4f3505e932b685ae8d Binary files /dev/null and b/skp/models/segmentation/decoders/deeplabv3/__pycache__/model.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/deeplabv3/decoder.py b/skp/models/segmentation/decoders/deeplabv3/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..cdaf98468ea1eb2a531fe6f303c0a96ca506148b --- /dev/null +++ b/skp/models/segmentation/decoders/deeplabv3/decoder.py @@ -0,0 +1,214 @@ +""" +BSD 3-Clause License + +Copyright (c) Soumith Chintala 2016, +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class DeepLabV3PlusDecoder(nn.Module): + def __init__( + self, + encoder_channels, + out_channels=256, + atrous_rates=(12, 24, 36), + output_stride=16, + deep_supervision=False + ): + super().__init__() + assert output_stride in [8, 16, 32] + + self.out_channels = out_channels + self.output_stride = output_stride + self.deep_supervision = deep_supervision + + self.aspp = nn.Sequential( + ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True), + SeparableConv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + if output_stride == 32: + scale_factor = 8 + elif output_stride == 16: + scale_factor = 4 + else: + scale_factor = 2 + + self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor) + + highres_in_channels = encoder_channels[-4] + highres_out_channels = 48 # proposed by authors of paper + self.block1 = nn.Sequential( + nn.Conv2d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(highres_out_channels), + nn.ReLU(), + ) + self.block2 = nn.Sequential( + SeparableConv2d( + highres_out_channels + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, *features): + aspp_features = self.aspp(features[-1]) + aspp_features = self.up(aspp_features) + high_res_features = self.block1(features[-4]) + concat_features = torch.cat([aspp_features, high_res_features], dim=1) + fused_features = self.block2(concat_features) + + if self.deep_supervision and self.training: + return aspp_features, high_res_features, fused_features + + return fused_features + + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + super().__init__( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + +class ASPPSeparableConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + super().__init__( + SeparableConv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels, out_channels): + super().__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, x): + size = x.shape[-2:] + for mod in self: + x = mod(x) + return F.interpolate(x, size=size, mode="bilinear", align_corners=False) + + +class ASPP(nn.Module): + def __init__(self, in_channels, out_channels, atrous_rates, separable=False): + super(ASPP, self).__init__() + modules = [] + modules.append( + nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + ) + + rate1, rate2, rate3 = tuple(atrous_rates) + ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv + + modules.append(ASPPConvModule(in_channels, out_channels, rate1)) + modules.append(ASPPConvModule(in_channels, out_channels, rate2)) + modules.append(ASPPConvModule(in_channels, out_channels, rate3)) + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Dropout(0.5), + ) + + def forward(self, x): + res = [] + for conv in self.convs: + res.append(conv(x)) + res = torch.cat(res, dim=1) + return self.project(res) + + +class SeparableConv2d(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=True, + ): + dephtwise_conv = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False, + ) + pointwise_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + bias=bias, + ) + super().__init__(dephtwise_conv, pointwise_conv) diff --git a/skp/models/segmentation/decoders/deeplabv3/model.py b/skp/models/segmentation/decoders/deeplabv3/model.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5d44cf7f27e45eb71b3cea7547b2594b111365 --- /dev/null +++ b/skp/models/segmentation/decoders/deeplabv3/model.py @@ -0,0 +1,120 @@ +from torch import nn +from typing import Optional + +from ...base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from ...encoders.create import create_encoder +from .decoder import DeepLabV3PlusDecoder + + +class DeepLabV3Plus(SegmentationModel): + """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable + Convolution for Semantic Image Segmentation" + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) + decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values) + decoder_channels: A number of convolution filters in ASPP module. Default is 256 + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + Returns: + ``torch.nn.Module``: **DeepLabV3Plus** + + Reference: + https://arxiv.org/abs/1802.02611v3 + + """ + + def __init__( + self, + encoder_name: str, + encoder_params: dict = {"pretrained": True, "output_stride": 16}, + decoder_channels: int = 256, + decoder_atrous_rates: tuple = (12, 24, 36), + dropout: float = 0.2, + in_channels: int = 3, + classes: int = 1, + deep_supervision: bool = False, + activation: Optional[str] = None, + upsampling: int = 4, + aux_params: Optional[dict] = None, + ): + super().__init__() + + encoder_output_stride = encoder_params.pop("output_stride", None) + if encoder_output_stride not in [8, 16, 32]: + raise ValueError("Encoder output stride should be 8, 16, or 32; got {}".format(encoder_output_stride)) + + self.encoder = create_encoder( + name=encoder_name, + encoder_params=encoder_params, + encoder_output_stride=encoder_output_stride, + in_channels=in_channels + ) + + self.decoder = DeepLabV3PlusDecoder( + encoder_channels=self.encoder.out_channels, + out_channels=decoder_channels, + atrous_rates=decoder_atrous_rates, + output_stride=encoder_output_stride, + deep_supervision=deep_supervision + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + kernel_size=1, + dropout=dropout, + upsampling=upsampling, + ) + + self.deep_supervision = deep_supervision + if self.deep_supervision: + self.supervisor_heads = [] + self.supervisor_heads.append( + SegmentationHead( + in_channels=48, + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=1, + ) + ) + self.supervisor_heads.append( + SegmentationHead( + in_channels=decoder_channels, + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=1, + ) + ) + self.supervisor_heads = nn.Sequential(*self.supervisor_heads) + + if aux_params is not None: + self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) + else: + self.classification_head = None diff --git a/skp/models/segmentation/decoders/deeplabv3_3d/__init__.py b/skp/models/segmentation/decoders/deeplabv3_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a508598da9934c7b4f140b772fa6ee2005c9b747 --- /dev/null +++ b/skp/models/segmentation/decoders/deeplabv3_3d/__init__.py @@ -0,0 +1 @@ +from .model import DeepLabV3Plus_3D diff --git a/skp/models/segmentation/decoders/deeplabv3_3d/__pycache__/__init__.cpython-39.pyc b/skp/models/segmentation/decoders/deeplabv3_3d/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c72d0025e6de1fd50969a5178e19f44c1f42868 Binary files /dev/null and b/skp/models/segmentation/decoders/deeplabv3_3d/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/deeplabv3_3d/__pycache__/decoder.cpython-39.pyc b/skp/models/segmentation/decoders/deeplabv3_3d/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c30daec36701a90b896852d78214adb0d709f15b Binary files /dev/null and b/skp/models/segmentation/decoders/deeplabv3_3d/__pycache__/decoder.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/deeplabv3_3d/__pycache__/model.cpython-39.pyc b/skp/models/segmentation/decoders/deeplabv3_3d/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b25e769ad607c74947d9c22a03007a20653e9a1d Binary files /dev/null and b/skp/models/segmentation/decoders/deeplabv3_3d/__pycache__/model.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/deeplabv3_3d/decoder.py b/skp/models/segmentation/decoders/deeplabv3_3d/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a12f5618a64603a75f6e47a2beabd47841f4fa86 --- /dev/null +++ b/skp/models/segmentation/decoders/deeplabv3_3d/decoder.py @@ -0,0 +1,223 @@ +""" +BSD 3-Clause License + +Copyright (c) Soumith Chintala 2016, +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +def GroupNorm(num_channels): + + return nn.GroupNorm(num_groups=16, num_channels=num_channels) + + +class DeepLabV3PlusDecoder_3D(nn.Module): + def __init__( + self, + encoder_channels, + out_channels=256, + atrous_rates=(12, 24, 36), + output_stride=16, + deep_supervision=False, + norm_layer="batch_norm", + ): + super().__init__() + assert output_stride in [8, 16, 32] + + self.out_channels = out_channels + self.output_stride = output_stride + self.deep_supervision = deep_supervision + + if norm_layer == "batch_norm": + norm_layer = nn.BatchNorm3d + elif norm_layer == "group_norm": + norm_layer = GroupNorm + + self.aspp = nn.Sequential( + ASPP(encoder_channels[-1], out_channels, atrous_rates, norm_layer=norm_layer, separable=True), + SeparableConv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + norm_layer(out_channels), + nn.ReLU(), + ) + + if output_stride == 32: + self.scale_factor = 8 + elif output_stride == 16: + self.scale_factor = 4 + else: + self.scale_factor = 2 + + highres_in_channels = encoder_channels[-4] + highres_out_channels = 48 # proposed by authors of paper + self.block1 = nn.Sequential( + nn.Conv3d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False), + norm_layer(highres_out_channels), + nn.ReLU(), + ) + self.block2 = nn.Sequential( + SeparableConv3d( + highres_out_channels + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + norm_layer(out_channels), + nn.ReLU(), + ) + + def forward(self, *features): + aspp_features = self.aspp(features[-1]) + aspp_features = F.interpolate(aspp_features, scale_factor=self.scale_factor, mode="trilinear", align_corners=False) + high_res_features = self.block1(features[-4]) + concat_features = torch.cat([aspp_features, high_res_features], dim=1) + fused_features = self.block2(concat_features) + + if self.deep_supervision and self.training: + return aspp_features, high_res_features, fused_features + + return fused_features + + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation, norm_layer): + super().__init__( + nn.Conv3d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ), + norm_layer(out_channels), + nn.ReLU(), + ) + + +class ASPPSeparableConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation, norm_layer): + super().__init__( + SeparableConv3d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ), + norm_layer(out_channels), + nn.ReLU(), + ) + + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels, out_channels, norm_layer): + super().__init__( + nn.AdaptiveAvgPool3d(1), + nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU(), + ) + + def forward(self, x): + size = x.shape[-3:] + for mod in self: + x = mod(x) + return F.interpolate(x, size=size, mode="trilinear", align_corners=False) + + +class ASPP(nn.Module): + def __init__(self, in_channels, out_channels, atrous_rates, norm_layer, separable=False): + super(ASPP, self).__init__() + modules = [] + modules.append( + nn.Sequential( + nn.Conv3d(in_channels, out_channels, 1, bias=False), + norm_layer(out_channels), + nn.ReLU(), + ) + ) + + rate1, rate2, rate3 = tuple(atrous_rates) + ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv + + modules.append(ASPPConvModule(in_channels, out_channels, rate1, norm_layer=norm_layer)) + modules.append(ASPPConvModule(in_channels, out_channels, rate2, norm_layer=norm_layer)) + modules.append(ASPPConvModule(in_channels, out_channels, rate3, norm_layer=norm_layer)) + modules.append(ASPPPooling(in_channels, out_channels, norm_layer=norm_layer)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv3d(5 * out_channels, out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU(), + nn.Dropout(0.5), + ) + + def forward(self, x): + res = [] + for conv in self.convs: + res.append(conv(x)) + res = torch.cat(res, dim=1) + return self.project(res) + + +class SeparableConv3d(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=True, + ): + dephtwise_conv = nn.Conv3d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False, + ) + pointwise_conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + bias=bias, + ) + super().__init__(dephtwise_conv, pointwise_conv) diff --git a/skp/models/segmentation/decoders/deeplabv3_3d/model.py b/skp/models/segmentation/decoders/deeplabv3_3d/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8782a6ccfb5ca1af6d8ac7c981d866087b437813 --- /dev/null +++ b/skp/models/segmentation/decoders/deeplabv3_3d/model.py @@ -0,0 +1,125 @@ +from torch import nn +from typing import Optional + +from ...base import ( + SegmentationModel, + SegmentationHead_3D, + ClassificationHead, +) +from ...encoders.create import create_encoder +from .decoder import DeepLabV3PlusDecoder_3D + + +class DeepLabV3Plus_3D(SegmentationModel): + """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable + Convolution for Semantic Image Segmentation" + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) + decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values) + decoder_channels: A number of convolution filters in ASPP module. Default is 256 + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + Returns: + ``torch.nn.Module``: **DeepLabV3Plus** + + Reference: + https://arxiv.org/abs/1802.02611v3 + + """ + + def __init__( + self, + encoder_name: str, + encoder_params: dict = {"pretrained": True, "output_stride": 16}, + decoder_channels: int = 256, + decoder_atrous_rates: tuple = (12, 24, 36), + dropout: float = 0.2, + in_channels: int = 3, + classes: int = 1, + deep_supervision: bool = False, + norm_layer: str = "batch_norm", + activation: Optional[str] = None, + upsampling: int = 4, + aux_params: Optional[dict] = None, + ): + super().__init__() + + assert "x3d" in encoder_name, "Only X3D backbone is currently supported for 3D segmentation" + encoder_output_stride = encoder_params.pop("output_stride", None) + if encoder_output_stride not in [8, 16, 32]: + raise ValueError("Encoder output stride should be 8, 16, or 32; got {}".format(encoder_output_stride)) + + self.encoder = create_encoder( + name=encoder_name, + encoder_params=encoder_params, + encoder_output_stride=encoder_output_stride, + in_channels=in_channels + ) + + assert norm_layer in ["batch_norm", "group_norm"] + + self.decoder = DeepLabV3PlusDecoder_3D( + encoder_channels=self.encoder.out_channels, + out_channels=decoder_channels, + atrous_rates=decoder_atrous_rates, + output_stride=encoder_output_stride, + deep_supervision = deep_supervision, + norm_layer=norm_layer, + ) + + self.segmentation_head = SegmentationHead_3D( + in_channels=self.decoder.out_channels, + out_channels=classes, + kernel_size=1, + dropout=dropout, + upsampling=upsampling, + ) + + self.deep_supervision = deep_supervision + if self.deep_supervision: + self.supervisor_heads = [] + self.supervisor_heads.append( + SegmentationHead_3D( + in_channels=48, + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=1, + ) + ) + self.supervisor_heads.append( + SegmentationHead_3D( + in_channels=decoder_channels, + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=1, + ) + ) + self.supervisor_heads = nn.Sequential(*self.supervisor_heads) + + if aux_params is not None: + self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) + else: + self.classification_head = None diff --git a/skp/models/segmentation/decoders/fpn/__init__.py b/skp/models/segmentation/decoders/fpn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd72b0731afe38fdcd5fea3ffe9ca18b2aba86d --- /dev/null +++ b/skp/models/segmentation/decoders/fpn/__init__.py @@ -0,0 +1 @@ +from .model import FPN diff --git a/skp/models/segmentation/decoders/fpn/__pycache__/__init__.cpython-39.pyc b/skp/models/segmentation/decoders/fpn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a178924dfba3e29fa85cd3787011068764d638a8 Binary files /dev/null and b/skp/models/segmentation/decoders/fpn/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/fpn/__pycache__/decoder.cpython-39.pyc b/skp/models/segmentation/decoders/fpn/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eca293870fdb0d8d820f42fb7f2d026f89e021d7 Binary files /dev/null and b/skp/models/segmentation/decoders/fpn/__pycache__/decoder.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/fpn/__pycache__/model.cpython-39.pyc b/skp/models/segmentation/decoders/fpn/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b1bb37ab36942ed6345f57809a8df4fde4ea450 Binary files /dev/null and b/skp/models/segmentation/decoders/fpn/__pycache__/model.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/fpn/decoder.py b/skp/models/segmentation/decoders/fpn/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..caff0a547e6fd975f4dc4c6dc0cddd91de0190db --- /dev/null +++ b/skp/models/segmentation/decoders/fpn/decoder.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Conv3x3GNReLU(nn.Module): + def __init__(self, in_channels, out_channels, upsample=False): + super().__init__() + self.upsample = upsample + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False), + nn.GroupNorm(32, out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + x = self.block(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + return x + + +class FPNBlock(nn.Module): + def __init__(self, pyramid_channels, skip_channels): + super().__init__() + self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + skip = self.skip_conv(skip) + x = x + skip + return x + + +class SegmentationBlock(nn.Module): + def __init__(self, in_channels, out_channels, n_upsamples=0): + super().__init__() + + blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] + + if n_upsamples > 1: + for _ in range(1, n_upsamples): + blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) + + self.block = nn.Sequential(*blocks) + + def forward(self, x): + return self.block(x) + + +class MergeBlock(nn.Module): + def __init__(self, policy): + super().__init__() + if policy not in ["add", "cat"]: + raise ValueError("`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy)) + self.policy = policy + + def forward(self, x): + if self.policy == "add": + return sum(x) + elif self.policy == "cat": + return torch.cat(x, dim=1) + else: + raise ValueError("`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)) + + +class FPNDecoder(nn.Module): + def __init__( + self, + encoder_channels, + encoder_depth=5, + pyramid_channels=256, + segmentation_channels=128, + merge_policy="add", + deep_supervision=False, + ): + super().__init__() + + self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4 + if encoder_depth < 3: + raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth)) + + encoder_channels = encoder_channels[::-1] + encoder_channels = encoder_channels[: encoder_depth + 1] + + self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) + self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) + self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) + self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) + + self.seg_blocks = nn.ModuleList( + [ + SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples) + for n_upsamples in [3, 2, 1, 0] + ] + ) + + self.merge = MergeBlock(merge_policy) + + def forward(self, *features): + c2, c3, c4, c5 = features[-4:] + + p5 = self.p5(c5) + p4 = self.p4(p5, c4) + p3 = self.p3(p4, c3) + p2 = self.p2(p3, c2) + + + feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])] + x = self.merge(feature_pyramid) + + return x diff --git a/skp/models/segmentation/decoders/fpn/model.py b/skp/models/segmentation/decoders/fpn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..78b5a42f3a67c9957c74dd096c94ba3f7531b4a3 --- /dev/null +++ b/skp/models/segmentation/decoders/fpn/model.py @@ -0,0 +1,98 @@ +from typing import Optional, Union + +from ...base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from ...encoders.create import create_encoder +from .decoder import FPNDecoder + + +class FPN(SegmentationModel): + """FPN_ is a fully convolution neural network for image semantic segmentation. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_ + decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_ + decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add** + and **cat** + decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_ + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **FPN** + + .. _FPN: + http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf + + """ + + def __init__( + self, + encoder_name: str, + encoder_params: dict = {"pretrained": True, "depth": 5}, + decoder_pyramid_channels: int = 256, + decoder_segmentation_channels: int = 128, + decoder_merge_policy: str = "add", + dropout: float = 0.2, + in_channels: int = 3, + classes: int = 1, + deep_supervision: bool = False, + activation: Optional[str] = None, + upsampling: int = 4, + aux_params: Optional[dict] = None, + ): + super().__init__() + + encoder_depth = encoder_params.pop("depth", 5) + self.encoder = create_encoder( + name=encoder_name, + encoder_params=encoder_params, + in_channels=in_channels + ) + + self.decoder = FPNDecoder( + encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, + pyramid_channels=decoder_pyramid_channels, + segmentation_channels=decoder_segmentation_channels, + merge_policy=decoder_merge_policy, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + kernel_size=1, + dropout=dropout, + upsampling=upsampling, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) + else: + self.classification_head = None + + self.name = "fpn-{}".format(encoder_name) + self.initialize() diff --git a/skp/models/segmentation/decoders/nas_fpn/__init__.py b/skp/models/segmentation/decoders/nas_fpn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c43ca279b006497b3b4d8052f392d8452db513a --- /dev/null +++ b/skp/models/segmentation/decoders/nas_fpn/__init__.py @@ -0,0 +1 @@ +from .model import NASFPN diff --git a/skp/models/segmentation/decoders/nas_fpn/__pycache__/__init__.cpython-39.pyc b/skp/models/segmentation/decoders/nas_fpn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82de91c090cd02810bbfa43d3b568c044db86924 Binary files /dev/null and b/skp/models/segmentation/decoders/nas_fpn/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/nas_fpn/__pycache__/decoder.cpython-39.pyc b/skp/models/segmentation/decoders/nas_fpn/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5edcbaafb6a57e96be794ac3177e0a2abc8cf780 Binary files /dev/null and b/skp/models/segmentation/decoders/nas_fpn/__pycache__/decoder.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/nas_fpn/__pycache__/model.cpython-39.pyc b/skp/models/segmentation/decoders/nas_fpn/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c178ed54abf376904771120f9981e896cebf8db Binary files /dev/null and b/skp/models/segmentation/decoders/nas_fpn/__pycache__/model.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/nas_fpn/decoder.py b/skp/models/segmentation/decoders/nas_fpn/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ae80805d40d20f75031e2aa16e2078c649fa21 --- /dev/null +++ b/skp/models/segmentation/decoders/nas_fpn/decoder.py @@ -0,0 +1,365 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from abc import abstractmethod + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvModule(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + activation="leaky_relu", + order=("conv", "norm", "act"), + act_inplace=True): + + super().__init__() + self.conv = nn.Conv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + self.norm = nn.BatchNorm2d(out_channels) + if activation: + if activation == "leaky_relu": + self.act = nn.LeakyReLU(negative_slope=0.01, inplace=act_inplace) + elif activation == "silu": + self.act = nn.SiLU(inplace=act_inplace) + elif activation == "gelu": + self.act = nn.GELU() + else: + self.act = nn.Identity() + self.order = order + + def forward(self, x): + for i in self.order: + x = getattr(self, i)(x) + return x + + +class BaseMergeCell(nn.Module): + """The basic class for cells used in NAS-FPN and NAS-FCOS. + BaseMergeCell takes 2 inputs. After applying convolution + on them, they are resized to the target size. Then, + they go through binary_op, which depends on the type of cell. + If with_out_conv is True, the result of output will go through + another convolution layer. + Args: + in_channels (int): number of input channels in out_conv layer. + out_channels (int): number of output channels in out_conv layer. + with_out_conv (bool): Whether to use out_conv layer + out_conv_cfg (dict): Config dict for convolution layer, which should + contain "groups", "kernel_size", "padding", "bias" to build + out_conv layer. + out_norm_cfg (dict): Config dict for normalization layer in out_conv. + out_conv_order (tuple): The order of conv/norm/activation layers in + out_conv. + with_input1_conv (bool): Whether to use convolution on input1. + with_input2_conv (bool): Whether to use convolution on input2. + input_conv_cfg (dict): Config dict for building input1_conv layer and + input2_conv layer, which is expected to contain the type of + convolution. + Default: None, which means using conv2d. + input_norm_cfg (dict): Config dict for normalization layer in + input1_conv and input2_conv layer. Default: None. + upsample_mode (str): Interpolation method used to resize the output + of input1_conv and input2_conv to target size. Currently, we + support ['nearest', 'bilinear']. Default: 'nearest'. + """ + + def __init__(self, + fused_channels=256, + out_channels=256, + with_out_conv=True, + out_conv_cfg=dict( + groups=1, kernel_size=3, padding=1, bias=True), + out_conv_order=('act', 'conv', 'norm'), + with_input1_conv=False, + with_input2_conv=False, + upsample_mode='nearest'): + super().__init__() + assert upsample_mode in ['nearest', 'bilinear'] + self.with_out_conv = with_out_conv + self.with_input1_conv = with_input1_conv + self.with_input2_conv = with_input2_conv + self.upsample_mode = upsample_mode + + if self.with_out_conv: + self.out_conv = ConvModule( + fused_channels, + out_channels, + **out_conv_cfg, + order=out_conv_order) + + self.input1_conv = self._build_input_conv( + out_channels) if with_input1_conv else nn.Sequential() + self.input2_conv = self._build_input_conv( + out_channels) if with_input2_conv else nn.Sequential() + + def _build_input_conv(self, channel): + return ConvModule( + channel, + channel, + 3, + padding=1, + bias=True) + + @abstractmethod + def _binary_op(self, x1, x2): + pass + + def _resize(self, x, size): + if x.shape[-2:] == size: + return x + elif x.shape[-2:] < size: + return F.interpolate(x, size=size, mode=self.upsample_mode) + else: + if x.shape[-2] % size[-2] != 0 or x.shape[-1] % size[-1] != 0: + h, w = x.shape[-2:] + target_h, target_w = size + pad_h = math.ceil(h / target_h) * target_h - h + pad_w = math.ceil(w / target_w) * target_w - w + pad_l = pad_w // 2 + pad_r = pad_w - pad_l + pad_t = pad_h // 2 + pad_b = pad_h - pad_t + pad = (pad_l, pad_r, pad_t, pad_b) + x = F.pad(x, pad, mode='constant', value=0.0) + kernel_size = (x.shape[-2] // size[-2], x.shape[-1] // size[-1]) + x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size) + return x + + def forward(self, x1, x2, out_size=None): + assert x1.shape[:2] == x2.shape[:2] + assert out_size is None or len(out_size) == 2 + if out_size is None: # resize to larger one + out_size = max(x1.size()[2:], x2.size()[2:]) + + x1 = self.input1_conv(x1) + x2 = self.input2_conv(x2) + + x1 = self._resize(x1, out_size) + x2 = self._resize(x2, out_size) + + x = self._binary_op(x1, x2) + if self.with_out_conv: + x = self.out_conv(x) + return x + + +class SumCell(BaseMergeCell): + + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + def _binary_op(self, x1, x2): + return x1 + x2 + + +class ConcatCell(BaseMergeCell): + + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__(in_channels * 2, out_channels, **kwargs) + + def _binary_op(self, x1, x2): + ret = torch.cat([x1, x2], dim=1) + return ret + + +class GlobalPoolingCell(BaseMergeCell): + + def __init__(self, in_channels=None, out_channels=None, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + + def _binary_op(self, x1, x2): + x2_att = self.global_pool(x2).sigmoid() + return x2 + x2_att * x1 + + +class Conv3x3GNReLU(nn.Module): + def __init__(self, in_channels, out_channels, upsample=False): + super().__init__() + self.upsample = upsample + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False), + nn.GroupNorm(32, out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + x = self.block(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + return x + + +class SegmentationBlock(nn.Module): + def __init__(self, in_channels, out_channels, n_upsamples=0): + super().__init__() + + blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] + + if n_upsamples > 1: + for _ in range(1, n_upsamples): + blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) + + self.block = nn.Sequential(*blocks) + + def forward(self, x): + return self.block(x) + + +class MergeBlock(nn.Module): + def __init__(self, policy): + super().__init__() + if policy not in ["add", "cat"]: + raise ValueError("`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy)) + self.policy = policy + + def forward(self, x): + if self.policy == "add": + return sum(x) + elif self.policy == "cat": + return torch.cat(x, dim=1) + else: + raise ValueError("`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)) + + +class NASFPNDecoder(nn.Module): + """NAS-FPN. + + Implementation of `NAS-FPN: Learning Scalable Feature Pyramid Architecture + for Object Detection `_ + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + depth (int): Number of output scales. + stack_times (int): The number of times the pyramid architecture will + be stacked. + """ + + def __init__(self, + in_channels, + pyramid_channels=256, + segmentation_channels=128, + depth=5, + stack_times=3, + merge_policy="add", + deep_supervision=False): + super().__init__() + assert isinstance(in_channels, (list, tuple)) + self.in_channels = in_channels + self.pyramid_channels = pyramid_channels + self.num_ins = len(in_channels) # num of input feature levels + self.depth = depth # num of output feature levels + assert self.num_ins == self.depth + self.stack_times = stack_times + + self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 5 + self.deep_supervision = deep_supervision + + # add lateral connections + self.lateral_convs = nn.ModuleList() + for i in range(depth): + l_conv = ConvModule( + in_channels[i], + pyramid_channels, + 1, + activation=None) + self.lateral_convs.append(l_conv) + + # add NAS FPN connections + self.fpn_stages = nn.ModuleList() + for _ in range(self.stack_times): + stage = nn.ModuleDict() + # gp(p6, p4) -> p4_1 + stage['gp_64_4'] = GlobalPoolingCell( + in_channels=pyramid_channels, + out_channels=pyramid_channels) + # sum(p4_1, p4) -> p4_2 + stage['sum_44_4'] = SumCell( + in_channels=pyramid_channels, + out_channels=pyramid_channels) + # sum(p4_2, p3) -> p3_out + stage['sum_43_3'] = SumCell( + in_channels=pyramid_channels, + out_channels=pyramid_channels) + # sum(p3_out, p4_2) -> p4_out + stage['sum_34_4'] = SumCell( + in_channels=pyramid_channels, + out_channels=pyramid_channels) + # sum(p5, gp(p4_out, p3_out)) -> p5_out + stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False) + stage['sum_55_5'] = SumCell( + in_channels=pyramid_channels, + out_channels=pyramid_channels) + # sum(p7, gp(p5_out, p4_2)) -> p7_out + stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False) + stage['sum_77_7'] = SumCell( + in_channels=pyramid_channels, + out_channels=pyramid_channels) + # gp(p7_out, p5_out) -> p6_out + stage['gp_75_6'] = GlobalPoolingCell( + in_channels=pyramid_channels, + out_channels=pyramid_channels) + self.fpn_stages.append(stage) + + self.seg_blocks = nn.ModuleList( + [ + SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples) + for n_upsamples in [4, 3, 2, 1, 0] + ] + ) + + self.merge = MergeBlock(merge_policy) + + def forward(self, *features): + """Forward function.""" + # build P1-P5 + features = [ + lateral_conv(features[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # This is actually P1-P5 but too lazy to change the naming scheme + p3, p4, p5, p6, p7 = features[-5:] + + for stage in self.fpn_stages: + # gp(p6, p4) -> p4_1 + p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:]) + # sum(p4_1, p4) -> p4_2 + p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:]) + # sum(p4_2, p3) -> p3_out + p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:]) + # sum(p3_out, p4_2) -> p4_out + p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:]) + # sum(p5, gp(p4_out, p3_out)) -> p5_out + p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:]) + p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:]) + # sum(p7, gp(p5_out, p4_2)) -> p7_out + p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:]) + p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:]) + # gp(p7_out, p5_out) -> p6_out + p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:]) + + feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p7, p6, p5, p4, p3])] + x = self.merge(feature_pyramid) + + if self.deep_supervision and self.training: + return p4, p3, x + + return x \ No newline at end of file diff --git a/skp/models/segmentation/decoders/nas_fpn/model.py b/skp/models/segmentation/decoders/nas_fpn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f244d22c20e4a19ea92c2f1fc78acfb1c6cffe1f --- /dev/null +++ b/skp/models/segmentation/decoders/nas_fpn/model.py @@ -0,0 +1,125 @@ +from torch import nn +from typing import Optional, Union + +from ...base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from ...encoders.create import create_encoder +from .decoder import NASFPNDecoder + + +class NASFPN(SegmentationModel): + """FPN_ is a fully convolution neural network for image semantic segmentation. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_ + decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_ + decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add** + and **cat** + decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_ + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **FPN** + + .. _FPN: + http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf + + """ + + def __init__( + self, + encoder_name: str, + encoder_params: dict = {"pretrained": True, "depth": 5}, + decoder_pyramid_channels: int = 256, + decoder_segmentation_channels: int = 128, + decoder_stack_times: int = 3, + decoder_merge_policy: str = "add", + dropout: float = 0.2, + in_channels: int = 3, + classes: int = 1, + deep_supervision: bool = False, + activation: Optional[str] = None, + upsampling: int = 2, + aux_params: Optional[dict] = None, + ): + super().__init__() + + encoder_depth = encoder_params.pop("depth", 5) + self.encoder = create_encoder( + name=encoder_name, + encoder_params=encoder_params, + in_channels=in_channels + ) + + self.decoder = NASFPNDecoder( + in_channels=self.encoder.out_channels, + depth=encoder_depth, + pyramid_channels=decoder_pyramid_channels, + segmentation_channels=decoder_segmentation_channels, + stack_times=decoder_stack_times, + merge_policy=decoder_merge_policy, + deep_supervision=deep_supervision + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + kernel_size=1, + dropout=dropout, + upsampling=upsampling, + ) + + self.deep_supervision = deep_supervision + if self.deep_supervision: + self.supervisor_heads = [] + self.supervisor_heads.append( + SegmentationHead( + in_channels=decoder_pyramid_channels, + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=1, + ) + ) + self.supervisor_heads.append( + SegmentationHead( + in_channels=decoder_pyramid_channels, + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=1, + ) + ) + self.supervisor_heads = nn.Sequential(*self.supervisor_heads) + + if aux_params is not None: + self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) + else: + self.classification_head = None + + self.name = "fpn-{}".format(encoder_name) + self.initialize() diff --git a/skp/models/segmentation/decoders/unet/__init__.py b/skp/models/segmentation/decoders/unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9a367cd7338999a742961fbc1a93289a6380da --- /dev/null +++ b/skp/models/segmentation/decoders/unet/__init__.py @@ -0,0 +1 @@ +from .model import Unet diff --git a/skp/models/segmentation/decoders/unet/__pycache__/__init__.cpython-39.pyc b/skp/models/segmentation/decoders/unet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6ad6915fc0efaa76e2613f3b373421add37f748 Binary files /dev/null and b/skp/models/segmentation/decoders/unet/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/unet/__pycache__/decoder.cpython-39.pyc b/skp/models/segmentation/decoders/unet/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ef7c088477a6d4e772806194023de85224ad0e0 Binary files /dev/null and b/skp/models/segmentation/decoders/unet/__pycache__/decoder.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/unet/__pycache__/model.cpython-39.pyc b/skp/models/segmentation/decoders/unet/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5f9b4cd18435d67d2eea784c4539faeb0722194 Binary files /dev/null and b/skp/models/segmentation/decoders/unet/__pycache__/model.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/unet/decoder.py b/skp/models/segmentation/decoders/unet/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe7cb6a910e5b07636e5b8e6817ae8c2bc0f8dd --- /dev/null +++ b/skp/models/segmentation/decoders/unet/decoder.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...base import modules as md + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + use_batchnorm=True, + attention_type=[None, None], + ): + super().__init__() + self.conv1 = md.Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention1 = md.Attention(attention_type[0], in_channels=in_channels + skip_channels) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention2 = md.Attention(attention_type[1], in_channels=out_channels) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.attention2(x) + return x + + +class CenterBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + conv1 = md.Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + super().__init__(conv1, conv2) + + +class UnetDecoder(nn.Module): + def __init__( + self, + encoder_channels, + decoder_channels, + n_blocks=5, + use_batchnorm=True, + attention_type=None, + center=False, + deep_supervision=False, + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( + n_blocks, len(decoder_channels) + ) + ) + + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + in_channels = [head_channels] + list(decoder_channels[:-1]) + skip_channels = list(encoder_channels[1:]) + [0] + out_channels = decoder_channels + + if center: + self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm) + else: + self.center = nn.Identity() + + self.deep_supervision = deep_supervision + + # combine decoder keyword arguments + kwargs = dict(use_batchnorm=use_batchnorm, attention_type=[attention_type, attention_type]) + blocks = [] + for block_idx, (in_ch, skip_ch, out_ch) in enumerate(zip(in_channels, skip_channels, out_channels)): + # For the last block, attention1 is not used + if block_idx == (len(in_channels) - 1): + kwargs["attention_type"] = [None, attention_type] + blocks.append(DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)) + blocks = [ + DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) + for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) + ] + self.blocks = nn.ModuleList(blocks) + + def forward(self, *features): + + features = features[::-1] # reverse channels to start from head of encoder + + head = features[0] + skips = features[1:] + + x = self.center(head) + + if self.deep_supervision and self.training: outputs = [] + + for i, decoder_block in enumerate(self.blocks): + skip = skips[i] if i < len(skips) else None + x = decoder_block(x, skip) + if self.deep_supervision and self.training: outputs.append(x) + + if self.deep_supervision and self.training: + return outputs + + return x diff --git a/skp/models/segmentation/decoders/unet/model.py b/skp/models/segmentation/decoders/unet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4d340e74c64835d68431a75d476abae6a1409990 --- /dev/null +++ b/skp/models/segmentation/decoders/unet/model.py @@ -0,0 +1,129 @@ +import torch.nn as nn + +from typing import Optional, Union, List + +from ...encoders.create import create_encoder +from ...base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from .decoder import UnetDecoder + + +class Unet(SegmentationModel): + """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* + and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial + resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* + for fusing decoder blocks with skip connections. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. + Length of the list should be the same as **encoder_depth** + decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. + Available options are **True, False, "inplace"** + decoder_attention_type: Attention module used in decoder of the model. Available options are + **None** and **scse** (https://arxiv.org/abs/1808.08127). + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: Unet + + .. _Unet: + https://arxiv.org/abs/1505.04597 + + """ + + def __init__( + self, + encoder_name: str, + encoder_params: dict = {"pretrained": True, "depth": 5}, + decoder_use_batchnorm: bool = True, + decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_attention_type: Optional[str] = None, + deep_supervision: bool = False, + dropout: float = 0.2, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + upsampling: int = 1, + aux_params: Optional[dict] = None, + ): + super().__init__() + + encoder_depth = encoder_params.pop("depth", 5) + self.encoder = create_encoder( + name=encoder_name, + encoder_params=encoder_params, + in_channels=in_channels + ) + + self.decoder = UnetDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_batchnorm=decoder_use_batchnorm, + center=True if encoder_name.startswith("vgg") else False, + deep_supervision=deep_supervision, + attention_type=decoder_attention_type, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=upsampling, + ) + + self.deep_supervision = deep_supervision + if self.deep_supervision: + self.supervisor_heads = [] + self.supervisor_heads.append( + SegmentationHead( + in_channels=decoder_channels[-2], + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=upsampling, + ) + ) + self.supervisor_heads.append( + SegmentationHead( + in_channels=decoder_channels[-3], + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=upsampling, + ) + ) + self.supervisor_heads = nn.Sequential(*self.supervisor_heads) + + if aux_params is not None: + self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) + else: + self.classification_head = None + + self.name = "u-{}".format(encoder_name) + self.initialize() diff --git a/skp/models/segmentation/decoders/unet_3d/__init__.py b/skp/models/segmentation/decoders/unet_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7d55746f79b0ff69a4023088baffdf5911b6c9 --- /dev/null +++ b/skp/models/segmentation/decoders/unet_3d/__init__.py @@ -0,0 +1 @@ +from .model import Unet_3D diff --git a/skp/models/segmentation/decoders/unet_3d/__pycache__/__init__.cpython-39.pyc b/skp/models/segmentation/decoders/unet_3d/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eef022c65daaf5a1b1dc11a750690aeec714228 Binary files /dev/null and b/skp/models/segmentation/decoders/unet_3d/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/unet_3d/__pycache__/decoder.cpython-39.pyc b/skp/models/segmentation/decoders/unet_3d/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8714c8ce408d89ecb5eda9fb17bf68a4309cf64d Binary files /dev/null and b/skp/models/segmentation/decoders/unet_3d/__pycache__/decoder.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/unet_3d/__pycache__/model.cpython-39.pyc b/skp/models/segmentation/decoders/unet_3d/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b318e801dce82976822f63401f82121101e338e Binary files /dev/null and b/skp/models/segmentation/decoders/unet_3d/__pycache__/model.cpython-39.pyc differ diff --git a/skp/models/segmentation/decoders/unet_3d/decoder.py b/skp/models/segmentation/decoders/unet_3d/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..66d672a1b6055c011930a815fec8ba93df784a91 --- /dev/null +++ b/skp/models/segmentation/decoders/unet_3d/decoder.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...base import modules as md + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + use_batchnorm=True, + attention_type=[None, None], + ): + super().__init__() + self.conv1 = md.Conv3dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention1 = md.Attention(attention_type[0], in_channels=in_channels + skip_channels) + self.conv2 = md.Conv3dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention2 = md.Attention(attention_type[1], in_channels=out_channels) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.attention2(x) + return x + + +class CenterBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + conv1 = md.Conv3dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + conv2 = md.Conv3dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + super().__init__(conv1, conv2) + + +class UnetDecoder_3D(nn.Module): + def __init__( + self, + encoder_channels, + decoder_channels, + n_blocks=5, + use_batchnorm=True, + attention_type=None, + center=False, + deep_supervision=False, + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( + n_blocks, len(decoder_channels) + ) + ) + + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + in_channels = [head_channels] + list(decoder_channels[:-1]) + skip_channels = list(encoder_channels[1:]) + [0] + out_channels = decoder_channels + + if center: + self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm) + else: + self.center = nn.Identity() + + self.deep_supervision = deep_supervision + + # combine decoder keyword arguments + kwargs = dict(use_batchnorm=use_batchnorm, attention_type=[attention_type, attention_type]) + blocks = [] + for block_idx, (in_ch, skip_ch, out_ch) in enumerate(zip(in_channels, skip_channels, out_channels)): + # For the last block, attention1 is not used + if block_idx == (len(in_channels) - 1): + kwargs["attention_type"] = [None, attention_type] + blocks.append(DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)) + blocks = [ + DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) + for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) + ] + self.blocks = nn.ModuleList(blocks) + + def forward(self, *features): + + features = features[::-1] # reverse channels to start from head of encoder + + head = features[0] + skips = features[1:] + + x = self.center(head) + + if self.deep_supervision and self.training: outputs = [] + + for i, decoder_block in enumerate(self.blocks): + skip = skips[i] if i < len(skips) else None + x = decoder_block(x, skip) + if self.deep_supervision and self.training: outputs.append(x) + + if self.deep_supervision and self.training: + return outputs + + return x diff --git a/skp/models/segmentation/decoders/unet_3d/model.py b/skp/models/segmentation/decoders/unet_3d/model.py new file mode 100644 index 0000000000000000000000000000000000000000..69d2df9f00fce7845caf28a4f82ecc8376eee24c --- /dev/null +++ b/skp/models/segmentation/decoders/unet_3d/model.py @@ -0,0 +1,131 @@ +import torch.nn as nn + +from typing import Optional, Union, List + +from ...encoders.create import create_encoder +from ...base import ( + SegmentationModel, + SegmentationHead_3D, + ClassificationHead, +) +from .decoder import UnetDecoder_3D + + +class Unet_3D(SegmentationModel): + """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* + and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial + resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* + for fusing decoder blocks with skip connections. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. + Length of the list should be the same as **encoder_depth** + decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. + Available options are **True, False, "inplace"** + decoder_attention_type: Attention module used in decoder of the model. Available options are + **None** and **scse** (https://arxiv.org/abs/1808.08127). + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: Unet + + .. _Unet: + https://arxiv.org/abs/1505.04597 + + """ + + def __init__( + self, + encoder_name: str, + encoder_params: dict = {"pretrained": True, "depth": 5}, + decoder_use_batchnorm: bool = True, + decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_attention_type: Optional[str] = None, + deep_supervision: bool = False, + dropout: float = 0.2, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + upsampling: int = 1, + aux_params: Optional[dict] = None, + ): + super().__init__() + + encoder_depth = encoder_params.pop("depth", 5) + self.encoder = create_encoder( + name=encoder_name, + encoder_params=encoder_params, + in_channels=in_channels + ) + + assert decoder_attention_type in [None, "scse_3d"] + + self.decoder = UnetDecoder_3D( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_batchnorm=decoder_use_batchnorm, + center=True if encoder_name.startswith("vgg") else False, + deep_supervision=deep_supervision, + attention_type=decoder_attention_type, + ) + + self.segmentation_head = SegmentationHead_3D( + in_channels=decoder_channels[-1], + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=upsampling, + ) + + self.deep_supervision = deep_supervision + if self.deep_supervision: + self.supervisor_heads = [] + self.supervisor_heads.append( + SegmentationHead_3D( + in_channels=decoder_channels[-2], + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=upsampling, + ) + ) + self.supervisor_heads.append( + SegmentationHead_3D( + in_channels=decoder_channels[-3], + out_channels=classes, + dropout=dropout, + kernel_size=3, + upsampling=upsampling, + ) + ) + self.supervisor_heads = nn.Sequential(*self.supervisor_heads) + + if aux_params is not None: + self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) + else: + self.classification_head = None + + self.name = "u-{}".format(encoder_name) + self.initialize() diff --git a/skp/models/segmentation/encoders/__init__.py b/skp/models/segmentation/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/skp/models/segmentation/encoders/__pycache__/__init__.cpython-39.pyc b/skp/models/segmentation/encoders/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..652f4910fec54413930773c04d22da89d12e360f Binary files /dev/null and b/skp/models/segmentation/encoders/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/segmentation/encoders/__pycache__/create.cpython-39.pyc b/skp/models/segmentation/encoders/__pycache__/create.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b273f334908e6cd3f0b790f25c5ad951f9448635 Binary files /dev/null and b/skp/models/segmentation/encoders/__pycache__/create.cpython-39.pyc differ diff --git a/skp/models/segmentation/encoders/__pycache__/swin_encoder.cpython-39.pyc b/skp/models/segmentation/encoders/__pycache__/swin_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbf4126af21d4c7cc29762272c1b520384e3dcd6 Binary files /dev/null and b/skp/models/segmentation/encoders/__pycache__/swin_encoder.cpython-39.pyc differ diff --git a/skp/models/segmentation/encoders/create.py b/skp/models/segmentation/encoders/create.py new file mode 100644 index 0000000000000000000000000000000000000000..222d5ec5588983167d321b8fe567208869096aed --- /dev/null +++ b/skp/models/segmentation/encoders/create.py @@ -0,0 +1,162 @@ +import re +import timm +import torch +import torch.nn as nn + +from ...backbones import create_x3d +from ...tools import change_num_input_channels +from .swin_encoder import SwinTransformer + + +def get_attribute(model, name): + """Hacked together function to retrieve the desired module from the model + based on its string attribute name. But it works. + """ + name = name.split(".") + for i, n in enumerate(name): + if i == 0: + if isinstance(n, int): + attr = model[n] + else: + attr = getattr(model, n) + else: + if isinstance(n, int): + attr = attr[n] + else: + attr = getattr(attr, n) + return attr + + +def check_if_int(s): + try: + _ = int(s) + return True + except ValueError: + return False + + +def create_encoder(name, encoder_params, encoder_output_stride=32, in_channels=3): + assert "pretrained" in encoder_params + + if name == "swin": + assert encoder_output_stride == 32, "`swin` encoders only support output_stride=32" + encoder = SwinTransformer(**encoder_params) + elif "x3d" in name: + encoder = create_x3d(name, features_only=True, **encoder_params) + assert encoder_output_stride in [16, 32] + if encoder_output_stride == 16: + encoder.model.blocks[-2].res_blocks[0].branch1_conv.stride = (1, 1, 1) + encoder.model.blocks[-2].res_blocks[0].branch2.conv_b.stride = (1, 1, 1) + else: + encoder = timm.create_model(name, features_only=True, **encoder_params) + encoder.out_channels = encoder.feature_info.channels() + + if encoder_output_stride != 32: + # Default for pretty much every model is 32 + # First, ensure that the provided stride is valid + assert 32 % encoder_output_stride == 0 + scale_factor = 32 // encoder_output_stride + layers_to_modify = 1 if scale_factor == 2 else 2 + + # First, get the layers with stride 2 + # For some models, there may be other conv layers with stride 2 + # that will need to be filtered out + # EfficientNet is OK + + if re.search(r"resnest", name): + if encoder_output_stride in [8, 16]: + encoder.layer4[0].downsample[0] = nn.Identity() + encoder.layer4[0].avd_last = nn.Identity() + if encoder_output_stride == 8: + encoder.layer3[0].downsample[0] = nn.Identity() + encoder.layer3[0].avd_last = nn.Identity() + else: + raise Exception(f"{name} only supports output stride of 8, 16, or 32") + + elif re.search(r"resnet[0-9]+d", name): + if encoder_output_stride in [8, 16]: + encoder.layer4[0].downsample[0] = nn.Identity() + encoder.layer4[0].conv1.stride = (1, 1) + encoder.layer4[0].conv2.stride = (1, 1) + if encoder_output_stride == 8: + encoder.layer3[0].downsample[0] = nn.Identity() + encoder.layer3[0].conv1.stride = (1, 1) + encoder.layer3[0].conv2.stride = (1, 1) + else: + raise Exception(f"{name} only supports output stride of 8, 16, or 32") + + elif re.search(r"regnet[x|y]", name): + downsample_convs = [] + for name, module in encoder.named_modules(): + if hasattr(module, "stride"): + if module.stride == (2, 2): + downsample_convs += [name] + + downsample_convs = downsample_convs[::-1] + for i in range(layers_to_modify * 2): + setattr(get_attribute(encoder, downsample_convs[i]), "stride", (1, 1)) + + elif re.search(r"efficientnet|regnetz|rexnet", name): + downsample_convs = [] + for name, module in encoder.named_modules(): + if hasattr(module, "stride"): + if module.stride == (2, 2): + downsample_convs += [name] + + downsample_convs = downsample_convs[::-1] + for i in range(layers_to_modify): + setattr(get_attribute(encoder, downsample_convs[i]), "stride", (1, 1)) + + elif re.search(r"convnext", name): + downsample_convs = [] + for name, module in encoder.named_modules(): + if hasattr(module, "stride"): + if module.stride == (2, 2): + downsample_convs += [name] + + downsample_convs = downsample_convs[::-1] + for i in range(layers_to_modify): + setattr(get_attribute(encoder, downsample_convs[i]), "stride", (1, 1)) + # Need to also change the kernel size ... + # This involves creating a new layer with the appropriate kernel size + # Then modifying the weights to fit the new kernel size + # Then changing the layer in the model + in_channels = get_attribute(encoder, downsample_convs[i]).in_channels + out_channels = get_attribute(encoder, downsample_convs[i]).out_channels + conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) + w = get_attribute(encoder, downsample_convs[i]).weight + w = w.mean(-1, keepdim=True).mean(-2, keepdim=True) + conv_layer.weight = nn.Parameter(w) + split_name = downsample_convs[i].split(".") + if check_if_int(split_name[-1]): + # If the module name ends with a number that means it's within a sequential object + # and needs to be modified by accessing the module within a list. + # + # So you have to get the SEQUENTIAL object (by getting the attribute WITHOUT the number + # at the end) and then use that number as the list index and set the layer + # to that layer. Phew. + get_attribute(encoder, ".".join(split_name[:-1]))[int(split_name[-1])] = conv_layer + else: + # If the module name ends with a string that means it can be accessed by + # just grabbing the attribute + setattr(get_attribute(encoder, ".".join(split_name[:-1])), split_name[-1], conv_layer) + + + else: + raise Exception (f"{name} is not yet supported for output stride < 32") + + # Run a quick test to make sure the output stride is correct + if "x3d" in name: + x = torch.randn((2,3,64,64,64)) + else: + x = torch.randn((2,3,128,128)) + final_out = encoder(x)[-1] + actual_output_stride = x.size(-1) // final_out.size(-1) + assert actual_output_stride == encoder_output_stride, f"Actual output stride [{actual_output_stride}] does not equal desired output stride [{encoder_output_stride}]" + print(f"Confirmed encoder output stride {encoder_output_stride} !") + encoder.output_stride = encoder_output_stride + + if in_channels != 3: + encoder = change_num_input_channels(encoder, in_channels) + + return encoder diff --git a/skp/models/segmentation/encoders/swin_encoder.py b/skp/models/segmentation/encoders/swin_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..be5885c9c244041e9be226c61060a54532e8167b --- /dev/null +++ b/skp/models/segmentation/encoders/swin_encoder.py @@ -0,0 +1,660 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 384. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 12. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=384, + model_size="base", + pretrained=True, + patch_size=4, + in_chans=3, + window_size=12, + deeplab=False, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + + if model_size == "large": + embed_dim = 192 + depths = [2, 2, 18, 2] + num_heads = [6, 12, 24, 48] + elif model_size == "base": + embed_dim = 128 + depths = [2, 2, 18, 2] + num_heads = [4, 8, 16, 32] + elif model_size == "small": + assert pretrain_img_size == 224, "`small` variant only supports `pretrain_img_size=224`" + assert window_size == 7, "`small` variant only supports `window_size=7`" + embed_dim = 96 + depths = [2, 2, 18, 2] + num_heads = [3, 6, 12, 24] + elif model_size == "tiny": + assert pretrain_img_size == 224, "`tiny` variant only supports `pretrain_img_size=224`" + assert window_size == 7, "`tiny` variant only supports `window_size=7`" + embed_dim = 96 + depths = [2, 2, 6, 2] + num_heads = [3, 6, 12, 24] + else: + raise Exception("`model_size` must be one of [`tiny`, `small`, `base`, `large`]") + + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.deeplab = deeplab + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + self.out_channels = num_features + + # add a norm layer for each output + for i_layer in out_indices: + if self.deeplab and i_layer in [1, 2]: + continue + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + if pretrained: + weights = timm.create_model(f"swin_{model_size}_patch{patch_size}_window{window_size}_{pretrain_img_size}", pretrained=True) + weights = weights.state_dict() + self.load_state_dict(weights, strict=False) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + # if isinstance(pretrained, str): + # self.apply(_init_weights) + # load_checkpoint(self, pretrained, strict=False, logger=logger) + # elif pretrained is None: + # self.apply(_init_weights) + # else: + # raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + if not self.deeplab or i not in [1, 2]: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() \ No newline at end of file diff --git a/skp/models/sequence.py b/skp/models/sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..62c9d32737456453ca0c5f8029337e8d6f92a5fb --- /dev/null +++ b/skp/models/sequence.py @@ -0,0 +1,232 @@ +import numpy as np +import torch +import torch.nn as nn + +from transformers.models.distilbert.modeling_distilbert import Transformer as T +from .pooling import create_pool1d_layer + + +class Config: + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class Transformer(nn.Module): + """ + If predict_sequence is True, then the model will predict an output + for each element in the sequence. If False, then the model will + predict a single output for the sequence. + + e.g., classifying each image in a CT scan vs the entire CT scan + """ + def __init__(self, + num_classes, + embedding_dim=512, + hidden_dim=1024, + n_layers=4, + n_heads=16, + dropout=0.2, + attention_dropout=0.1, + output_attentions=False, + activation='gelu', + output_hidden_states=False, + chunk_size_feed_forward=0, + predict_sequence=True, + pool=None + ): + super().__init__() + config = Config(**{ + 'dim': embedding_dim, + 'hidden_dim': hidden_dim, + 'n_layers': n_layers, + 'n_heads': n_heads, + 'dropout': dropout, + 'attention_dropout': attention_dropout, + 'output_attentions': output_attentions, + 'activation': activation, + 'output_hidden_states': output_hidden_states, + 'chunk_size_feed_forward': chunk_size_feed_forward + }) + + self.transformer = T(config) + self.predict_sequence = predict_sequence + + if not predict_sequence: + if isinstance(pool, str): + self.pool_layer = create_pool1d_layer(pool) + if pool == "catavgmax": + embedding_dim *= 2 + else: + self.pool_layer = nn.Identity() + + self.classifier = nn.Linear(embedding_dim, num_classes) + + def extract_features(self, x): + x, mask = x + x = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1)) + x = x[0] + + if not self.predict_sequence: + if isinstance(self.pool_layer, nn.Identity): + # Just take the last vector in the sequence + x = x[:, 0] + else: + x = self.pool_layer(x.transpose(-1, -2)) + + return x + + def classify(self, x): + + if not self.predict_sequence: + if isinstance(self.pool_layer, nn.Identity): + # Just take the last vector in the sequence + x = x[:, 0] + else: + x = self.pool_layer(x.transpose(-1, -2)) + + out = self.classifier(x) + + if self.classifier.out_features == 1: + return out[..., 0] + else: + return out + + def forward_tr(self, x, mask): + output = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1)) + return self.classify(output[0]) + + def forward(self, x): + x, mask = x + return self.forward_tr(x, mask) + + +class DualTransformer(nn.Module): + """ + Essentially the same as above except predicts both sequence and study labels. + """ + def __init__(self, + num_classes, + embedding_dim=512, + hidden_dim=1024, + n_layers=4, + n_heads=16, + dropout=0.2, + attention_dropout=0.1, + output_attentions=False, + activation='gelu', + output_hidden_states=False, + chunk_size_feed_forward=0, + pool=None + ): + super().__init__() + config = Config(**{ + 'dim': embedding_dim, + 'hidden_dim': hidden_dim, + 'n_layers': n_layers, + 'n_heads': n_heads, + 'dropout': dropout, + 'attention_dropout': attention_dropout, + 'output_attentions': output_attentions, + 'activation': activation, + 'output_hidden_states': output_hidden_states, + 'chunk_size_feed_forward': chunk_size_feed_forward + }) + + self.transformer = T(config) + + self.pool_layer = nn.Identity() + + self.classifier1 = nn.Linear(embedding_dim, num_classes) + self.classifier2 = nn.Linear(embedding_dim, num_classes) + + def classify(self, x): + + if isinstance(self.pool_layer, nn.Identity): + # Just take the last vector in the sequence + x_summ = x[:, 0] + else: + x_summ = self.pool_layer(x.transpose(-1, -2)) + + # Element-wise labels + out1 = self.classifier1(x)[:, :, 0] + # Single label for whole sequence + out2 = self.classifier2(x_summ) + out = torch.cat([out1, out2], dim=1) + return out + + def forward_tr(self, x, mask): + output = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1)) + return self.classify(output[0]) + + def forward(self, x): + x, mask = x + return self.forward_tr(x, mask) + + +class DualTransformerV2(nn.Module): + """ + More complicated variant of DualTransformer. + Returns tuple of: (element-wise prediction, sequence prediction) + """ + def __init__(self, + num_seq_classes, + num_classes, + embedding_dim=512, + hidden_dim=1024, + n_layers=4, + n_heads=16, + dropout=0.2, + attention_dropout=0.1, + output_attentions=False, + activation='gelu', + output_hidden_states=False, + chunk_size_feed_forward=0, + pool=None + ): + super().__init__() + config = Config(**{ + 'dim': embedding_dim, + 'hidden_dim': hidden_dim, + 'n_layers': n_layers, + 'n_heads': n_heads, + 'dropout': dropout, + 'attention_dropout': attention_dropout, + 'output_attentions': output_attentions, + 'activation': activation, + 'output_hidden_states': output_hidden_states, + 'chunk_size_feed_forward': chunk_size_feed_forward + }) + + self.transformer = T(config) + + self.pool_layer = nn.Identity() + + self.classifier1 = nn.Linear(embedding_dim, num_seq_classes) + self.classifier2 = nn.Linear(embedding_dim, num_classes) + + def classify(self, x): + + if isinstance(self.pool_layer, nn.Identity): + # Just take the last vector in the sequence + x_summ = x[:, 0] + else: + x_summ = self.pool_layer(x.transpose(-1, -2)) + + # Element-wise labels + out1 = self.classifier1(x) + if self.classifier1.out_features == 1: + out1 = out1[:, :, 0] + # Single label for whole sequence + out2 = self.classifier2(x_summ) + return out1, out2 + + def forward_tr(self, x, mask): + output = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1)) + return self.classify(output[0]) + + def forward(self, x): + x, mask = x + return self.forward_tr(x, mask) + diff --git a/skp/models/tools.py b/skp/models/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..e92b946c9828256fa9e19833bd7b2b72677b3c17 --- /dev/null +++ b/skp/models/tools.py @@ -0,0 +1,30 @@ +import torch.nn as nn + + +def change_num_input_channels(model, in_channels=1): + """ + Assumes number of input channels in model is 3. + """ + for i, m in enumerate(model.modules()): + if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3: + m.in_channels = in_channels + # First, sum across channels + W = m.weight.sum(1, keepdim=True) + # Then, divide by number of channels + W = W / in_channels + # Then, repeat by number of channels + size = [1] * W.ndim + size[1] = in_channels + W = W.repeat(size) + m.weight = nn.Parameter(W) + break + return model + + +def change_initial_stride(model, stride, in_channels): + + for i, m in enumerate(model.modules()): + if isinstance(m, (nn.Conv2d, nn.Conv3d)) and m.in_channels == in_channels: + m.stride = stride + break + return model \ No newline at end of file diff --git a/skp/models/vmz/__init__.py b/skp/models/vmz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..328697f133eb93d54a4bc84e52f1c0bf329fca62 --- /dev/null +++ b/skp/models/vmz/__init__.py @@ -0,0 +1,2 @@ +from .csn import * +from .r2plus1d import * diff --git a/skp/models/vmz/__pycache__/__init__.cpython-39.pyc b/skp/models/vmz/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29f62abc0b2abb7c8885ac526dc78f0edbd41ffe Binary files /dev/null and b/skp/models/vmz/__pycache__/__init__.cpython-39.pyc differ diff --git a/skp/models/vmz/__pycache__/backbones.cpython-39.pyc b/skp/models/vmz/__pycache__/backbones.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd4d6bfd251069442211eb05df879b7a5effa66a Binary files /dev/null and b/skp/models/vmz/__pycache__/backbones.cpython-39.pyc differ diff --git a/skp/models/vmz/__pycache__/csn.cpython-39.pyc b/skp/models/vmz/__pycache__/csn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d353525b39ec63cb4102ce82b61eb92f9f5e7d22 Binary files /dev/null and b/skp/models/vmz/__pycache__/csn.cpython-39.pyc differ diff --git a/skp/models/vmz/__pycache__/r2plus1d.cpython-39.pyc b/skp/models/vmz/__pycache__/r2plus1d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ec28b170a5e7dcb1ce5bf75931cd21d9f520e11 Binary files /dev/null and b/skp/models/vmz/__pycache__/r2plus1d.cpython-39.pyc differ diff --git a/skp/models/vmz/__pycache__/utils.cpython-39.pyc b/skp/models/vmz/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3708a27f3ff6af43880108bc389e190313bc2e1f Binary files /dev/null and b/skp/models/vmz/__pycache__/utils.cpython-39.pyc differ diff --git a/skp/models/vmz/backbones.py b/skp/models/vmz/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..4c29cbaf2db0ba41ada1eb3f31014d8c09f2e4d9 --- /dev/null +++ b/skp/models/vmz/backbones.py @@ -0,0 +1,53 @@ +import torch.nn as nn + +from . import csn +from . import r2plus1d + + +def ir_csn_152(pretrained=True, **kwargs): + model = csn.ir_csn_152(pretraining='ig65m_32frms' if pretrained else '', num_classes=359) + model.avgpool = nn.Identity() + model.fc = nn.Identity() + return model + + +def ir_csn_101(pretrained=True, **kwargs): + model = ir_csn_152(pretrained=pretrained, **kwargs) + model.layer2 = model.layer2[:4] + model.layer3 = model.layer3[:23] + return model + + +def ir_csn_50(pretrained=True, **kwargs): + model = ir_csn_152(pretrained=pretrained, **kwargs) + model.layer2 = model.layer2[:4] + model.layer3 = model.layer3[:6] + return model + + +def ip_csn_152(pretrained=True, **kwargs): + model = csn.ip_csn_152(pretraining='ig65m_32frms' if pretrained else '', num_classes=359) + model.avgpool = nn.Identity() + model.fc = nn.Identity() + return model + + +def ip_csn_101(pretrained=True, **kwargs): + model = ip_csn_152(pretrained=pretrained, **kwargs) + model.layer2 = model.layer2[:4] + model.layer3 = model.layer3[:23] + return model + + +def ip_csn_50(pretrained=True, **kwargs): + model = ip_csn_152(pretrained=pretrained, **kwargs) + model.layer2 = model.layer2[:4] + model.layer3 = model.layer3[:6] + return model + + +def r2plus1d_34(pretrained=True, **kwargs): + model = r2plus1d.r2plus1d_34(pretraining='32_ig65m' if pretrained else '', num_classes=359) + model.avgpool = nn.Identity() + model.fc = nn.Identity() + return model \ No newline at end of file diff --git a/skp/models/vmz/csn.py b/skp/models/vmz/csn.py new file mode 100644 index 0000000000000000000000000000000000000000..e7112b0f7586e062c6566e4905838d42a22e0350 --- /dev/null +++ b/skp/models/vmz/csn.py @@ -0,0 +1,68 @@ +import warnings + +import torch.hub +import torch.nn as nn +from torchvision.models.video.resnet import BasicStem, BasicBlock, Bottleneck + +from .utils import _generic_resnet, Conv3DDepthwise, BasicStem_Pool, IPConv3DDepthwise + + +__all__ = ["ir_csn_152", "ip_csn_152"] + + +def ir_csn_152(pretraining="", use_pool1=True, progress=False, **kwargs): + avail_pretrainings = [ + "ig65m_32frms", + "ig_ft_kinetics_32frms", + "sports1m_32frms", + "sports1m_ft_kinetics_32frms", + ] + + if pretraining in avail_pretrainings: + arch = "ir_csn_152_" + pretraining + pretrained = True + else: + arch = "ir_csn_152" + pretrained = False + + model = _generic_resnet( + arch, + pretrained, + progress, + block=Bottleneck, + conv_makers=[Conv3DDepthwise] * 4, + layers=[3, 8, 36, 3], + stem=BasicStem_Pool if use_pool1 else BasicStem, + **kwargs, + ) + + return model + + +def ip_csn_152(pretraining="", use_pool1=True, progress=False, **kwargs): + avail_pretrainings = [ + "ig65m_32frms", + "ig_ft_kinetics_32frms", + "sports1m_32frms", + "sports1m_ft_kinetics_32frms", + ] + + if pretraining in avail_pretrainings: + arch = "ip_csn_152_" + pretraining + pretrained = True + else: + arch = "ip_csn_152" + pretrained = False + + model = _generic_resnet( + arch, + pretrained, + progress, + block=Bottleneck, + conv_makers=[IPConv3DDepthwise] * 4, + layers=[3, 8, 36, 3], + stem=BasicStem_Pool if use_pool1 else BasicStem, + **kwargs, + ) + + return model diff --git a/skp/models/vmz/r2plus1d.py b/skp/models/vmz/r2plus1d.py new file mode 100644 index 0000000000000000000000000000000000000000..f252795dd5db99c991975eb22dc74a3ed541a388 --- /dev/null +++ b/skp/models/vmz/r2plus1d.py @@ -0,0 +1,100 @@ +import warnings + +import torch.hub +import torch.nn as nn +from torchvision.models.video.resnet import R2Plus1dStem, BasicBlock, Bottleneck + + +from .utils import _generic_resnet, R2Plus1dStem_Pool, Conv2Plus1D, model_urls + + +__all__ = ["r2plus1d_34", "r2plus1d_152"] + + +def r2plus1d_34(pretraining="", use_pool1=False, progress=False, **kwargs): + avail_pretrainings = [ + "kinetics_8frms", + "kinetics_32frms", + "ig65m_8frms", + "ig65m_32frms", + "32_ig65m" + ] + if pretraining in avail_pretrainings: + arch = "r2plus1d_34_" + pretraining + pretrained = True + else: + warnings.warn( + f"Unrecognized pretraining dataset, continuing with randomly initialized network." + " Available pretrainings: {avail_pretrainings}", + UserWarning, + ) + arch = "r2plus1d_34" + pretrained = False + + model = _generic_resnet( + arch, + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv2Plus1D] * 4, + layers=[3, 4, 6, 3], + stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem, + **kwargs, + ) + # We need exact Caffe2 momentum for BatchNorm scaling + for m in model.modules(): + if isinstance(m, nn.BatchNorm3d): + m.eps = 1e-3 + m.momentum = 0.9 + + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + model_urls[arch], progress=progress + ) + model.load_state_dict(state_dict) + + return model + + +def r2plus1d_152(pretraining="", use_pool1=True, progress=False, **kwargs): + avail_pretrainings = [ + "ig65m_32frms", + "ig_ft_kinetics_32frms", + "sports1m_32frms", + "sports1m_ft_kinetics_32frms", + ] + if pretraining in avail_pretrainings: + arch = "r2plus1d_" + pretraining + pretrained = True + else: + warnings.warn( + f"Unrecognized pretraining dataset, continuing with randomly initialized network." + " Available pretrainings: {avail_pretrainings}", + UserWarning, + ) + arch = "r2plus1d_34" + pretrained = False + + model = _generic_resnet( + arch, + pretrained, + progress, + block=Bottleneck, + conv_makers=[Conv2Plus1D] * 4, + layers=[3, 8, 36, 3], + stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem, + **kwargs, + ) + # We need exact Caffe2 momentum for BatchNorm scaling + for m in model.modules(): + if isinstance(m, nn.BatchNorm3d): + m.eps = 1e-3 + m.momentum = 0.9 + + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + model_urls[arch], progress=progress + ) + model.load_state_dict(state_dict) + + return model diff --git a/skp/models/vmz/utils.py b/skp/models/vmz/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad591f27a3fc2eb80f89ec41070f6fe02e9c41b --- /dev/null +++ b/skp/models/vmz/utils.py @@ -0,0 +1,250 @@ +import torch +import torch.nn as nn + +from torch import Tensor +from torchvision.models.video.resnet import BasicBlock, Bottleneck, Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D +from typing import Callable, List, Sequence, Type, Union + + +# TODO: upload models and load them +model_urls = { + "r2plus1d_34_8_ig65m": "https://github.com/moabitcoin/ig65m-pytorch/releases/download/v1.0.0/r2plus1d_34_clip8_ig65m_from_scratch-9bae36ae.pth", # noqa: E501 + "r2plus1d_34_32_ig65m": "https://github.com/moabitcoin/ig65m-pytorch/releases/download/v1.0.0/r2plus1d_34_clip32_ig65m_from_scratch-449a7af9.pth", # noqa: E501 + "r2plus1d_34_8_kinetics": "https://github.com/moabitcoin/ig65m-pytorch/releases/download/v1.0.0/r2plus1d_34_clip8_ft_kinetics_from_ig65m-0aa0550b.pth", # noqa: E501 + "r2plus1d_34_32_kinetics": "https://github.com/moabitcoin/ig65m-pytorch/releases/download/v1.0.0/r2plus1d_34_clip32_ft_kinetics_from_ig65m-ade133f1.pth", # noqa: E501 + "r2plus1d_152_ig65m_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/r2plus1d_152_ig65m_from_scratch_f106380637.pth", + "r2plus1d_152_ig_ft_kinetics_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/r2plus1d_152_ft_kinetics_from_ig65m_f107107466.pth", + "r2plus1d_152_sports1m_32frms": "", + "r2plus1d_152_sports1m_ft_kinetics_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/r2plus1d_152_ft_kinetics_from_sports1m_f128957437.pth", + "ir_csn_152_ig65m_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/irCSN_152_ig65m_from_scratch_f125286141.pth", + "ir_csn_152_ig_ft_kinetics_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/irCSN_152_ft_kinetics_from_ig65m_f126851907.pth", + "ir_csn_152_sports1m_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/irCSN_152_Sports1M_from_scratch_f99918785.pth", + "ir_csn_152_sports1m_ft_kinetics_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/irCSN_152_ft_kinetics_from_Sports1M_f101599884.pth", + "ip_csn_152_ig65m_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/ipCSN_152_ig65m_from_scratch_f130601052.pth", + "ip_csn_152_ig_ft_kinetics_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/ipCSN_152_ft_kinetics_from_ig65m_f133090949.pth", + "ip_csn_152_sports1m_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/ipCSN_152_Sports1M_from_scratch_f111018543.pth", + "ip_csn_152_sports1m_ft_kinetics_32frms": "https://github.com/bjuncek/VMZ/releases/download/test_models/ipCSN_152_ft_kinetics_from_Sports1M_f111279053.pth", +} + + +class VideoResNet(nn.Module): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], + layers: List[int], + stem: Callable[..., nn.Module], + num_classes: int = 400, + zero_init_residual: bool = False, + ) -> None: + """Generic resnet video generator. + + Args: + block (Type[Union[BasicBlock, Bottleneck]]): resnet building block + conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator + function for each layer + layers (List[int]): number of blocks per layer + stem (Callable[..., nn.Module]): module specifying the ResNet stem. + num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. + zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. + """ + super().__init__() + self.inplanes = 64 + + self.stem = stem() + + self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) + self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) + + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + # init weights + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[union-attr, arg-type] + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = self.fc(x) + + return x + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]], + planes: int, + blocks: int, + stride: int = 1, + ) -> nn.Sequential: + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + ds_stride = conv_builder.get_downsample_stride(stride) + downsample = nn.Sequential( + nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False), + nn.BatchNorm3d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_builder)) + + return nn.Sequential(*layers) + + +def _generic_resnet(arch, pretrained=False, progress=False, **kwargs): + model = VideoResNet(**kwargs) + + # We need exact Caffe2 momentum for BatchNorm scaling + for m in model.modules(): + if isinstance(m, nn.BatchNorm3d): + m.eps = 1e-3 + m.momentum = 0.9 + + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + model_urls[arch], progress=progress + ) + model.load_state_dict(state_dict) + + return model + + +class BasicStem_Pool(nn.Sequential): + def __init__(self): + super(BasicStem_Pool, self).__init__( + nn.Conv3d( + 3, + 64, + kernel_size=(3, 7, 7), + stride=(1, 2, 2), + padding=(1, 3, 3), + bias=False, + ), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), + ) + + +class R2Plus1dStem_Pool(nn.Sequential): + """R(2+1)D stem is different than the default one as it uses separated 3D convolution + """ + + def __init__(self): + super(R2Plus1dStem_Pool, self).__init__( + nn.Conv3d( + 3, + 45, + kernel_size=(1, 7, 7), + stride=(1, 2, 2), + padding=(0, 3, 3), + bias=False, + ), + nn.BatchNorm3d(45), + nn.ReLU(inplace=True), + nn.Conv3d( + 45, + 64, + kernel_size=(3, 1, 1), + stride=(1, 1, 1), + padding=(1, 0, 0), + bias=False, + ), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), + ) + + +class Conv3DDepthwise(nn.Conv3d): + def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): + + assert in_planes == out_planes + super(Conv3DDepthwise, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + groups=in_planes, + bias=False, + ) + + @staticmethod + def get_downsample_stride(stride): + return (stride, stride, stride) + + +class IPConv3DDepthwise(nn.Sequential): + def __init__(self, in_planes, out_planes, midplanes, stride=1, padding=1): + + assert in_planes == out_planes + super(IPConv3DDepthwise, self).__init__( + nn.Conv3d(in_planes, out_planes, kernel_size=1, bias=False), + nn.BatchNorm3d(out_planes), + # nn.ReLU(inplace=True), + Conv3DDepthwise(out_planes, out_planes, None, stride), + ) + + @staticmethod + def get_downsample_stride(stride): + return (stride, stride, stride) + + +class Conv2Plus1D(nn.Sequential): + def __init__(self, in_planes, out_planes, midplanes, stride=1, padding=1): + + midplanes = (in_planes * out_planes * 3 * 3 * 3) // ( + in_planes * 3 * 3 + 3 * out_planes + ) + super(Conv2Plus1D, self).__init__( + nn.Conv3d( + in_planes, + midplanes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + ), + nn.BatchNorm3d(midplanes), + nn.ReLU(inplace=True), + nn.Conv3d( + midplanes, + out_planes, + kernel_size=(3, 1, 1), + stride=(stride, 1, 1), + padding=(padding, 0, 0), + bias=False, + ), + ) + + @staticmethod + def get_downsample_stride(stride): + return (stride, stride, stride) diff --git a/x3d.ckpt b/x3d.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..1c125628eb8f74d939c11241ac145f54688bc35c --- /dev/null +++ b/x3d.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ef70897cfa159a662e0b5bb77059a93aeab895206350b9b1ac12420431ee8c7 +size 18530917