Spaces:
Runtime error
Runtime error
NoelShin
commited on
Commit
·
35188e4
1
Parent(s):
6c45278
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .idea/.gitignore +8 -0
- .idea/deployment.xml +15 -0
- .idea/inspectionProfiles/Project_Default.xml +23 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/selfmask_demo.iml +8 -0
- .idea/sonarlint/issuestore/index.pb +0 -0
- .idea/webServers.xml +14 -0
- __pycache__/bilateral_solver.cpython-38.pyc +0 -0
- __pycache__/utils.cpython-38.pyc +0 -0
- app.py +134 -0
- bilateral_solver.py +206 -0
- duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml +56 -0
- networks/__init__.py +0 -0
- networks/__pycache__/__init__.cpython-38.pyc +0 -0
- networks/__pycache__/timm_deit.cpython-38.pyc +0 -0
- networks/__pycache__/timm_vit.cpython-38.pyc +0 -0
- networks/__pycache__/vision_transformer.cpython-38.pyc +0 -0
- networks/maskformer/__pycache__/maskformer.cpython-38.pyc +0 -0
- networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc +0 -0
- networks/maskformer/maskformer.py +267 -0
- networks/maskformer/positional_embedding.py +48 -0
- networks/maskformer/transformer_decoder.py +376 -0
- networks/module_helper.py +176 -0
- networks/resnet.py +60 -0
- networks/resnet_backbone.py +194 -0
- networks/resnet_models.py +273 -0
- networks/timm_deit.py +254 -0
- networks/timm_vit.py +819 -0
- networks/vision_transformer.py +569 -0
- resources/.DS_Store +0 -0
- resources/0053.jpg +0 -0
- resources/0236.jpg +0 -0
- resources/0239.jpg +0 -0
- resources/0403.jpg +0 -0
- resources/0412.jpg +0 -0
- resources/ILSVRC2012_test_00005309.jpg +0 -0
- resources/ILSVRC2012_test_00012622.jpg +0 -0
- resources/ILSVRC2012_test_00022698.jpg +0 -0
- resources/ILSVRC2012_test_00040725.jpg +0 -0
- resources/ILSVRC2012_test_00075738.jpg +0 -0
- resources/ILSVRC2012_test_00080683.jpg +0 -0
- resources/ILSVRC2012_test_00085874.jpg +0 -0
- resources/im052.jpg +0 -0
- resources/sun_ainjbonxmervsvpv.jpg +0 -0
- resources/sun_alfntqzssslakmss.jpg +0 -0
- resources/sun_amnrcxhisjfrliwa.jpg +0 -0
- resources/sun_bvyxpvkouzlfwwod.jpg +0 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/deployment.xml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="PublishConfigData" autoUpload="Always" serverName="mydev" remoteFilesAllowedToDisappearOnAutoupload="false">
|
| 4 |
+
<serverData>
|
| 5 |
+
<paths name="mydev">
|
| 6 |
+
<serverdata>
|
| 7 |
+
<mappings>
|
| 8 |
+
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
| 9 |
+
</mappings>
|
| 10 |
+
</serverdata>
|
| 11 |
+
</paths>
|
| 12 |
+
</serverData>
|
| 13 |
+
<option name="myAutoUpload" value="ALWAYS" />
|
| 14 |
+
</component>
|
| 15 |
+
</project>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
| 5 |
+
<option name="ignoredPackages">
|
| 6 |
+
<value>
|
| 7 |
+
<list size="10">
|
| 8 |
+
<item index="0" class="java.lang.String" itemvalue="prettytable" />
|
| 9 |
+
<item index="1" class="java.lang.String" itemvalue="interrogate" />
|
| 10 |
+
<item index="2" class="java.lang.String" itemvalue="pytest" />
|
| 11 |
+
<item index="3" class="java.lang.String" itemvalue="yapf" />
|
| 12 |
+
<item index="4" class="java.lang.String" itemvalue="cityscapesscripts" />
|
| 13 |
+
<item index="5" class="java.lang.String" itemvalue="Wand" />
|
| 14 |
+
<item index="6" class="java.lang.String" itemvalue="isort" />
|
| 15 |
+
<item index="7" class="java.lang.String" itemvalue="xdoctest" />
|
| 16 |
+
<item index="8" class="java.lang.String" itemvalue="codecov" />
|
| 17 |
+
<item index="9" class="java.lang.String" itemvalue="flake8" />
|
| 18 |
+
</list>
|
| 19 |
+
</value>
|
| 20 |
+
</option>
|
| 21 |
+
</inspection_tool>
|
| 22 |
+
</profile>
|
| 23 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pytorch)" project-jdk-type="Python SDK" />
|
| 4 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/selfmask_demo.iml" filepath="$PROJECT_DIR$/.idea/selfmask_demo.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/selfmask_demo.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
.idea/sonarlint/issuestore/index.pb
ADDED
|
File without changes
|
.idea/webServers.xml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="WebServers">
|
| 4 |
+
<option name="servers">
|
| 5 |
+
<webServer id="12e2cf4d-3b81-4241-9665-54a333f70567" name="mydev">
|
| 6 |
+
<fileTransfer rootFolder="/users/gyungin/selfmask_demo" accessType="SFTP" host="mydev" port="22" sshConfigId="3e23a652-ab3c-4dc2-a117-84c2bf217891" sshConfig="gyungin@mydev:22 password">
|
| 7 |
+
<advancedOptions>
|
| 8 |
+
<advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
|
| 9 |
+
</advancedOptions>
|
| 10 |
+
</fileTransfer>
|
| 11 |
+
</webServer>
|
| 12 |
+
</option>
|
| 13 |
+
</component>
|
| 14 |
+
</project>
|
__pycache__/bilateral_solver.cpython-38.pyc
ADDED
|
Binary file (6.76 kB). View file
|
|
|
__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (2.9 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser, Namespace
|
| 2 |
+
from typing import Dict, List, Tuple
|
| 3 |
+
import yaml
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torchvision.transforms.functional import to_tensor, normalize, resize
|
| 10 |
+
import gradio as gr
|
| 11 |
+
from utils import get_model
|
| 12 |
+
from bilateral_solver import bilateral_solver_output
|
| 13 |
+
import os
|
| 14 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
| 15 |
+
|
| 16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
state_dict: dict = torch.hub.load_state_dict_from_url(
|
| 18 |
+
"https://github.com/NoelShin/selfmask/releases/download/v1.0.0/selfmask_nq20.pt",
|
| 19 |
+
map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
)["model"]
|
| 21 |
+
|
| 22 |
+
parser = ArgumentParser("SelfMask demo")
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--config",
|
| 25 |
+
type=str,
|
| 26 |
+
default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# parser.add_argument(
|
| 30 |
+
# "--p_state_dict",
|
| 31 |
+
# type=str,
|
| 32 |
+
# default="/users/gyungin/selfmask_bak/ckpt/nq20_ndl6_bc_sr10100_duts_pm_all_k2,3,4_md_seed0_final/eval/hku_is/best_model.pt",
|
| 33 |
+
# )
|
| 34 |
+
#
|
| 35 |
+
# parser.add_argument(
|
| 36 |
+
# "--dataset_name", '-dn', type=str, default="duts",
|
| 37 |
+
# choices=["dut_omron", "duts", "ecssd"]
|
| 38 |
+
# )
|
| 39 |
+
|
| 40 |
+
# independent variables
|
| 41 |
+
# parser.add_argument("--use_gpu", type=bool, default=True)
|
| 42 |
+
# parser.add_argument('--seed', default=0, type=int)
|
| 43 |
+
# parser.add_argument("--dir_root", type=str, default="..")
|
| 44 |
+
# parser.add_argument("--gpu_id", type=int, default=2)
|
| 45 |
+
# parser.add_argument("--suffix", type=str, default='')
|
| 46 |
+
args: Namespace = parser.parse_args()
|
| 47 |
+
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
|
| 48 |
+
base_args.pop("dataset_name")
|
| 49 |
+
args: dict = vars(args)
|
| 50 |
+
args.update(base_args)
|
| 51 |
+
args: Namespace = Namespace(**args)
|
| 52 |
+
|
| 53 |
+
model = get_model(arch="maskformer", configs=args).to(device)
|
| 54 |
+
model.load_state_dict(state_dict)
|
| 55 |
+
model.eval()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@torch.no_grad()
|
| 59 |
+
def main(
|
| 60 |
+
image: Image.Image,
|
| 61 |
+
size: int = 384,
|
| 62 |
+
max_size: int = 512,
|
| 63 |
+
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
|
| 64 |
+
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
|
| 65 |
+
):
|
| 66 |
+
pil_image: Image.Image = resize(image, size=size, max_size=max_size)
|
| 67 |
+
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
|
| 68 |
+
dict_outputs = model(image[None].to(device))
|
| 69 |
+
|
| 70 |
+
batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1]
|
| 71 |
+
batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1]
|
| 72 |
+
|
| 73 |
+
if len(batch_pred_masks.shape) == 5:
|
| 74 |
+
# b x n_layers x n_queries x h x w -> b x n_queries x h x w
|
| 75 |
+
batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer
|
| 76 |
+
|
| 77 |
+
if batch_objectness is not None:
|
| 78 |
+
# b x n_layers x n_queries x 1 -> b x n_queries x 1
|
| 79 |
+
batch_objectness = batch_objectness[:, -1, ...]
|
| 80 |
+
|
| 81 |
+
# resize prediction to original resolution
|
| 82 |
+
# note: upsampling by 4 and cutting the padded region allows for a better result
|
| 83 |
+
H, W = image.shape[-2:]
|
| 84 |
+
batch_pred_masks = F.interpolate(
|
| 85 |
+
batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False
|
| 86 |
+
)[..., :H, :W]
|
| 87 |
+
|
| 88 |
+
# iterate over batch dimension
|
| 89 |
+
for batch_index, pred_masks in enumerate(batch_pred_masks):
|
| 90 |
+
# n_queries x 1 -> n_queries
|
| 91 |
+
objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1)
|
| 92 |
+
ranks = torch.argsort(objectness, descending=True) # n_queries
|
| 93 |
+
pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W
|
| 94 |
+
pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
|
| 95 |
+
|
| 96 |
+
pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64
|
| 97 |
+
pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
|
| 98 |
+
|
| 99 |
+
attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
|
| 100 |
+
super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
|
| 101 |
+
return super_imposed_img
|
| 102 |
+
# return pred_mask_bi
|
| 103 |
+
|
| 104 |
+
demo = gr.Interface(
|
| 105 |
+
fn=main,
|
| 106 |
+
inputs=gr.inputs.Image(type="pil"),
|
| 107 |
+
outputs="image",
|
| 108 |
+
examples=[f"resources/{fname}.jpg" for fname in [
|
| 109 |
+
"0053",
|
| 110 |
+
"0236",
|
| 111 |
+
"0239",
|
| 112 |
+
"0403",
|
| 113 |
+
"0412",
|
| 114 |
+
"ILSVRC2012_test_00005309",
|
| 115 |
+
"ILSVRC2012_test_00012622",
|
| 116 |
+
"ILSVRC2012_test_00022698",
|
| 117 |
+
"ILSVRC2012_test_00040725",
|
| 118 |
+
"ILSVRC2012_test_00075738",
|
| 119 |
+
"ILSVRC2012_test_00080683",
|
| 120 |
+
"ILSVRC2012_test_00085874",
|
| 121 |
+
"im052",
|
| 122 |
+
"sun_ainjbonxmervsvpv",
|
| 123 |
+
"sun_alfntqzssslakmss",
|
| 124 |
+
"sun_amnrcxhisjfrliwa",
|
| 125 |
+
"sun_bvyxpvkouzlfwwod"
|
| 126 |
+
]],
|
| 127 |
+
title="Unsupervised Salient Object Detection with Spectral Cluster Voting",
|
| 128 |
+
allow_flagging="never",
|
| 129 |
+
analytics_enabled=False
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
demo.launch(
|
| 133 |
+
# share=True
|
| 134 |
+
)
|
bilateral_solver.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scipy.sparse import diags
|
| 2 |
+
from scipy.sparse.linalg import cg
|
| 3 |
+
from scipy.sparse import csr_matrix
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.io import imread
|
| 6 |
+
from scipy import ndimage
|
| 7 |
+
import torch
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import os
|
| 10 |
+
from argparse import ArgumentParser, Namespace
|
| 11 |
+
from typing import Dict, Union
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
import yaml
|
| 14 |
+
import ujson as json
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
RGB_TO_YUV = np.array([
|
| 22 |
+
[0.299, 0.587, 0.114],
|
| 23 |
+
[-0.168736, -0.331264, 0.5],
|
| 24 |
+
[0.5, -0.418688, -0.081312]])
|
| 25 |
+
YUV_TO_RGB = np.array([
|
| 26 |
+
[1.0, 0.0, 1.402],
|
| 27 |
+
[1.0, -0.34414, -0.71414],
|
| 28 |
+
[1.0, 1.772, 0.0]])
|
| 29 |
+
YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)
|
| 30 |
+
MAX_VAL = 255.0
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def rgb2yuv(im):
|
| 34 |
+
return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def yuv2rgb(im):
|
| 38 |
+
return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_valid_idx(valid, candidates):
|
| 42 |
+
"""Find which values are present in a list and where they are located"""
|
| 43 |
+
locs = np.searchsorted(valid, candidates)
|
| 44 |
+
# Handle edge case where the candidate is larger than all valid values
|
| 45 |
+
locs = np.clip(locs, 0, len(valid) - 1)
|
| 46 |
+
# Identify which values are actually present
|
| 47 |
+
valid_idx = np.flatnonzero(valid[locs] == candidates)
|
| 48 |
+
locs = locs[valid_idx]
|
| 49 |
+
return valid_idx, locs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class BilateralGrid(object):
|
| 53 |
+
def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8):
|
| 54 |
+
im_yuv = rgb2yuv(im)
|
| 55 |
+
# Compute 5-dimensional XYLUV bilateral-space coordinates
|
| 56 |
+
Iy, Ix = np.mgrid[:im.shape[0], :im.shape[1]]
|
| 57 |
+
x_coords = (Ix / sigma_spatial).astype(int)
|
| 58 |
+
y_coords = (Iy / sigma_spatial).astype(int)
|
| 59 |
+
luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int)
|
| 60 |
+
chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int)
|
| 61 |
+
coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords))
|
| 62 |
+
coords_flat = coords.reshape(-1, coords.shape[-1])
|
| 63 |
+
self.npixels, self.dim = coords_flat.shape
|
| 64 |
+
# Hacky "hash vector" for coordinates,
|
| 65 |
+
# Requires all scaled coordinates be < MAX_VAL
|
| 66 |
+
self.hash_vec = (MAX_VAL ** np.arange(self.dim))
|
| 67 |
+
# Construct S and B matrix
|
| 68 |
+
self._compute_factorization(coords_flat)
|
| 69 |
+
|
| 70 |
+
def _compute_factorization(self, coords_flat):
|
| 71 |
+
# Hash each coordinate in grid to a unique value
|
| 72 |
+
hashed_coords = self._hash_coords(coords_flat)
|
| 73 |
+
unique_hashes, unique_idx, idx = \
|
| 74 |
+
np.unique(hashed_coords, return_index=True, return_inverse=True)
|
| 75 |
+
# Identify unique set of vertices
|
| 76 |
+
unique_coords = coords_flat[unique_idx]
|
| 77 |
+
self.nvertices = len(unique_coords)
|
| 78 |
+
# Construct sparse splat matrix that maps from pixels to vertices
|
| 79 |
+
self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels))))
|
| 80 |
+
# Construct sparse blur matrices.
|
| 81 |
+
# Note that these represent [1 0 1] blurs, excluding the central element
|
| 82 |
+
self.blurs = []
|
| 83 |
+
for d in range(self.dim):
|
| 84 |
+
blur = 0.0
|
| 85 |
+
for offset in (-1, 1):
|
| 86 |
+
offset_vec = np.zeros((1, self.dim))
|
| 87 |
+
offset_vec[:, d] = offset
|
| 88 |
+
neighbor_hash = self._hash_coords(unique_coords + offset_vec)
|
| 89 |
+
valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash)
|
| 90 |
+
blur = blur + csr_matrix((np.ones((len(valid_coord),)),
|
| 91 |
+
(valid_coord, idx)),
|
| 92 |
+
shape=(self.nvertices, self.nvertices))
|
| 93 |
+
self.blurs.append(blur)
|
| 94 |
+
|
| 95 |
+
def _hash_coords(self, coord):
|
| 96 |
+
"""Hacky function to turn a coordinate into a unique value"""
|
| 97 |
+
return np.dot(coord.reshape(-1, self.dim), self.hash_vec)
|
| 98 |
+
|
| 99 |
+
def splat(self, x):
|
| 100 |
+
return self.S.dot(x)
|
| 101 |
+
|
| 102 |
+
def slice(self, y):
|
| 103 |
+
return self.S.T.dot(y)
|
| 104 |
+
|
| 105 |
+
def blur(self, x):
|
| 106 |
+
"""Blur a bilateral-space vector with a 1 2 1 kernel in each dimension"""
|
| 107 |
+
assert x.shape[0] == self.nvertices
|
| 108 |
+
out = 2 * self.dim * x
|
| 109 |
+
for blur in self.blurs:
|
| 110 |
+
out = out + blur.dot(x)
|
| 111 |
+
return out
|
| 112 |
+
|
| 113 |
+
def filter(self, x):
|
| 114 |
+
"""Apply bilateral filter to an input x"""
|
| 115 |
+
return self.slice(self.blur(self.splat(x))) / \
|
| 116 |
+
self.slice(self.blur(self.splat(np.ones_like(x))))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def bistochastize(grid, maxiter=10):
|
| 120 |
+
"""Compute diagonal matrices to bistochastize a bilateral grid"""
|
| 121 |
+
m = grid.splat(np.ones(grid.npixels))
|
| 122 |
+
n = np.ones(grid.nvertices)
|
| 123 |
+
for i in range(maxiter):
|
| 124 |
+
n = np.sqrt(n * m / grid.blur(n))
|
| 125 |
+
# Correct m to satisfy the assumption of bistochastization regardless
|
| 126 |
+
# of how many iterations have been run.
|
| 127 |
+
m = n * grid.blur(n)
|
| 128 |
+
Dm = diags(m, 0)
|
| 129 |
+
Dn = diags(n, 0)
|
| 130 |
+
return Dn, Dm
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class BilateralSolver(object):
|
| 134 |
+
def __init__(self, grid, params):
|
| 135 |
+
self.grid = grid
|
| 136 |
+
self.params = params
|
| 137 |
+
self.Dn, self.Dm = bistochastize(grid)
|
| 138 |
+
|
| 139 |
+
def solve(self, x, w):
|
| 140 |
+
# Check that w is a vector or a nx1 matrix
|
| 141 |
+
if w.ndim == 2:
|
| 142 |
+
assert (w.shape[1] == 1)
|
| 143 |
+
elif w.dim == 1:
|
| 144 |
+
w = w.reshape(w.shape[0], 1)
|
| 145 |
+
A_smooth = (self.Dm - self.Dn.dot(self.grid.blur(self.Dn)))
|
| 146 |
+
w_splat = self.grid.splat(w)
|
| 147 |
+
A_data = diags(w_splat[:, 0], 0)
|
| 148 |
+
A = self.params["lam"] * A_smooth + A_data
|
| 149 |
+
xw = x * w
|
| 150 |
+
b = self.grid.splat(xw)
|
| 151 |
+
# Use simple Jacobi preconditioner
|
| 152 |
+
A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"])
|
| 153 |
+
M = diags(1 / A_diag, 0)
|
| 154 |
+
# Flat initialization
|
| 155 |
+
y0 = self.grid.splat(xw) / w_splat
|
| 156 |
+
yhat = np.empty_like(y0)
|
| 157 |
+
for d in range(x.shape[-1]):
|
| 158 |
+
yhat[..., d], info = cg(A, b[..., d], x0=y0[..., d], M=M, maxiter=self.params["cg_maxiter"],
|
| 159 |
+
tol=self.params["cg_tol"])
|
| 160 |
+
xhat = self.grid.slice(yhat)
|
| 161 |
+
return xhat
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def bilateral_solver_output(
|
| 165 |
+
img: Image.Image,
|
| 166 |
+
target: np.ndarray,
|
| 167 |
+
sigma_spatial=16,
|
| 168 |
+
sigma_luma=16,
|
| 169 |
+
sigma_chroma=8
|
| 170 |
+
):
|
| 171 |
+
reference = np.array(img)
|
| 172 |
+
h, w = target.shape
|
| 173 |
+
confidence = np.ones((h, w)) * 0.999
|
| 174 |
+
|
| 175 |
+
grid_params = {
|
| 176 |
+
'sigma_luma': sigma_luma, # Brightness bandwidth
|
| 177 |
+
'sigma_chroma': sigma_chroma, # Color bandwidth
|
| 178 |
+
'sigma_spatial': sigma_spatial # Spatial bandwidth
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
bs_params = {
|
| 182 |
+
'lam': 256, # The strength of the smoothness parameter
|
| 183 |
+
'A_diag_min': 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner.
|
| 184 |
+
'cg_tol': 1e-5, # The tolerance on the convergence in PCG
|
| 185 |
+
'cg_maxiter': 25 # The number of PCG iterations
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
grid = BilateralGrid(reference, **grid_params)
|
| 189 |
+
|
| 190 |
+
t = target.reshape(-1, 1).astype(np.double)
|
| 191 |
+
c = confidence.reshape(-1, 1).astype(np.double)
|
| 192 |
+
|
| 193 |
+
## output solver, which is a soft value
|
| 194 |
+
output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w))
|
| 195 |
+
|
| 196 |
+
binary_solver = ndimage.binary_fill_holes(output_solver > 0.5)
|
| 197 |
+
labeled, nr_objects = ndimage.label(binary_solver)
|
| 198 |
+
|
| 199 |
+
nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
|
| 200 |
+
pixel_order = np.argsort(nb_pixel)
|
| 201 |
+
try:
|
| 202 |
+
binary_solver = labeled == pixel_order[-2]
|
| 203 |
+
except:
|
| 204 |
+
binary_solver = np.ones((h, w), dtype=bool)
|
| 205 |
+
|
| 206 |
+
return output_solver, binary_solver
|
duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# augmentations
|
| 2 |
+
use_copy_paste: false
|
| 3 |
+
scale_range: [ 0.1, 1.0 ]
|
| 4 |
+
repeat_image: false
|
| 5 |
+
|
| 6 |
+
# base directories
|
| 7 |
+
dir_ckpt: "/users/gyungin/selfmask/ckpt" # "/work/gyungin/selfmask/ckpt"
|
| 8 |
+
dir_dataset: "/scratch/shared/beegfs/gyungin/datasets"
|
| 9 |
+
|
| 10 |
+
# clustering
|
| 11 |
+
k: [2, 3, 4]
|
| 12 |
+
clustering_mode: "spectral"
|
| 13 |
+
use_gpu: true # if you want to use gpu-accelerated code for clustering
|
| 14 |
+
scale_factor: 2 # "how much you want to upsample encoder features before clustering"
|
| 15 |
+
|
| 16 |
+
# dataset
|
| 17 |
+
dataset_name: "duts"
|
| 18 |
+
use_pseudo_masks: true
|
| 19 |
+
train_image_size: 224
|
| 20 |
+
eval_image_size: 224
|
| 21 |
+
n_percent: 100
|
| 22 |
+
n_copy_pastes: null
|
| 23 |
+
pseudo_masks_fp: "/users/gyungin/selfmask/datasets/swav_mocov2_dino_p16_k234.json"
|
| 24 |
+
|
| 25 |
+
# dataloader:
|
| 26 |
+
batch_size: 8
|
| 27 |
+
num_workers: 4
|
| 28 |
+
pin_memory: true
|
| 29 |
+
|
| 30 |
+
# networks
|
| 31 |
+
abs_2d_pe_init: false
|
| 32 |
+
arch: "vit_small"
|
| 33 |
+
lateral_connection: false
|
| 34 |
+
learnable_pixel_decoder: false # if False, use the bilinear interpolation
|
| 35 |
+
use_binary_classifier: true # if True, use a binary classifier to get an objectness for each query from transformer decoder
|
| 36 |
+
n_decoder_layers: 6
|
| 37 |
+
n_queries: 20
|
| 38 |
+
num_layers: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
| 39 |
+
patch_size: 8
|
| 40 |
+
training_method: "dino" # "supervised", "deit", "dino", "mocov2", "swav"
|
| 41 |
+
|
| 42 |
+
# objective
|
| 43 |
+
loss_every_decoder_layer: true
|
| 44 |
+
weight_dice_loss: 1.0
|
| 45 |
+
weight_focal_loss: 0.0
|
| 46 |
+
|
| 47 |
+
# optimizer
|
| 48 |
+
lr: 0.000006 # default: 0.00006
|
| 49 |
+
lr_warmup_duration: 0 # 5
|
| 50 |
+
momentum: 0.9
|
| 51 |
+
n_epochs: 12
|
| 52 |
+
weight_decay: 0.01
|
| 53 |
+
optimizer_type: "adamw"
|
| 54 |
+
|
| 55 |
+
# validation
|
| 56 |
+
benchmarks: null
|
networks/__init__.py
ADDED
|
File without changes
|
networks/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
networks/__pycache__/timm_deit.cpython-38.pyc
ADDED
|
Binary file (7.08 kB). View file
|
|
|
networks/__pycache__/timm_vit.cpython-38.pyc
ADDED
|
Binary file (27.7 kB). View file
|
|
|
networks/__pycache__/vision_transformer.cpython-38.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
networks/maskformer/__pycache__/maskformer.cpython-38.pyc
ADDED
|
Binary file (8.51 kB). View file
|
|
|
networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc
ADDED
|
Binary file (8.83 kB). View file
|
|
|
networks/maskformer/maskformer.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from math import sqrt, log
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from networks.maskformer.transformer_decoder import TransformerDecoderLayer, TransformerDecoder
|
| 8 |
+
from utils import get_model
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MaskFormer(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
n_queries: int = 100,
|
| 15 |
+
arch: str = "vit_small",
|
| 16 |
+
patch_size: int = 8,
|
| 17 |
+
training_method: str = "dino",
|
| 18 |
+
n_decoder_layers: int = 6,
|
| 19 |
+
normalize_before: bool = False,
|
| 20 |
+
return_intermediate: bool = False,
|
| 21 |
+
learnable_pixel_decoder: bool = False,
|
| 22 |
+
lateral_connection: bool = False,
|
| 23 |
+
scale_factor: int = 2,
|
| 24 |
+
abs_2d_pe_init: bool = False,
|
| 25 |
+
use_binary_classifier: bool = False
|
| 26 |
+
):
|
| 27 |
+
"""Define a encoder and decoder along with queries to be learned through the decoder."""
|
| 28 |
+
super(MaskFormer, self).__init__()
|
| 29 |
+
|
| 30 |
+
if arch == "vit_small":
|
| 31 |
+
self.encoder = get_model(arch=arch, patch_size=patch_size, training_method=training_method)
|
| 32 |
+
n_dims: int = self.encoder.n_embs
|
| 33 |
+
n_heads: int = self.encoder.n_heads
|
| 34 |
+
mlp_ratio: int = self.encoder.mlp_ratio
|
| 35 |
+
else:
|
| 36 |
+
self.encoder = get_model(arch=arch, training_method=training_method)
|
| 37 |
+
n_dims_resnet: int = self.encoder.n_embs
|
| 38 |
+
n_dims: int = 384
|
| 39 |
+
n_heads: int = 6
|
| 40 |
+
mlp_ratio: int = 4
|
| 41 |
+
self.linear_layer = nn.Conv2d(n_dims_resnet, n_dims, kernel_size=1)
|
| 42 |
+
|
| 43 |
+
decoder_layer = TransformerDecoderLayer(
|
| 44 |
+
n_dims, n_heads, n_dims * mlp_ratio, 0., activation="relu", normalize_before=normalize_before
|
| 45 |
+
)
|
| 46 |
+
self.decoder = TransformerDecoder(
|
| 47 |
+
decoder_layer,
|
| 48 |
+
n_decoder_layers,
|
| 49 |
+
norm=nn.LayerNorm(n_dims),
|
| 50 |
+
return_intermediate=return_intermediate
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.query_embed = nn.Embedding(n_queries, n_dims).weight # initialized with gaussian(0, 1)
|
| 54 |
+
|
| 55 |
+
if use_binary_classifier:
|
| 56 |
+
# self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
|
| 57 |
+
# self.linear_classifier = nn.Linear(n_dims, 1)
|
| 58 |
+
self.ffn = MLP(n_dims, n_dims, 1, num_layers=3)
|
| 59 |
+
# self.norm = nn.LayerNorm(n_dims)
|
| 60 |
+
else:
|
| 61 |
+
# self.ffn = None
|
| 62 |
+
# self.linear_classifier = None
|
| 63 |
+
# self.norm = None
|
| 64 |
+
self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
|
| 65 |
+
self.linear_classifier = nn.Linear(n_dims, 2)
|
| 66 |
+
self.norm = nn.LayerNorm(n_dims)
|
| 67 |
+
|
| 68 |
+
self.arch = arch
|
| 69 |
+
self.use_binary_classifier = use_binary_classifier
|
| 70 |
+
self.lateral_connection = lateral_connection
|
| 71 |
+
self.learnable_pixel_decoder = learnable_pixel_decoder
|
| 72 |
+
self.scale_factor = scale_factor
|
| 73 |
+
|
| 74 |
+
# copy-pasted from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
|
| 75 |
+
@staticmethod
|
| 76 |
+
def positional_encoding_2d(n_dims: int, height: int, width: int):
|
| 77 |
+
"""
|
| 78 |
+
:param n_dims: dimension of the model
|
| 79 |
+
:param height: height of the positions
|
| 80 |
+
:param width: width of the positions
|
| 81 |
+
:return: d_model*height*width position matrix
|
| 82 |
+
"""
|
| 83 |
+
if n_dims % 4 != 0:
|
| 84 |
+
raise ValueError("Cannot use sin/cos positional encoding with "
|
| 85 |
+
"odd dimension (got dim={:d})".format(n_dims))
|
| 86 |
+
pe = torch.zeros(n_dims, height, width)
|
| 87 |
+
# Each dimension use half of d_model
|
| 88 |
+
d_model = int(n_dims / 2)
|
| 89 |
+
div_term = torch.exp(torch.arange(0., d_model, 2) * -(log(10000.0) / d_model))
|
| 90 |
+
pos_w = torch.arange(0., width).unsqueeze(1)
|
| 91 |
+
pos_h = torch.arange(0., height).unsqueeze(1)
|
| 92 |
+
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
| 93 |
+
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
| 94 |
+
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
| 95 |
+
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
| 96 |
+
|
| 97 |
+
return pe
|
| 98 |
+
|
| 99 |
+
def forward_encoder(self, x: torch.Tensor):
|
| 100 |
+
"""
|
| 101 |
+
:param x: b x c x h x w
|
| 102 |
+
:return patch_tokens: b x depth x hw x n_dims
|
| 103 |
+
"""
|
| 104 |
+
if self.arch == "vit_small":
|
| 105 |
+
encoder_outputs: Dict[str, torch.Tensor] = self.encoder(x) # [:, 1:, :]
|
| 106 |
+
all_patch_tokens: List[torch.Tensor] = list()
|
| 107 |
+
for layer_name in [f"layer{num_layer}" for num_layer in range(1, self.encoder.depth + 1)]:
|
| 108 |
+
patch_tokens: torch.Tensor = encoder_outputs[layer_name][:, 1:, :] # b x hw x n_dims
|
| 109 |
+
all_patch_tokens.append(patch_tokens)
|
| 110 |
+
|
| 111 |
+
all_patch_tokens: torch.Tensor = torch.stack(all_patch_tokens, dim=0) # depth x b x hw x n_dims
|
| 112 |
+
all_patch_tokens = all_patch_tokens.permute(1, 0, 3, 2) # b x depth x n_dims x hw
|
| 113 |
+
return all_patch_tokens
|
| 114 |
+
else:
|
| 115 |
+
encoder_outputs = self.linear_layer(self.encoder(x)[-1]) # b x n_dims x h x w
|
| 116 |
+
return encoder_outputs
|
| 117 |
+
|
| 118 |
+
def forward_transformer_decoder(self, patch_tokens: torch.Tensor, skip_decoder: bool = False) -> torch.Tensor:
|
| 119 |
+
"""Forward transformer decoder given patch tokens from the encoder's last layer.
|
| 120 |
+
:param patch_tokens: b x n_dims x hw -> hw x b x n_dims
|
| 121 |
+
:param skip_decoder: if True, skip the decoder and produce mask predictions directly by matrix multiplication
|
| 122 |
+
between learnable queries and encoder features (i.e., patch tokens). This is for the purpose of an overfitting
|
| 123 |
+
experiment.
|
| 124 |
+
:return queries: n_queries x b x n_dims -> b x n_queries x n_dims or b x n_layers x n_queries x n_dims
|
| 125 |
+
"""
|
| 126 |
+
b = patch_tokens.shape[0]
|
| 127 |
+
patch_tokens = patch_tokens.permute(2, 0, 1) # b x n_dims x hw -> hw x b x n_dims
|
| 128 |
+
|
| 129 |
+
# n_queries x n_dims -> n_queries x b x n_dims
|
| 130 |
+
queries: torch.Tensor = self.query_embed.unsqueeze(1).repeat(1, b, 1)
|
| 131 |
+
queries: torch.Tensor = self.decoder.forward(
|
| 132 |
+
tgt=torch.zeros_like(queries),
|
| 133 |
+
memory=patch_tokens,
|
| 134 |
+
query_pos=queries
|
| 135 |
+
).squeeze(dim=0)
|
| 136 |
+
|
| 137 |
+
if len(queries.shape) == 3:
|
| 138 |
+
queries: torch.Tensor = queries.permute(1, 0, 2) # n_queries x b x n_dims -> b x n_queries x n_dims
|
| 139 |
+
elif len(queries.shape) == 4:
|
| 140 |
+
# n_layers x n_queries x b x n_dims -> b x n_layers x n_queries x n_dims
|
| 141 |
+
queries: torch.Tensor = queries.permute(2, 0, 1, 3)
|
| 142 |
+
return queries
|
| 143 |
+
|
| 144 |
+
def forward_pixel_decoder(self, patch_tokens: torch.Tensor, input_size=None):
|
| 145 |
+
""" Upsample patch tokens by self.scale_factor and produce mask predictions
|
| 146 |
+
:param patch_tokens: b (x depth) x n_dims x hw -> b (x depth) x n_dims x h x w
|
| 147 |
+
:param queries: b x n_queries x n_dims
|
| 148 |
+
:return mask_predictions: b x n_queries x h x w
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
if input_size is None:
|
| 152 |
+
# assume square shape features
|
| 153 |
+
hw = patch_tokens.shape[-1]
|
| 154 |
+
h = w = int(sqrt(hw))
|
| 155 |
+
else:
|
| 156 |
+
# arbitrary shape features
|
| 157 |
+
h, w = input_size
|
| 158 |
+
patch_tokens = patch_tokens.view(*patch_tokens.shape[:-1], h, w)
|
| 159 |
+
|
| 160 |
+
assert len(patch_tokens.shape) == 4
|
| 161 |
+
patch_tokens = F.interpolate(patch_tokens, scale_factor=self.scale_factor, mode="bilinear")
|
| 162 |
+
return patch_tokens
|
| 163 |
+
|
| 164 |
+
def forward(self, x, encoder_only=False, skip_decoder: bool = False):
|
| 165 |
+
"""
|
| 166 |
+
x: b x c x h x w
|
| 167 |
+
patch_tokens: b x n_patches x n_dims -> n_patches x b x n_dims
|
| 168 |
+
query_emb: n_queries x n_dims -> n_queries x b x n_dims
|
| 169 |
+
"""
|
| 170 |
+
dict_outputs: dict = dict()
|
| 171 |
+
|
| 172 |
+
# b x depth x n_dims x hw (vit) or b x n_dims x h x w (resnet50)
|
| 173 |
+
features: torch.Tensor = self.forward_encoder(x)
|
| 174 |
+
|
| 175 |
+
if self.arch == "vit_small":
|
| 176 |
+
# extract the last layer for decoder input
|
| 177 |
+
last_layer_features: torch.Tensor = features[:, -1, ...] # b x n_dims x hw
|
| 178 |
+
else:
|
| 179 |
+
# transform the shape of the features to the one compatible with transformer decoder
|
| 180 |
+
b, n_dims, h, w = features.shape
|
| 181 |
+
last_layer_features: torch.Tensor = features.view(b, n_dims, h * w) # b x n_dims x hw
|
| 182 |
+
|
| 183 |
+
if encoder_only:
|
| 184 |
+
_h, _w = self.encoder.make_input_divisible(x).shape[-2:]
|
| 185 |
+
_h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
|
| 186 |
+
|
| 187 |
+
b, n_dims, hw = last_layer_features.shape
|
| 188 |
+
dict_outputs.update({"patch_tokens": last_layer_features.view(b, _h, _w, n_dims)})
|
| 189 |
+
return dict_outputs
|
| 190 |
+
|
| 191 |
+
# transformer decoder forward
|
| 192 |
+
queries: torch.Tensor = self.forward_transformer_decoder(
|
| 193 |
+
last_layer_features,
|
| 194 |
+
skip_decoder=skip_decoder
|
| 195 |
+
) # b x n_queries x n_dims or b x n_layers x n_queries x n_dims
|
| 196 |
+
|
| 197 |
+
# pixel decoder forward (upsampling the patch tokens by self.scale_factor)
|
| 198 |
+
if self.arch == "vit_small":
|
| 199 |
+
_h, _w = self.encoder.make_input_divisible(x).shape[-2:]
|
| 200 |
+
_h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
|
| 201 |
+
else:
|
| 202 |
+
_h, _w = h, w
|
| 203 |
+
features: torch.Tensor = self.forward_pixel_decoder(
|
| 204 |
+
patch_tokens=features if self.lateral_connection else last_layer_features,
|
| 205 |
+
input_size=(_h, _w)
|
| 206 |
+
) # b x n_dims x h x w
|
| 207 |
+
|
| 208 |
+
# queries: b x n_queries x n_dims or b x n_layers x n_queries x n_dims
|
| 209 |
+
# features: b x n_dims x h x w
|
| 210 |
+
# mask_pred: b x n_queries x h x w or b x n_layers x n_queries x h x w
|
| 211 |
+
if len(queries.shape) == 3:
|
| 212 |
+
mask_pred = torch.einsum("bqn,bnhw->bqhw", queries, features)
|
| 213 |
+
else:
|
| 214 |
+
if self.use_binary_classifier:
|
| 215 |
+
mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", queries, features))
|
| 216 |
+
else:
|
| 217 |
+
mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", self.ffn(queries), features))
|
| 218 |
+
|
| 219 |
+
if self.use_binary_classifier:
|
| 220 |
+
# queries: b x n_layers x n_queries x n_dims -> n_layers x b x n_queries x n_dims
|
| 221 |
+
queries = queries.permute(1, 0, 2, 3)
|
| 222 |
+
objectness: List[torch.Tensor] = list()
|
| 223 |
+
for n_layer, queries_per_layer in enumerate(queries): # queries_per_layer: b x n_queries x n_dims
|
| 224 |
+
# objectness_per_layer = self.linear_classifier(
|
| 225 |
+
# self.ffn(self.norm(queries_per_layer))
|
| 226 |
+
# ) # b x n_queries x 1
|
| 227 |
+
objectness_per_layer = self.ffn(queries_per_layer) # b x n_queries x 1
|
| 228 |
+
objectness.append(objectness_per_layer)
|
| 229 |
+
# n_layers x b x n_queries x 1 -> # b x n_layers x n_queries x 1
|
| 230 |
+
objectness: torch.Tensor = torch.stack(objectness).permute(1, 0, 2, 3)
|
| 231 |
+
dict_outputs.update({
|
| 232 |
+
"objectness": torch.sigmoid(objectness),
|
| 233 |
+
"mask_pred": mask_pred
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
return dict_outputs
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class MLP(nn.Module):
|
| 240 |
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
| 241 |
+
|
| 242 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.num_layers = num_layers
|
| 245 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 246 |
+
self.layers = nn.ModuleList(
|
| 247 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
def forward(self, x):
|
| 251 |
+
for i, layer in enumerate(self.layers):
|
| 252 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 253 |
+
return x
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class UpsampleBlock(nn.Module):
|
| 257 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, n_groups=32, scale_factor=2):
|
| 258 |
+
super(UpsampleBlock, self).__init__()
|
| 259 |
+
self.block = nn.Sequential(
|
| 260 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
|
| 261 |
+
nn.GroupNorm(n_groups, out_channels),
|
| 262 |
+
nn.ReLU()
|
| 263 |
+
)
|
| 264 |
+
self.scale_factor = scale_factor
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
return F.interpolate(self.block(x), scale_factor=self.scale_factor, mode="bilinear")
|
networks/maskformer/positional_embedding.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
|
| 3 |
+
"""
|
| 4 |
+
Various positional encodings for the transformer.
|
| 5 |
+
"""
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PositionEmbeddingSine(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 15 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.num_pos_feats = num_pos_feats
|
| 21 |
+
self.temperature = temperature
|
| 22 |
+
self.normalize = normalize
|
| 23 |
+
if scale is not None and normalize is False:
|
| 24 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 25 |
+
if scale is None:
|
| 26 |
+
scale = 2 * math.pi
|
| 27 |
+
self.scale = scale
|
| 28 |
+
|
| 29 |
+
def forward(self, x, mask=None):
|
| 30 |
+
if mask is None:
|
| 31 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
| 32 |
+
not_mask = ~mask
|
| 33 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 34 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 35 |
+
if self.normalize:
|
| 36 |
+
eps = 1e-6
|
| 37 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 38 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 39 |
+
|
| 40 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 41 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 42 |
+
|
| 43 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 44 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 45 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 46 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 47 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 48 |
+
return pos
|
networks/maskformer/transformer_decoder.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
|
| 3 |
+
"""
|
| 4 |
+
Transformer class.
|
| 5 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
| 6 |
+
* positional encodings are passed in MHattention
|
| 7 |
+
* extra LN at the end of encoder is removed
|
| 8 |
+
* decoder returns a stack of activations from all decoding layers
|
| 9 |
+
"""
|
| 10 |
+
import copy
|
| 11 |
+
from typing import List, Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch import Tensor, nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Transformer(nn.Module):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
d_model=512,
|
| 22 |
+
nhead=8,
|
| 23 |
+
num_encoder_layers=6,
|
| 24 |
+
num_decoder_layers=6,
|
| 25 |
+
dim_feedforward=2048,
|
| 26 |
+
dropout=0.1,
|
| 27 |
+
activation="relu", # noel - dino used GeLU
|
| 28 |
+
normalize_before=False,
|
| 29 |
+
return_intermediate_dec=False,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
encoder_layer = TransformerEncoderLayer(
|
| 34 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
| 35 |
+
)
|
| 36 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
| 37 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
| 38 |
+
|
| 39 |
+
decoder_layer = TransformerDecoderLayer(
|
| 40 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
| 41 |
+
)
|
| 42 |
+
decoder_norm = nn.LayerNorm(d_model)
|
| 43 |
+
self.decoder = TransformerDecoder(
|
| 44 |
+
decoder_layer,
|
| 45 |
+
num_decoder_layers,
|
| 46 |
+
decoder_norm,
|
| 47 |
+
return_intermediate=return_intermediate_dec,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self._reset_parameters()
|
| 51 |
+
|
| 52 |
+
self.d_model = d_model
|
| 53 |
+
self.nhead = nhead
|
| 54 |
+
|
| 55 |
+
def _reset_parameters(self):
|
| 56 |
+
for p in self.parameters():
|
| 57 |
+
if p.dim() > 1:
|
| 58 |
+
nn.init.xavier_uniform_(p)
|
| 59 |
+
|
| 60 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
| 61 |
+
# flatten NxCxHxW to HWxNxC
|
| 62 |
+
bs, c, h, w = src.shape
|
| 63 |
+
src = src.flatten(2).permute(2, 0, 1)
|
| 64 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
| 65 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
| 66 |
+
if mask is not None:
|
| 67 |
+
mask = mask.flatten(1)
|
| 68 |
+
|
| 69 |
+
tgt = torch.zeros_like(query_embed)
|
| 70 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
| 71 |
+
hs = self.decoder(
|
| 72 |
+
tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
|
| 73 |
+
)
|
| 74 |
+
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class TransformerEncoder(nn.Module):
|
| 78 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 81 |
+
self.num_layers = num_layers
|
| 82 |
+
self.norm = norm
|
| 83 |
+
|
| 84 |
+
def forward(
|
| 85 |
+
self,
|
| 86 |
+
src,
|
| 87 |
+
mask: Optional[Tensor] = None,
|
| 88 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 89 |
+
pos: Optional[Tensor] = None,
|
| 90 |
+
):
|
| 91 |
+
output = src
|
| 92 |
+
|
| 93 |
+
for layer in self.layers:
|
| 94 |
+
output = layer(
|
| 95 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if self.norm is not None:
|
| 99 |
+
output = self.norm(output)
|
| 100 |
+
|
| 101 |
+
return output
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class TransformerDecoder(nn.Module):
|
| 105 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.layers: nn.ModuleList = _get_clones(decoder_layer, num_layers)
|
| 108 |
+
self.num_layers: int = num_layers
|
| 109 |
+
self.norm = norm
|
| 110 |
+
self.return_intermediate: bool = return_intermediate
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
tgt,
|
| 115 |
+
memory,
|
| 116 |
+
tgt_mask: Optional[Tensor] = None,
|
| 117 |
+
memory_mask: Optional[Tensor] = None,
|
| 118 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 119 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 120 |
+
pos: Optional[Tensor] = None,
|
| 121 |
+
query_pos: Optional[Tensor] = None,
|
| 122 |
+
):
|
| 123 |
+
output = tgt
|
| 124 |
+
|
| 125 |
+
intermediate = []
|
| 126 |
+
|
| 127 |
+
for layer in self.layers:
|
| 128 |
+
output = layer(
|
| 129 |
+
output,
|
| 130 |
+
memory,
|
| 131 |
+
tgt_mask=tgt_mask,
|
| 132 |
+
memory_mask=memory_mask,
|
| 133 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 134 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 135 |
+
pos=pos,
|
| 136 |
+
query_pos=query_pos,
|
| 137 |
+
)
|
| 138 |
+
if self.return_intermediate:
|
| 139 |
+
intermediate.append(self.norm(output))
|
| 140 |
+
|
| 141 |
+
if self.norm is not None:
|
| 142 |
+
output = self.norm(output)
|
| 143 |
+
if self.return_intermediate:
|
| 144 |
+
intermediate.pop()
|
| 145 |
+
intermediate.append(output)
|
| 146 |
+
|
| 147 |
+
if self.return_intermediate:
|
| 148 |
+
return torch.stack(intermediate)
|
| 149 |
+
|
| 150 |
+
return output.unsqueeze(0)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TransformerEncoderLayer(nn.Module):
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
d_model,
|
| 157 |
+
nhead,
|
| 158 |
+
dim_feedforward=2048,
|
| 159 |
+
dropout=0.1,
|
| 160 |
+
activation="relu",
|
| 161 |
+
normalize_before=False,
|
| 162 |
+
):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 165 |
+
# Implementation of Feedforward model
|
| 166 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 167 |
+
self.dropout = nn.Dropout(dropout)
|
| 168 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 169 |
+
|
| 170 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 171 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 172 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 173 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 174 |
+
|
| 175 |
+
self.activation = _get_activation_fn(activation)
|
| 176 |
+
self.normalize_before = normalize_before
|
| 177 |
+
|
| 178 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 179 |
+
return tensor if pos is None else tensor + pos
|
| 180 |
+
|
| 181 |
+
def forward_post(
|
| 182 |
+
self,
|
| 183 |
+
src,
|
| 184 |
+
src_mask: Optional[Tensor] = None,
|
| 185 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 186 |
+
pos: Optional[Tensor] = None,
|
| 187 |
+
):
|
| 188 |
+
q = k = self.with_pos_embed(src, pos)
|
| 189 |
+
src2 = self.self_attn(
|
| 190 |
+
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
| 191 |
+
)[0]
|
| 192 |
+
src = src + self.dropout1(src2)
|
| 193 |
+
src = self.norm1(src)
|
| 194 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| 195 |
+
src = src + self.dropout2(src2)
|
| 196 |
+
src = self.norm2(src)
|
| 197 |
+
return src
|
| 198 |
+
|
| 199 |
+
def forward_pre(
|
| 200 |
+
self,
|
| 201 |
+
src,
|
| 202 |
+
src_mask: Optional[Tensor] = None,
|
| 203 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 204 |
+
pos: Optional[Tensor] = None,
|
| 205 |
+
):
|
| 206 |
+
src2 = self.norm1(src)
|
| 207 |
+
q = k = self.with_pos_embed(src2, pos)
|
| 208 |
+
src2 = self.self_attn(
|
| 209 |
+
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
| 210 |
+
)[0]
|
| 211 |
+
src = src + self.dropout1(src2)
|
| 212 |
+
src2 = self.norm2(src)
|
| 213 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
| 214 |
+
src = src + self.dropout2(src2)
|
| 215 |
+
return src
|
| 216 |
+
|
| 217 |
+
def forward(
|
| 218 |
+
self,
|
| 219 |
+
src,
|
| 220 |
+
src_mask: Optional[Tensor] = None,
|
| 221 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 222 |
+
pos: Optional[Tensor] = None,
|
| 223 |
+
):
|
| 224 |
+
if self.normalize_before:
|
| 225 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
| 226 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class TransformerDecoderLayer(nn.Module):
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
d_model,
|
| 233 |
+
nhead,
|
| 234 |
+
dim_feedforward=2048,
|
| 235 |
+
dropout=0.1,
|
| 236 |
+
activation="relu",
|
| 237 |
+
normalize_before=False,
|
| 238 |
+
):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 241 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 242 |
+
# Implementation of Feedforward model
|
| 243 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 244 |
+
self.dropout = nn.Dropout(dropout)
|
| 245 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 246 |
+
|
| 247 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 248 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 249 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 250 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 251 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 252 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 253 |
+
|
| 254 |
+
self.activation = _get_activation_fn(activation)
|
| 255 |
+
self.normalize_before = normalize_before
|
| 256 |
+
|
| 257 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 258 |
+
return tensor if pos is None else tensor + pos
|
| 259 |
+
|
| 260 |
+
def forward_post(
|
| 261 |
+
self,
|
| 262 |
+
tgt,
|
| 263 |
+
memory,
|
| 264 |
+
tgt_mask: Optional[Tensor] = None,
|
| 265 |
+
memory_mask: Optional[Tensor] = None,
|
| 266 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 267 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 268 |
+
pos: Optional[Tensor] = None,
|
| 269 |
+
query_pos: Optional[Tensor] = None,
|
| 270 |
+
):
|
| 271 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
| 272 |
+
|
| 273 |
+
tgt2 = self.self_attn(
|
| 274 |
+
q,
|
| 275 |
+
k,
|
| 276 |
+
value=tgt,
|
| 277 |
+
attn_mask=tgt_mask,
|
| 278 |
+
key_padding_mask=tgt_key_padding_mask
|
| 279 |
+
)[0]
|
| 280 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 281 |
+
tgt = self.norm1(tgt)
|
| 282 |
+
|
| 283 |
+
tgt2 = self.multihead_attn(
|
| 284 |
+
query=self.with_pos_embed(tgt, query_pos),
|
| 285 |
+
key=self.with_pos_embed(memory, pos),
|
| 286 |
+
value=memory,
|
| 287 |
+
attn_mask=memory_mask,
|
| 288 |
+
key_padding_mask=memory_key_padding_mask,
|
| 289 |
+
)[0]
|
| 290 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 291 |
+
tgt = self.norm2(tgt)
|
| 292 |
+
|
| 293 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
| 294 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 295 |
+
tgt = self.norm3(tgt)
|
| 296 |
+
|
| 297 |
+
return tgt
|
| 298 |
+
|
| 299 |
+
def forward_pre(
|
| 300 |
+
self,
|
| 301 |
+
tgt,
|
| 302 |
+
memory,
|
| 303 |
+
tgt_mask: Optional[Tensor] = None,
|
| 304 |
+
memory_mask: Optional[Tensor] = None,
|
| 305 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 306 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 307 |
+
pos: Optional[Tensor] = None,
|
| 308 |
+
query_pos: Optional[Tensor] = None,
|
| 309 |
+
):
|
| 310 |
+
tgt2 = self.norm1(tgt)
|
| 311 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 312 |
+
tgt2 = self.self_attn(
|
| 313 |
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
| 314 |
+
)[0]
|
| 315 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 316 |
+
tgt2 = self.norm2(tgt)
|
| 317 |
+
tgt2 = self.multihead_attn(
|
| 318 |
+
query=self.with_pos_embed(tgt2, query_pos),
|
| 319 |
+
key=self.with_pos_embed(memory, pos),
|
| 320 |
+
value=memory,
|
| 321 |
+
attn_mask=memory_mask,
|
| 322 |
+
key_padding_mask=memory_key_padding_mask,
|
| 323 |
+
)[0]
|
| 324 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 325 |
+
tgt2 = self.norm3(tgt)
|
| 326 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 327 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 328 |
+
return tgt
|
| 329 |
+
|
| 330 |
+
def forward(
|
| 331 |
+
self,
|
| 332 |
+
tgt,
|
| 333 |
+
memory,
|
| 334 |
+
tgt_mask: Optional[Tensor] = None,
|
| 335 |
+
memory_mask: Optional[Tensor] = None,
|
| 336 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 337 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 338 |
+
pos: Optional[Tensor] = None,
|
| 339 |
+
query_pos: Optional[Tensor] = None,
|
| 340 |
+
):
|
| 341 |
+
if self.normalize_before:
|
| 342 |
+
return self.forward_pre(
|
| 343 |
+
tgt,
|
| 344 |
+
memory,
|
| 345 |
+
tgt_mask,
|
| 346 |
+
memory_mask,
|
| 347 |
+
tgt_key_padding_mask,
|
| 348 |
+
memory_key_padding_mask,
|
| 349 |
+
pos,
|
| 350 |
+
query_pos,
|
| 351 |
+
)
|
| 352 |
+
return self.forward_post(
|
| 353 |
+
tgt,
|
| 354 |
+
memory,
|
| 355 |
+
tgt_mask,
|
| 356 |
+
memory_mask,
|
| 357 |
+
tgt_key_padding_mask,
|
| 358 |
+
memory_key_padding_mask,
|
| 359 |
+
pos,
|
| 360 |
+
query_pos,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _get_clones(module, N):
|
| 365 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _get_activation_fn(activation):
|
| 369 |
+
"""Return an activation function given a string"""
|
| 370 |
+
if activation == "relu":
|
| 371 |
+
return F.relu
|
| 372 |
+
if activation == "gelu":
|
| 373 |
+
return F.gelu
|
| 374 |
+
if activation == "glu":
|
| 375 |
+
return F.glu
|
| 376 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
networks/module_helper.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
# Author: Donny You ([email protected])
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from urllib import urlretrieve
|
| 11 |
+
except ImportError:
|
| 12 |
+
from urllib.request import urlretrieve
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FixedBatchNorm(nn.BatchNorm2d):
|
| 16 |
+
def forward(self, input):
|
| 17 |
+
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ModuleHelper(object):
|
| 21 |
+
@staticmethod
|
| 22 |
+
def BNReLU(num_features, norm_type=None, **kwargs):
|
| 23 |
+
if norm_type == 'batchnorm':
|
| 24 |
+
return nn.Sequential(
|
| 25 |
+
nn.BatchNorm2d(num_features, **kwargs),
|
| 26 |
+
nn.ReLU()
|
| 27 |
+
)
|
| 28 |
+
elif norm_type == 'encsync_batchnorm':
|
| 29 |
+
from encoding.nn import BatchNorm2d
|
| 30 |
+
return nn.Sequential(
|
| 31 |
+
BatchNorm2d(num_features, **kwargs),
|
| 32 |
+
nn.ReLU()
|
| 33 |
+
)
|
| 34 |
+
elif norm_type == 'instancenorm':
|
| 35 |
+
return nn.Sequential(
|
| 36 |
+
nn.InstanceNorm2d(num_features, **kwargs),
|
| 37 |
+
nn.ReLU()
|
| 38 |
+
)
|
| 39 |
+
elif norm_type == 'fixed_batchnorm':
|
| 40 |
+
return nn.Sequential(
|
| 41 |
+
FixedBatchNorm(num_features, **kwargs),
|
| 42 |
+
nn.ReLU()
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError('Not support BN type: {}.'.format(norm_type))
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def BatchNorm3d(norm_type=None, ret_cls=False):
|
| 49 |
+
if norm_type == 'batchnorm':
|
| 50 |
+
return nn.BatchNorm3d
|
| 51 |
+
elif norm_type == 'encsync_batchnorm':
|
| 52 |
+
from encoding.nn import BatchNorm3d
|
| 53 |
+
return BatchNorm3d
|
| 54 |
+
elif norm_type == 'instancenorm':
|
| 55 |
+
return nn.InstanceNorm3d
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError('Not support BN type: {}.'.format(norm_type))
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def BatchNorm2d(norm_type=None, ret_cls=False):
|
| 61 |
+
if norm_type == 'batchnorm':
|
| 62 |
+
return nn.BatchNorm2d
|
| 63 |
+
elif norm_type == 'encsync_batchnorm':
|
| 64 |
+
from encoding.nn import BatchNorm2d
|
| 65 |
+
return BatchNorm2d
|
| 66 |
+
|
| 67 |
+
elif norm_type == 'instancenorm':
|
| 68 |
+
return nn.InstanceNorm2d
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError('Not support BN type: {}.'.format(norm_type))
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def BatchNorm1d(norm_type=None, ret_cls=False):
|
| 74 |
+
if norm_type == 'batchnorm':
|
| 75 |
+
return nn.BatchNorm1d
|
| 76 |
+
elif norm_type == 'encsync_batchnorm':
|
| 77 |
+
from encoding.nn import BatchNorm1d
|
| 78 |
+
return BatchNorm1d
|
| 79 |
+
elif norm_type == 'instancenorm':
|
| 80 |
+
return nn.InstanceNorm1d
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError('Not support BN type: {}.'.format(norm_type))
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def load_model(model, pretrained=None, all_match=True, map_location='cpu'):
|
| 86 |
+
if pretrained is None:
|
| 87 |
+
return model
|
| 88 |
+
|
| 89 |
+
if not os.path.exists(pretrained):
|
| 90 |
+
pretrained = pretrained.replace("..", "/home/gishin-temp/projects/open_set/segmentation")
|
| 91 |
+
if os.path.exists(pretrained):
|
| 92 |
+
pass
|
| 93 |
+
else:
|
| 94 |
+
raise FileNotFoundError('{} not exists.'.format(pretrained))
|
| 95 |
+
|
| 96 |
+
print('Loading pretrained model:{}'.format(pretrained))
|
| 97 |
+
if all_match:
|
| 98 |
+
pretrained_dict = torch.load(pretrained, map_location=map_location)
|
| 99 |
+
model_dict = model.state_dict()
|
| 100 |
+
load_dict = dict()
|
| 101 |
+
for k, v in pretrained_dict.items():
|
| 102 |
+
if 'prefix.{}'.format(k) in model_dict:
|
| 103 |
+
load_dict['prefix.{}'.format(k)] = v
|
| 104 |
+
else:
|
| 105 |
+
load_dict[k] = v
|
| 106 |
+
model.load_state_dict(load_dict)
|
| 107 |
+
|
| 108 |
+
else:
|
| 109 |
+
pretrained_dict = torch.load(pretrained)
|
| 110 |
+
model_dict = model.state_dict()
|
| 111 |
+
load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
| 112 |
+
print('Matched Keys: {}'.format(load_dict.keys()))
|
| 113 |
+
model_dict.update(load_dict)
|
| 114 |
+
model.load_state_dict(model_dict)
|
| 115 |
+
|
| 116 |
+
return model
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def load_url(url, map_location=None):
|
| 120 |
+
model_dir = os.path.join('~', '.TorchCV', 'model')
|
| 121 |
+
if not os.path.exists(model_dir):
|
| 122 |
+
os.makedirs(model_dir)
|
| 123 |
+
|
| 124 |
+
filename = url.split('/')[-1]
|
| 125 |
+
cached_file = os.path.join(model_dir, filename)
|
| 126 |
+
if not os.path.exists(cached_file):
|
| 127 |
+
print('Downloading: "{}" to {}\n'.format(url, cached_file))
|
| 128 |
+
urlretrieve(url, cached_file)
|
| 129 |
+
|
| 130 |
+
print('Loading pretrained model:{}'.format(cached_file))
|
| 131 |
+
return torch.load(cached_file, map_location=map_location)
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def constant_init(module, val, bias=0):
|
| 135 |
+
nn.init.constant_(module.weight, val)
|
| 136 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 137 |
+
nn.init.constant_(module.bias, bias)
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
| 141 |
+
assert distribution in ['uniform', 'normal']
|
| 142 |
+
if distribution == 'uniform':
|
| 143 |
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
| 144 |
+
else:
|
| 145 |
+
nn.init.xavier_normal_(module.weight, gain=gain)
|
| 146 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 147 |
+
nn.init.constant_(module.bias, bias)
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def normal_init(module, mean=0, std=1, bias=0):
|
| 151 |
+
nn.init.normal_(module.weight, mean, std)
|
| 152 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 153 |
+
nn.init.constant_(module.bias, bias)
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def uniform_init(module, a=0, b=1, bias=0):
|
| 157 |
+
nn.init.uniform_(module.weight, a, b)
|
| 158 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 159 |
+
nn.init.constant_(module.bias, bias)
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def kaiming_init(module,
|
| 163 |
+
mode='fan_in',
|
| 164 |
+
nonlinearity='leaky_relu',
|
| 165 |
+
bias=0,
|
| 166 |
+
distribution='normal'):
|
| 167 |
+
assert distribution in ['uniform', 'normal']
|
| 168 |
+
if distribution == 'uniform':
|
| 169 |
+
nn.init.kaiming_uniform_(
|
| 170 |
+
module.weight, mode=mode, nonlinearity=nonlinearity)
|
| 171 |
+
else:
|
| 172 |
+
nn.init.kaiming_normal_(
|
| 173 |
+
module.weight, mode=mode, nonlinearity=nonlinearity)
|
| 174 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 175 |
+
nn.init.constant_(module.bias, bias)
|
| 176 |
+
|
networks/resnet.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from .resnet_backbone import ResNetBackbone
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ResNet50(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
weight_type: str = "supervised",
|
| 12 |
+
use_dilated_resnet: bool = True
|
| 13 |
+
):
|
| 14 |
+
super(ResNet50, self).__init__()
|
| 15 |
+
self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None)
|
| 16 |
+
self.n_embs = self.network.num_features
|
| 17 |
+
self.use_dilated_resnet = use_dilated_resnet
|
| 18 |
+
self._load_pretrained(weight_type)
|
| 19 |
+
|
| 20 |
+
def _load_pretrained(self, training_method: str) -> None:
|
| 21 |
+
curr_state_dict = self.network.state_dict()
|
| 22 |
+
if training_method == "mocov2":
|
| 23 |
+
state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"]
|
| 24 |
+
|
| 25 |
+
for k in list(state_dict.keys()):
|
| 26 |
+
if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]):
|
| 27 |
+
state_dict.pop(k)
|
| 28 |
+
|
| 29 |
+
elif training_method == "swav":
|
| 30 |
+
state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar")
|
| 31 |
+
for k in list(state_dict.keys()):
|
| 32 |
+
if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]):
|
| 33 |
+
state_dict.pop(k)
|
| 34 |
+
|
| 35 |
+
elif training_method == "supervised":
|
| 36 |
+
# Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why.
|
| 37 |
+
# for k in list(curr_state_dict.keys()):
|
| 38 |
+
# if k.find("num_batches_tracked") != -1:
|
| 39 |
+
# curr_state_dict.pop(k)
|
| 40 |
+
# state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth")
|
| 41 |
+
|
| 42 |
+
from torchvision.models.resnet import resnet50
|
| 43 |
+
resnet50_supervised = resnet50(True, True)
|
| 44 |
+
state_dict = resnet50_supervised.state_dict()
|
| 45 |
+
for k in list(state_dict.keys()):
|
| 46 |
+
if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]):
|
| 47 |
+
state_dict.pop(k)
|
| 48 |
+
|
| 49 |
+
assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}"
|
| 50 |
+
for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()):
|
| 51 |
+
curr_state_dict[k_curr].copy_(state_dict[k])
|
| 52 |
+
print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.")
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return self.network(x)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == '__main__':
|
| 60 |
+
resnet = ResNet50("mocov2")
|
networks/resnet_backbone.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
# Author: Donny You([email protected])
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from networks.resnet_models import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class NormalResnetBackbone(nn.Module):
|
| 11 |
+
def __init__(self, orig_resnet):
|
| 12 |
+
super(NormalResnetBackbone, self).__init__()
|
| 13 |
+
|
| 14 |
+
self.num_features = 2048
|
| 15 |
+
# take pretrained resnet, except AvgPool and FC
|
| 16 |
+
self.prefix = orig_resnet.prefix
|
| 17 |
+
self.maxpool = orig_resnet.maxpool
|
| 18 |
+
self.layer1 = orig_resnet.layer1
|
| 19 |
+
self.layer2 = orig_resnet.layer2
|
| 20 |
+
self.layer3 = orig_resnet.layer3
|
| 21 |
+
self.layer4 = orig_resnet.layer4
|
| 22 |
+
|
| 23 |
+
def get_num_features(self):
|
| 24 |
+
return self.num_features
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
tuple_features = list()
|
| 28 |
+
x = self.prefix(x)
|
| 29 |
+
x = self.maxpool(x)
|
| 30 |
+
x = self.layer1(x)
|
| 31 |
+
tuple_features.append(x)
|
| 32 |
+
x = self.layer2(x)
|
| 33 |
+
tuple_features.append(x)
|
| 34 |
+
x = self.layer3(x)
|
| 35 |
+
tuple_features.append(x)
|
| 36 |
+
x = self.layer4(x)
|
| 37 |
+
tuple_features.append(x)
|
| 38 |
+
|
| 39 |
+
return tuple_features
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DilatedResnetBackbone(nn.Module):
|
| 43 |
+
def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)):
|
| 44 |
+
super(DilatedResnetBackbone, self).__init__()
|
| 45 |
+
|
| 46 |
+
self.num_features = 2048
|
| 47 |
+
from functools import partial
|
| 48 |
+
|
| 49 |
+
if dilate_scale == 8:
|
| 50 |
+
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
|
| 51 |
+
if multi_grid is None:
|
| 52 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
|
| 53 |
+
else:
|
| 54 |
+
for i, r in enumerate(multi_grid):
|
| 55 |
+
orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r)))
|
| 56 |
+
|
| 57 |
+
elif dilate_scale == 16:
|
| 58 |
+
if multi_grid is None:
|
| 59 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
|
| 60 |
+
else:
|
| 61 |
+
for i, r in enumerate(multi_grid):
|
| 62 |
+
orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r)))
|
| 63 |
+
|
| 64 |
+
# Take pretrained resnet, except AvgPool and FC
|
| 65 |
+
self.prefix = orig_resnet.prefix
|
| 66 |
+
self.maxpool = orig_resnet.maxpool
|
| 67 |
+
self.layer1 = orig_resnet.layer1
|
| 68 |
+
self.layer2 = orig_resnet.layer2
|
| 69 |
+
self.layer3 = orig_resnet.layer3
|
| 70 |
+
self.layer4 = orig_resnet.layer4
|
| 71 |
+
|
| 72 |
+
def _nostride_dilate(self, m, dilate):
|
| 73 |
+
classname = m.__class__.__name__
|
| 74 |
+
if classname.find('Conv') != -1:
|
| 75 |
+
# the convolution with stride
|
| 76 |
+
if m.stride == (2, 2):
|
| 77 |
+
m.stride = (1, 1)
|
| 78 |
+
if m.kernel_size == (3, 3):
|
| 79 |
+
m.dilation = (dilate // 2, dilate // 2)
|
| 80 |
+
m.padding = (dilate // 2, dilate // 2)
|
| 81 |
+
# other convoluions
|
| 82 |
+
else:
|
| 83 |
+
if m.kernel_size == (3, 3):
|
| 84 |
+
m.dilation = (dilate, dilate)
|
| 85 |
+
m.padding = (dilate, dilate)
|
| 86 |
+
|
| 87 |
+
def get_num_features(self):
|
| 88 |
+
return self.num_features
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
tuple_features = list()
|
| 92 |
+
|
| 93 |
+
x = self.prefix(x)
|
| 94 |
+
x = self.maxpool(x)
|
| 95 |
+
|
| 96 |
+
x = self.layer1(x)
|
| 97 |
+
tuple_features.append(x)
|
| 98 |
+
x = self.layer2(x)
|
| 99 |
+
tuple_features.append(x)
|
| 100 |
+
x = self.layer3(x)
|
| 101 |
+
tuple_features.append(x)
|
| 102 |
+
x = self.layer4(x)
|
| 103 |
+
tuple_features.append(x)
|
| 104 |
+
|
| 105 |
+
return tuple_features
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'):
|
| 109 |
+
arch = backbone
|
| 110 |
+
|
| 111 |
+
if arch == 'resnet18':
|
| 112 |
+
orig_resnet = resnet18(pretrained=pretrained)
|
| 113 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
| 114 |
+
arch_net.num_features = 512
|
| 115 |
+
|
| 116 |
+
elif arch == 'resnet18_dilated8':
|
| 117 |
+
orig_resnet = resnet18(pretrained=pretrained)
|
| 118 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
| 119 |
+
arch_net.num_features = 512
|
| 120 |
+
|
| 121 |
+
elif arch == 'resnet34':
|
| 122 |
+
orig_resnet = resnet34(pretrained=pretrained)
|
| 123 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
| 124 |
+
arch_net.num_features = 512
|
| 125 |
+
|
| 126 |
+
elif arch == 'resnet34_dilated8':
|
| 127 |
+
orig_resnet = resnet34(pretrained=pretrained)
|
| 128 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
| 129 |
+
arch_net.num_features = 512
|
| 130 |
+
|
| 131 |
+
elif arch == 'resnet34_dilated16':
|
| 132 |
+
orig_resnet = resnet34(pretrained=pretrained)
|
| 133 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
| 134 |
+
arch_net.num_features = 512
|
| 135 |
+
|
| 136 |
+
elif arch == 'resnet50':
|
| 137 |
+
orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
|
| 138 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
| 139 |
+
|
| 140 |
+
elif arch == 'resnet50_dilated8':
|
| 141 |
+
orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
|
| 142 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
| 143 |
+
|
| 144 |
+
elif arch == 'resnet50_dilated16':
|
| 145 |
+
orig_resnet = resnet50(pretrained=pretrained)
|
| 146 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
| 147 |
+
|
| 148 |
+
elif arch == 'deepbase_resnet50':
|
| 149 |
+
if pretrained:
|
| 150 |
+
pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
|
| 151 |
+
orig_resnet = deepbase_resnet50(pretrained=pretrained)
|
| 152 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
| 153 |
+
|
| 154 |
+
elif arch == 'deepbase_resnet50_dilated8':
|
| 155 |
+
if pretrained:
|
| 156 |
+
pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
|
| 157 |
+
# pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth"
|
| 158 |
+
orig_resnet = deepbase_resnet50(pretrained=pretrained)
|
| 159 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
| 160 |
+
|
| 161 |
+
elif arch == 'deepbase_resnet50_dilated16':
|
| 162 |
+
orig_resnet = deepbase_resnet50(pretrained=pretrained)
|
| 163 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
| 164 |
+
|
| 165 |
+
elif arch == 'resnet101':
|
| 166 |
+
orig_resnet = resnet101(pretrained=pretrained)
|
| 167 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
| 168 |
+
|
| 169 |
+
elif arch == 'resnet101_dilated8':
|
| 170 |
+
orig_resnet = resnet101(pretrained=pretrained)
|
| 171 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
| 172 |
+
|
| 173 |
+
elif arch == 'resnet101_dilated16':
|
| 174 |
+
orig_resnet = resnet101(pretrained=pretrained)
|
| 175 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
| 176 |
+
|
| 177 |
+
elif arch == 'deepbase_resnet101':
|
| 178 |
+
orig_resnet = deepbase_resnet101(pretrained=pretrained)
|
| 179 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
| 180 |
+
|
| 181 |
+
elif arch == 'deepbase_resnet101_dilated8':
|
| 182 |
+
if pretrained:
|
| 183 |
+
pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth'
|
| 184 |
+
orig_resnet = deepbase_resnet101(pretrained=pretrained)
|
| 185 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
| 186 |
+
|
| 187 |
+
elif arch == 'deepbase_resnet101_dilated16':
|
| 188 |
+
orig_resnet = deepbase_resnet101(pretrained=pretrained)
|
| 189 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
| 190 |
+
|
| 191 |
+
else:
|
| 192 |
+
raise Exception('Architecture undefined!')
|
| 193 |
+
|
| 194 |
+
return arch_net
|
networks/resnet_models.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
# Author: Donny You([email protected])
|
| 4 |
+
import math
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from .module_helper import ModuleHelper
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
model_urls = {
|
| 11 |
+
'resnet18': 'https://download.pytorch.org/backbones/resnet18-5c106cde.pth',
|
| 12 |
+
'resnet34': 'https://download.pytorch.org/backbones/resnet34-333f7ec4.pth',
|
| 13 |
+
'resnet50': 'https://download.pytorch.org/backbones/resnet50-19c8e357.pth',
|
| 14 |
+
'resnet101': 'https://download.pytorch.org/backbones/resnet101-5d3b4d8f.pth',
|
| 15 |
+
'resnet152': 'https://download.pytorch.org/backbones/resnet152-b121ed2d.pth'
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 20 |
+
"3x3 convolution with padding"
|
| 21 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 22 |
+
padding=1, bias=False)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BasicBlock(nn.Module):
|
| 26 |
+
expansion = 1
|
| 27 |
+
|
| 28 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
|
| 29 |
+
super(BasicBlock, self).__init__()
|
| 30 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 31 |
+
self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
|
| 32 |
+
self.relu = nn.ReLU(inplace=True)
|
| 33 |
+
self.conv2 = conv3x3(planes, planes)
|
| 34 |
+
self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
|
| 35 |
+
self.downsample = downsample
|
| 36 |
+
self.stride = stride
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
residual = x
|
| 40 |
+
|
| 41 |
+
out = self.conv1(x)
|
| 42 |
+
out = self.bn1(out)
|
| 43 |
+
out = self.relu(out)
|
| 44 |
+
|
| 45 |
+
out = self.conv2(out)
|
| 46 |
+
out = self.bn2(out)
|
| 47 |
+
|
| 48 |
+
if self.downsample is not None:
|
| 49 |
+
residual = self.downsample(x)
|
| 50 |
+
|
| 51 |
+
out += residual
|
| 52 |
+
out = self.relu(out)
|
| 53 |
+
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Bottleneck(nn.Module):
|
| 58 |
+
expansion = 4
|
| 59 |
+
|
| 60 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
|
| 61 |
+
super(Bottleneck, self).__init__()
|
| 62 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 63 |
+
self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
|
| 64 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 65 |
+
padding=1, bias=False)
|
| 66 |
+
self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
|
| 67 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 68 |
+
self.bn3 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes * 4)
|
| 69 |
+
self.relu = nn.ReLU(inplace=True)
|
| 70 |
+
self.downsample = downsample
|
| 71 |
+
self.stride = stride
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
residual = x
|
| 75 |
+
|
| 76 |
+
out = self.conv1(x)
|
| 77 |
+
out = self.bn1(out)
|
| 78 |
+
out = self.relu(out)
|
| 79 |
+
|
| 80 |
+
out = self.conv2(out)
|
| 81 |
+
out = self.bn2(out)
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
|
| 84 |
+
out = self.conv3(out)
|
| 85 |
+
out = self.bn3(out)
|
| 86 |
+
|
| 87 |
+
if self.downsample is not None:
|
| 88 |
+
residual = self.downsample(x)
|
| 89 |
+
|
| 90 |
+
out += residual
|
| 91 |
+
out = self.relu(out)
|
| 92 |
+
|
| 93 |
+
return out
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ResNet(nn.Module):
|
| 97 |
+
def __init__(self, block, layers, width_multiplier=1.0, num_classes=1000, deep_base=False, norm_type=None):
|
| 98 |
+
super(ResNet, self).__init__()
|
| 99 |
+
self.inplanes = 128 if deep_base else int(64 * width_multiplier)
|
| 100 |
+
self.width_multiplier = width_multiplier
|
| 101 |
+
if deep_base:
|
| 102 |
+
self.prefix = nn.Sequential(OrderedDict([
|
| 103 |
+
('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)),
|
| 104 |
+
('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
|
| 105 |
+
('relu1', nn.ReLU(inplace=False)),
|
| 106 |
+
('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)),
|
| 107 |
+
('bn2', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
|
| 108 |
+
('relu2', nn.ReLU(inplace=False)),
|
| 109 |
+
('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)),
|
| 110 |
+
('bn3', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
|
| 111 |
+
('relu3', nn.ReLU(inplace=False))]
|
| 112 |
+
))
|
| 113 |
+
else:
|
| 114 |
+
self.prefix = nn.Sequential(OrderedDict([
|
| 115 |
+
('conv1', nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
|
| 116 |
+
('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
|
| 117 |
+
('relu', nn.ReLU(inplace=False))]
|
| 118 |
+
))
|
| 119 |
+
|
| 120 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) # change.
|
| 121 |
+
|
| 122 |
+
self.layer1 = self._make_layer(block, int(64 * width_multiplier), layers[0], norm_type=norm_type)
|
| 123 |
+
self.layer2 = self._make_layer(block, int(128 * width_multiplier), layers[1], stride=2, norm_type=norm_type)
|
| 124 |
+
self.layer3 = self._make_layer(block, int(256 * width_multiplier), layers[2], stride=2, norm_type=norm_type)
|
| 125 |
+
self.layer4 = self._make_layer(block, int(512 * width_multiplier), layers[3], stride=2, norm_type=norm_type)
|
| 126 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
| 127 |
+
self.fc = nn.Linear(int(512 * block.expansion * width_multiplier), num_classes)
|
| 128 |
+
|
| 129 |
+
for m in self.modules():
|
| 130 |
+
if isinstance(m, nn.Conv2d):
|
| 131 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 132 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 133 |
+
elif isinstance(m, ModuleHelper.BatchNorm2d(norm_type=norm_type, ret_cls=True)):
|
| 134 |
+
m.weight.data.fill_(1)
|
| 135 |
+
m.bias.data.zero_()
|
| 136 |
+
|
| 137 |
+
def _make_layer(self, block, planes, blocks, stride=1, norm_type=None):
|
| 138 |
+
downsample = None
|
| 139 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 140 |
+
downsample = nn.Sequential(
|
| 141 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 142 |
+
kernel_size=1, stride=stride, bias=False),
|
| 143 |
+
ModuleHelper.BatchNorm2d(norm_type=norm_type)(int(planes * block.expansion * self.width_multiplier)),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
layers = []
|
| 147 |
+
layers.append(block(self.inplanes, planes,
|
| 148 |
+
stride, downsample, norm_type=norm_type))
|
| 149 |
+
|
| 150 |
+
self.inplanes = planes * block.expansion
|
| 151 |
+
for i in range(1, blocks):
|
| 152 |
+
layers.append(block(self.inplanes, planes, norm_type=norm_type))
|
| 153 |
+
|
| 154 |
+
return nn.Sequential(*layers)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
x = self.prefix(x)
|
| 158 |
+
x = self.maxpool(x)
|
| 159 |
+
|
| 160 |
+
x = self.layer1(x)
|
| 161 |
+
x = self.layer2(x)
|
| 162 |
+
x = self.layer3(x)
|
| 163 |
+
x = self.layer4(x)
|
| 164 |
+
|
| 165 |
+
x = self.avgpool(x)
|
| 166 |
+
x = x.view(x.size(0), -1)
|
| 167 |
+
x = self.fc(x)
|
| 168 |
+
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 173 |
+
"""Constructs a ResNet-18 model.
|
| 174 |
+
Args:
|
| 175 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 176 |
+
norm_type (str): choose norm type
|
| 177 |
+
"""
|
| 178 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=False, norm_type=norm_type)
|
| 179 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 180 |
+
return model
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def deepbase_resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 184 |
+
"""Constructs a ResNet-18 model.
|
| 185 |
+
Args:
|
| 186 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 187 |
+
"""
|
| 188 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
| 189 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 190 |
+
return model
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 194 |
+
"""Constructs a ResNet-34 model.
|
| 195 |
+
Args:
|
| 196 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 197 |
+
"""
|
| 198 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
|
| 199 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 200 |
+
return model
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def deepbase_resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 204 |
+
"""Constructs a ResNet-34 model.
|
| 205 |
+
Args:
|
| 206 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 207 |
+
"""
|
| 208 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
| 209 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 210 |
+
return model
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 214 |
+
"""Constructs a ResNet-50 model.
|
| 215 |
+
Args:
|
| 216 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 217 |
+
"""
|
| 218 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type,
|
| 219 |
+
width_multiplier=kwargs["width_multiplier"])
|
| 220 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 221 |
+
return model
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def deepbase_resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 225 |
+
"""Constructs a ResNet-50 model.
|
| 226 |
+
Args:
|
| 227 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 228 |
+
"""
|
| 229 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
| 230 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 231 |
+
return model
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 235 |
+
"""Constructs a ResNet-101 model.
|
| 236 |
+
Args:
|
| 237 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 238 |
+
"""
|
| 239 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
|
| 240 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 241 |
+
return model
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def deepbase_resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 245 |
+
"""Constructs a ResNet-101 model.
|
| 246 |
+
Args:
|
| 247 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 248 |
+
"""
|
| 249 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
| 250 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 251 |
+
return model
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 255 |
+
"""Constructs a ResNet-152 model.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 259 |
+
"""
|
| 260 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
|
| 261 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 262 |
+
return model
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def deepbase_resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
| 266 |
+
"""Constructs a ResNet-152 model.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
| 270 |
+
"""
|
| 271 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
| 272 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
| 273 |
+
return model
|
networks/timm_deit.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
from networks.timm_vit import VisionTransformer, _cfg
|
| 9 |
+
from timm.models.registry import register_model
|
| 10 |
+
from timm.models.layers import trunc_normal_
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
|
| 15 |
+
'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
|
| 16 |
+
'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
|
| 17 |
+
'deit_base_distilled_patch16_384',
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DistilledVisionTransformer(VisionTransformer):
|
| 22 |
+
def __init__(self, *args, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
| 25 |
+
num_patches = self.patch_embed.num_patches
|
| 26 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
|
| 27 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
| 28 |
+
|
| 29 |
+
trunc_normal_(self.dist_token, std=.02)
|
| 30 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 31 |
+
self.head_dist.apply(self._init_weights)
|
| 32 |
+
|
| 33 |
+
def forward_features(self, x):
|
| 34 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 35 |
+
# with slight modifications to add the dist_token
|
| 36 |
+
B = x.shape[0]
|
| 37 |
+
x = self.patch_embed(x)
|
| 38 |
+
|
| 39 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 40 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
| 41 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
| 42 |
+
|
| 43 |
+
x = x + self.pos_embed
|
| 44 |
+
x = self.pos_drop(x)
|
| 45 |
+
|
| 46 |
+
for blk in self.blocks:
|
| 47 |
+
x = blk(x)
|
| 48 |
+
|
| 49 |
+
x = self.norm(x)
|
| 50 |
+
return x[:, 0], x[:, 1]
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
x, x_dist = self.forward_features(x)
|
| 54 |
+
x = self.head(x)
|
| 55 |
+
x_dist = self.head_dist(x_dist)
|
| 56 |
+
if self.training:
|
| 57 |
+
return x, x_dist
|
| 58 |
+
else:
|
| 59 |
+
# during inference, return the average of both classifier predictions
|
| 60 |
+
return (x + x_dist) / 2
|
| 61 |
+
|
| 62 |
+
def interpolate_pos_encoding(self, x, pos_embed):
|
| 63 |
+
"""Interpolate the learnable positional encoding to match the number of patches.
|
| 64 |
+
|
| 65 |
+
x: B x (1 + 1 + N patches) x dim_embedding
|
| 66 |
+
pos_embed: B x (1 + 1 + N patches) x dim_embedding
|
| 67 |
+
|
| 68 |
+
return interpolated positional embedding
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
npatch = x.shape[1] - 2 # (H // patch_size * W // patch_size)
|
| 72 |
+
N = pos_embed.shape[1] - 2 # 784 (= 28 x 28)
|
| 73 |
+
|
| 74 |
+
if npatch == N:
|
| 75 |
+
return pos_embed
|
| 76 |
+
|
| 77 |
+
class_emb, distil_token, pos_embed = pos_embed[:, 0], pos_embed[:, 1], pos_embed[:, 2:] # a learnable CLS token, learnable position embeddings
|
| 78 |
+
|
| 79 |
+
dim = x.shape[-1] # dimension of embeddings
|
| 80 |
+
pos_embed = nn.functional.interpolate(
|
| 81 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
|
| 82 |
+
scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer.
|
| 83 |
+
recompute_scale_factor=True,
|
| 84 |
+
mode='bicubic'
|
| 85 |
+
)
|
| 86 |
+
# print("pos_embed", pos_embed.shape, npatch, N, math.sqrt(npatch/N), math.sqrt(npatch/N) * int(math.sqrt(N)))
|
| 87 |
+
# exit(12)
|
| 88 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 89 |
+
pos_embed = torch.cat((class_emb.unsqueeze(0), distil_token.unsqueeze(0), pos_embed), dim=1)
|
| 90 |
+
return pos_embed
|
| 91 |
+
|
| 92 |
+
def get_tokens(
|
| 93 |
+
self,
|
| 94 |
+
x,
|
| 95 |
+
layers: list,
|
| 96 |
+
patch_tokens: bool = False,
|
| 97 |
+
norm: bool = True,
|
| 98 |
+
input_tokens: bool = False,
|
| 99 |
+
post_pe: bool = False
|
| 100 |
+
):
|
| 101 |
+
"""Return intermediate tokens."""
|
| 102 |
+
list_tokens: list = []
|
| 103 |
+
|
| 104 |
+
B = x.shape[0]
|
| 105 |
+
x = self.patch_embed(x)
|
| 106 |
+
|
| 107 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 108 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
| 109 |
+
|
| 110 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
| 111 |
+
|
| 112 |
+
if input_tokens:
|
| 113 |
+
list_tokens.append(x)
|
| 114 |
+
|
| 115 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
| 116 |
+
x = x + pos_embed
|
| 117 |
+
|
| 118 |
+
if post_pe:
|
| 119 |
+
list_tokens.append(x)
|
| 120 |
+
|
| 121 |
+
x = self.pos_drop(x)
|
| 122 |
+
|
| 123 |
+
for i, blk in enumerate(self.blocks):
|
| 124 |
+
x = blk(x) # B x # patches x dim
|
| 125 |
+
if layers is None or i in layers:
|
| 126 |
+
list_tokens.append(self.norm(x) if norm else x)
|
| 127 |
+
|
| 128 |
+
tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
|
| 129 |
+
|
| 130 |
+
if not patch_tokens:
|
| 131 |
+
return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
return torch.cat((tokens[:, :, 0, :].unsqueeze(dim=2), tokens[:, :, 2:, :]), dim=2) # exclude distil token.
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@register_model
|
| 138 |
+
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
| 139 |
+
model = VisionTransformer(
|
| 140 |
+
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
| 141 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 142 |
+
model.default_cfg = _cfg()
|
| 143 |
+
if pretrained:
|
| 144 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 145 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
|
| 146 |
+
map_location="cpu", check_hash=True
|
| 147 |
+
)
|
| 148 |
+
model.load_state_dict(checkpoint["model"])
|
| 149 |
+
return model
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@register_model
|
| 153 |
+
def deit_small_patch16_224(pretrained=False, **kwargs):
|
| 154 |
+
model = VisionTransformer(
|
| 155 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 156 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 157 |
+
model.default_cfg = _cfg()
|
| 158 |
+
if pretrained:
|
| 159 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 160 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
|
| 161 |
+
map_location="cpu", check_hash=True
|
| 162 |
+
)
|
| 163 |
+
model.load_state_dict(checkpoint["model"])
|
| 164 |
+
return model
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@register_model
|
| 168 |
+
def deit_base_patch16_224(pretrained=False, **kwargs):
|
| 169 |
+
model = VisionTransformer(
|
| 170 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 171 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 172 |
+
model.default_cfg = _cfg()
|
| 173 |
+
if pretrained:
|
| 174 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 175 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 176 |
+
map_location="cpu", check_hash=True
|
| 177 |
+
)
|
| 178 |
+
model.load_state_dict(checkpoint["model"])
|
| 179 |
+
return model
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@register_model
|
| 183 |
+
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
| 184 |
+
model = DistilledVisionTransformer(
|
| 185 |
+
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
| 186 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 187 |
+
model.default_cfg = _cfg()
|
| 188 |
+
if pretrained:
|
| 189 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 190 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
|
| 191 |
+
map_location="cpu", check_hash=True
|
| 192 |
+
)
|
| 193 |
+
model.load_state_dict(checkpoint["model"])
|
| 194 |
+
return model
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@register_model
|
| 198 |
+
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
| 199 |
+
model = DistilledVisionTransformer(
|
| 200 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 201 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 202 |
+
model.default_cfg = _cfg()
|
| 203 |
+
if pretrained:
|
| 204 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 205 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
|
| 206 |
+
map_location="cpu", check_hash=True
|
| 207 |
+
)
|
| 208 |
+
model.load_state_dict(checkpoint["model"])
|
| 209 |
+
return model
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@register_model
|
| 213 |
+
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
| 214 |
+
model = DistilledVisionTransformer(
|
| 215 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 216 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 217 |
+
model.default_cfg = _cfg()
|
| 218 |
+
if pretrained:
|
| 219 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 220 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
|
| 221 |
+
map_location="cpu", check_hash=True
|
| 222 |
+
)
|
| 223 |
+
model.load_state_dict(checkpoint["model"])
|
| 224 |
+
return model
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@register_model
|
| 228 |
+
def deit_base_patch16_384(pretrained=False, **kwargs):
|
| 229 |
+
model = VisionTransformer(
|
| 230 |
+
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 231 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 232 |
+
model.default_cfg = _cfg()
|
| 233 |
+
if pretrained:
|
| 234 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 235 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
|
| 236 |
+
map_location="cpu", check_hash=True
|
| 237 |
+
)
|
| 238 |
+
model.load_state_dict(checkpoint["model"])
|
| 239 |
+
return model
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@register_model
|
| 243 |
+
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
| 244 |
+
model = DistilledVisionTransformer(
|
| 245 |
+
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 246 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 247 |
+
model.default_cfg = _cfg()
|
| 248 |
+
if pretrained:
|
| 249 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 250 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
|
| 251 |
+
map_location="cpu", check_hash=True
|
| 252 |
+
)
|
| 253 |
+
model.load_state_dict(checkpoint["model"])
|
| 254 |
+
return model
|
networks/timm_vit.py
ADDED
|
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Vision Transformer (ViT) in PyTorch
|
| 2 |
+
|
| 3 |
+
A PyTorch implement of Vision Transformers as described in
|
| 4 |
+
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
| 5 |
+
|
| 6 |
+
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
| 7 |
+
|
| 8 |
+
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
| 9 |
+
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
| 10 |
+
|
| 11 |
+
Acknowledgments:
|
| 12 |
+
* The paper authors for releasing code and weights, thanks!
|
| 13 |
+
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
| 14 |
+
for some einops/einsum fun
|
| 15 |
+
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
| 16 |
+
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
| 17 |
+
|
| 18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 19 |
+
"""
|
| 20 |
+
import math
|
| 21 |
+
import logging
|
| 22 |
+
from functools import partial
|
| 23 |
+
from collections import OrderedDict
|
| 24 |
+
from copy import deepcopy
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 31 |
+
from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
|
| 32 |
+
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
|
| 33 |
+
from timm.models.registry import register_model
|
| 34 |
+
|
| 35 |
+
_logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _cfg(url='', **kwargs):
|
| 39 |
+
return {
|
| 40 |
+
'url': url,
|
| 41 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 42 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
| 43 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 44 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
| 45 |
+
**kwargs
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
default_cfgs = {
|
| 50 |
+
# patch models (my experiments)
|
| 51 |
+
'vit_small_patch16_224': _cfg(
|
| 52 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
| 53 |
+
),
|
| 54 |
+
|
| 55 |
+
# patch models (weights ported from official Google JAX impl)
|
| 56 |
+
'vit_base_patch16_224': _cfg(
|
| 57 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
| 58 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
| 59 |
+
),
|
| 60 |
+
'vit_base_patch32_224': _cfg(
|
| 61 |
+
url='', # no official model weights for this combo, only for in21k
|
| 62 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 63 |
+
'vit_base_patch16_384': _cfg(
|
| 64 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
| 65 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
| 66 |
+
'vit_base_patch32_384': _cfg(
|
| 67 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
| 68 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
| 69 |
+
'vit_large_patch16_224': _cfg(
|
| 70 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
| 71 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 72 |
+
'vit_large_patch32_224': _cfg(
|
| 73 |
+
url='', # no official model weights for this combo, only for in21k
|
| 74 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 75 |
+
'vit_large_patch16_384': _cfg(
|
| 76 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
| 77 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
| 78 |
+
'vit_large_patch32_384': _cfg(
|
| 79 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
| 80 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
| 81 |
+
|
| 82 |
+
# patch models, imagenet21k (weights ported from official Google JAX impl)
|
| 83 |
+
'vit_base_patch16_224_in21k': _cfg(
|
| 84 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
|
| 85 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 86 |
+
'vit_base_patch32_224_in21k': _cfg(
|
| 87 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
|
| 88 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 89 |
+
'vit_large_patch16_224_in21k': _cfg(
|
| 90 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
|
| 91 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 92 |
+
'vit_large_patch32_224_in21k': _cfg(
|
| 93 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
| 94 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 95 |
+
'vit_huge_patch14_224_in21k': _cfg(
|
| 96 |
+
hf_hub='timm/vit_huge_patch14_224_in21k',
|
| 97 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 98 |
+
|
| 99 |
+
# deit models (FB weights)
|
| 100 |
+
'vit_deit_tiny_patch16_224': _cfg(
|
| 101 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
| 102 |
+
'vit_deit_small_patch16_224': _cfg(
|
| 103 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
| 104 |
+
'vit_deit_base_patch16_224': _cfg(
|
| 105 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
| 106 |
+
'vit_deit_base_patch16_384': _cfg(
|
| 107 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
| 108 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 109 |
+
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
| 110 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
| 111 |
+
classifier=('head', 'head_dist')),
|
| 112 |
+
'vit_deit_small_distilled_patch16_224': _cfg(
|
| 113 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
| 114 |
+
classifier=('head', 'head_dist')),
|
| 115 |
+
'vit_deit_base_distilled_patch16_224': _cfg(
|
| 116 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
| 117 |
+
classifier=('head', 'head_dist')),
|
| 118 |
+
'vit_deit_base_distilled_patch16_384': _cfg(
|
| 119 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
| 120 |
+
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
| 121 |
+
|
| 122 |
+
# ViT ImageNet-21K-P pretraining
|
| 123 |
+
'vit_base_patch16_224_miil_in21k': _cfg(
|
| 124 |
+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
|
| 125 |
+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
|
| 126 |
+
),
|
| 127 |
+
'vit_base_patch16_224_miil': _cfg(
|
| 128 |
+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
|
| 129 |
+
'/vit_base_patch16_224_1k_miil_84_4.pth',
|
| 130 |
+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
|
| 131 |
+
),
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Attention(nn.Module):
|
| 136 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.num_heads = num_heads
|
| 139 |
+
head_dim = dim // num_heads
|
| 140 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 141 |
+
|
| 142 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 143 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 144 |
+
self.proj = nn.Linear(dim, dim)
|
| 145 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
B, N, C = x.shape
|
| 149 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 150 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 151 |
+
|
| 152 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 153 |
+
attn = attn.softmax(dim=-1)
|
| 154 |
+
attn = self.attn_drop(attn)
|
| 155 |
+
|
| 156 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 157 |
+
x = self.proj(x)
|
| 158 |
+
x = self.proj_drop(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class Block(nn.Module):
|
| 163 |
+
|
| 164 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 165 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.norm1 = norm_layer(dim)
|
| 168 |
+
self.attn = Attention(
|
| 169 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 170 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 171 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 172 |
+
self.norm2 = norm_layer(dim)
|
| 173 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 174 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 175 |
+
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 178 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 179 |
+
return x
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class VisionTransformer(nn.Module):
|
| 183 |
+
""" Vision Transformer
|
| 184 |
+
|
| 185 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
| 186 |
+
- https://arxiv.org/abs/2010.11929
|
| 187 |
+
|
| 188 |
+
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
| 189 |
+
- https://arxiv.org/abs/2012.12877
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 193 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
|
| 194 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
|
| 195 |
+
act_layer=None, weight_init='',
|
| 196 |
+
# noel
|
| 197 |
+
img_size_eval: int = 224):
|
| 198 |
+
"""
|
| 199 |
+
Args:
|
| 200 |
+
img_size (int, tuple): input image size
|
| 201 |
+
patch_size (int, tuple): patch size
|
| 202 |
+
in_chans (int): number of input channels
|
| 203 |
+
num_classes (int): number of classes for classification head
|
| 204 |
+
embed_dim (int): embedding dimension
|
| 205 |
+
depth (int): depth of transformer
|
| 206 |
+
num_heads (int): number of attention heads
|
| 207 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 208 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 209 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
| 210 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
| 211 |
+
distilled (bool): model includes a distillation token and head as in DeiT models
|
| 212 |
+
drop_rate (float): dropout rate
|
| 213 |
+
attn_drop_rate (float): attention dropout rate
|
| 214 |
+
drop_path_rate (float): stochastic depth rate
|
| 215 |
+
embed_layer (nn.Module): patch embedding layer
|
| 216 |
+
norm_layer: (nn.Module): normalization layer
|
| 217 |
+
weight_init: (str): weight init scheme
|
| 218 |
+
"""
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.num_classes = num_classes
|
| 221 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 222 |
+
self.num_tokens = 2 if distilled else 1
|
| 223 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 224 |
+
act_layer = act_layer or nn.GELU
|
| 225 |
+
|
| 226 |
+
self.patch_embed = embed_layer(
|
| 227 |
+
img_size=img_size,
|
| 228 |
+
patch_size=patch_size,
|
| 229 |
+
in_chans=in_chans,
|
| 230 |
+
embed_dim=embed_dim
|
| 231 |
+
)
|
| 232 |
+
num_patches = self.patch_embed.num_patches
|
| 233 |
+
|
| 234 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 235 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
| 236 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 237 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 238 |
+
|
| 239 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 240 |
+
self.blocks = nn.Sequential(*[
|
| 241 |
+
Block(
|
| 242 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 243 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
| 244 |
+
for i in range(depth)])
|
| 245 |
+
self.norm = norm_layer(embed_dim)
|
| 246 |
+
|
| 247 |
+
# Representation layer
|
| 248 |
+
if representation_size and not distilled:
|
| 249 |
+
self.num_features = representation_size
|
| 250 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
| 251 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
| 252 |
+
('act', nn.Tanh())
|
| 253 |
+
]))
|
| 254 |
+
else:
|
| 255 |
+
self.pre_logits = nn.Identity()
|
| 256 |
+
|
| 257 |
+
# Classifier head(s)
|
| 258 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 259 |
+
self.head_dist = None
|
| 260 |
+
if distilled:
|
| 261 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
| 262 |
+
|
| 263 |
+
# Weight init
|
| 264 |
+
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
|
| 265 |
+
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
|
| 266 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 267 |
+
if self.dist_token is not None:
|
| 268 |
+
trunc_normal_(self.dist_token, std=.02)
|
| 269 |
+
if weight_init.startswith('jax'):
|
| 270 |
+
# leave cls token as zeros to match jax impl
|
| 271 |
+
for n, m in self.named_modules():
|
| 272 |
+
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
|
| 273 |
+
else:
|
| 274 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 275 |
+
self.apply(_init_vit_weights)
|
| 276 |
+
|
| 277 |
+
# noel
|
| 278 |
+
self.depth = depth
|
| 279 |
+
self.distilled = distilled
|
| 280 |
+
self.patch_size = patch_size
|
| 281 |
+
self.patch_embed.img_size = (img_size_eval, img_size_eval)
|
| 282 |
+
|
| 283 |
+
def _init_weights(self, m):
|
| 284 |
+
# this fn left here for compat with downstream users
|
| 285 |
+
_init_vit_weights(m)
|
| 286 |
+
|
| 287 |
+
@torch.jit.ignore
|
| 288 |
+
def no_weight_decay(self):
|
| 289 |
+
return {'pos_embed', 'cls_token', 'dist_token'}
|
| 290 |
+
|
| 291 |
+
def get_classifier(self):
|
| 292 |
+
if self.dist_token is None:
|
| 293 |
+
return self.head
|
| 294 |
+
else:
|
| 295 |
+
return self.head, self.head_dist
|
| 296 |
+
|
| 297 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
| 298 |
+
self.num_classes = num_classes
|
| 299 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 300 |
+
if self.num_tokens == 2:
|
| 301 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
| 302 |
+
|
| 303 |
+
def forward_features(self, x):
|
| 304 |
+
x = self.patch_embed(x)
|
| 305 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 306 |
+
if self.dist_token is None:
|
| 307 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 308 |
+
else:
|
| 309 |
+
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 310 |
+
x = self.pos_drop(x + self.pos_embed)
|
| 311 |
+
x = self.blocks(x)
|
| 312 |
+
x = self.norm(x)
|
| 313 |
+
if self.dist_token is None:
|
| 314 |
+
return self.pre_logits(x[:, 0])
|
| 315 |
+
else:
|
| 316 |
+
return x[:, 0], x[:, 1]
|
| 317 |
+
|
| 318 |
+
# def forward(self, x):
|
| 319 |
+
# x = self.forward_features(x)
|
| 320 |
+
# if self.head_dist is not None:
|
| 321 |
+
# x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
|
| 322 |
+
# if self.training and not torch.jit.is_scripting():
|
| 323 |
+
# # during inference, return the average of both classifier predictions
|
| 324 |
+
# return x, x_dist
|
| 325 |
+
# else:
|
| 326 |
+
# return (x + x_dist) / 2
|
| 327 |
+
# else:
|
| 328 |
+
# x = self.head(x)
|
| 329 |
+
# return x
|
| 330 |
+
|
| 331 |
+
# noel - start
|
| 332 |
+
def make_square(self, x: torch.Tensor):
|
| 333 |
+
"""Pad some pixels to make the input size divisible by the patch size."""
|
| 334 |
+
B, _, H_0, W_0 = x.shape
|
| 335 |
+
pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
|
| 336 |
+
pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
|
| 337 |
+
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=x.mean())
|
| 338 |
+
|
| 339 |
+
H_p, W_p = H_0 + pad_h, W_0 + pad_w
|
| 340 |
+
x = nn.functional.pad(x, (0, H_p - W_p, 0, 0) if H_p > W_p else (0, 0, 0, W_p - H_p), value=x.mean())
|
| 341 |
+
return x
|
| 342 |
+
|
| 343 |
+
def interpolate_pos_encoding(self, x, pos_embed, size):
|
| 344 |
+
"""Interpolate the learnable positional encoding to match the number of patches.
|
| 345 |
+
|
| 346 |
+
x: B x (1 + N patches) x dim_embedding
|
| 347 |
+
pos_embed: B x (1 + N patches) x dim_embedding
|
| 348 |
+
|
| 349 |
+
return interpolated positional embedding
|
| 350 |
+
"""
|
| 351 |
+
npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
|
| 352 |
+
N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
|
| 353 |
+
if npatch == N:
|
| 354 |
+
return pos_embed
|
| 355 |
+
class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
|
| 356 |
+
|
| 357 |
+
dim = x.shape[-1] # dimension of embeddings
|
| 358 |
+
pos_embed = nn.functional.interpolate(
|
| 359 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
|
| 360 |
+
size=size,
|
| 361 |
+
mode='bicubic',
|
| 362 |
+
align_corners=False
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 366 |
+
pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
| 367 |
+
return pos_embed
|
| 368 |
+
|
| 369 |
+
# def interpolate_pos_encoding(self, x, pos_embed):
|
| 370 |
+
# """Interpolate the learnable positional encoding to match the number of patches.
|
| 371 |
+
#
|
| 372 |
+
# x: B x (1 + N patches) x dim_embedding
|
| 373 |
+
# pos_embed: B x (1 + N patches) x dim_embedding
|
| 374 |
+
#
|
| 375 |
+
# return interpolated positional embedding
|
| 376 |
+
# """
|
| 377 |
+
# npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
|
| 378 |
+
# N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
|
| 379 |
+
# if npatch == N:
|
| 380 |
+
# return pos_embed
|
| 381 |
+
# class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
|
| 382 |
+
#
|
| 383 |
+
# dim = x.shape[-1] # dimension of embeddings
|
| 384 |
+
# pos_embed = nn.functional.interpolate(
|
| 385 |
+
# pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
|
| 386 |
+
# scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer.
|
| 387 |
+
# recompute_scale_factor=True,
|
| 388 |
+
# mode='bicubic',
|
| 389 |
+
# align_corners=False
|
| 390 |
+
# )
|
| 391 |
+
#
|
| 392 |
+
# pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 393 |
+
# pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
| 394 |
+
# return pos_embed
|
| 395 |
+
|
| 396 |
+
def prepare_tokens(self, x):
|
| 397 |
+
B, nc, h, w = x.shape
|
| 398 |
+
patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
| 399 |
+
x = self.patch_embed(x) # patch linear embedding
|
| 400 |
+
|
| 401 |
+
# add the [CLS] token to the embed patch tokens
|
| 402 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 403 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 404 |
+
|
| 405 |
+
# add positional encoding to each token
|
| 406 |
+
x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
|
| 407 |
+
return self.pos_drop(x)
|
| 408 |
+
|
| 409 |
+
def get_tokens(
|
| 410 |
+
self,
|
| 411 |
+
x,
|
| 412 |
+
layers: list,
|
| 413 |
+
patch_tokens: bool = False,
|
| 414 |
+
norm: bool = True,
|
| 415 |
+
input_tokens: bool = False,
|
| 416 |
+
post_pe: bool = False
|
| 417 |
+
):
|
| 418 |
+
"""Return intermediate tokens."""
|
| 419 |
+
list_tokens: list = []
|
| 420 |
+
|
| 421 |
+
B = x.shape[0]
|
| 422 |
+
x = self.patch_embed(x)
|
| 423 |
+
|
| 424 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 425 |
+
|
| 426 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 427 |
+
|
| 428 |
+
if input_tokens:
|
| 429 |
+
list_tokens.append(x)
|
| 430 |
+
|
| 431 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
| 432 |
+
x = x + pos_embed
|
| 433 |
+
|
| 434 |
+
if post_pe:
|
| 435 |
+
list_tokens.append(x)
|
| 436 |
+
|
| 437 |
+
x = self.pos_drop(x)
|
| 438 |
+
|
| 439 |
+
for i, blk in enumerate(self.blocks):
|
| 440 |
+
x = blk(x) # B x # patches x dim
|
| 441 |
+
if layers is None or i in layers:
|
| 442 |
+
list_tokens.append(self.norm(x) if norm else x)
|
| 443 |
+
|
| 444 |
+
tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
|
| 445 |
+
|
| 446 |
+
if not patch_tokens:
|
| 447 |
+
return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
|
| 448 |
+
|
| 449 |
+
else:
|
| 450 |
+
return tokens
|
| 451 |
+
|
| 452 |
+
def forward(self, x, layer: str = None):
|
| 453 |
+
x = self.prepare_tokens(x)
|
| 454 |
+
|
| 455 |
+
features: dict = {}
|
| 456 |
+
for i, blk in enumerate(self.blocks):
|
| 457 |
+
x = blk(x)
|
| 458 |
+
features[f"layer{i + 1}"] = self.norm(x)
|
| 459 |
+
|
| 460 |
+
if layer is not None:
|
| 461 |
+
return features[layer]
|
| 462 |
+
else:
|
| 463 |
+
return features["layer12"]
|
| 464 |
+
# noel - end
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
|
| 468 |
+
""" ViT weight initialization
|
| 469 |
+
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
| 470 |
+
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
| 471 |
+
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
| 472 |
+
"""
|
| 473 |
+
if isinstance(m, nn.Linear):
|
| 474 |
+
if n.startswith('head'):
|
| 475 |
+
nn.init.zeros_(m.weight)
|
| 476 |
+
nn.init.constant_(m.bias, head_bias)
|
| 477 |
+
elif n.startswith('pre_logits'):
|
| 478 |
+
lecun_normal_(m.weight)
|
| 479 |
+
nn.init.zeros_(m.bias)
|
| 480 |
+
else:
|
| 481 |
+
if jax_impl:
|
| 482 |
+
nn.init.xavier_uniform_(m.weight)
|
| 483 |
+
if m.bias is not None:
|
| 484 |
+
if 'mlp' in n:
|
| 485 |
+
nn.init.normal_(m.bias, std=1e-6)
|
| 486 |
+
else:
|
| 487 |
+
nn.init.zeros_(m.bias)
|
| 488 |
+
else:
|
| 489 |
+
trunc_normal_(m.weight, std=.02)
|
| 490 |
+
if m.bias is not None:
|
| 491 |
+
nn.init.zeros_(m.bias)
|
| 492 |
+
elif jax_impl and isinstance(m, nn.Conv2d):
|
| 493 |
+
# NOTE conv was left to pytorch default in my original init
|
| 494 |
+
lecun_normal_(m.weight)
|
| 495 |
+
if m.bias is not None:
|
| 496 |
+
nn.init.zeros_(m.bias)
|
| 497 |
+
elif isinstance(m, nn.LayerNorm):
|
| 498 |
+
nn.init.zeros_(m.bias)
|
| 499 |
+
nn.init.ones_(m.weight)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
| 503 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
| 504 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
| 505 |
+
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
| 506 |
+
ntok_new = posemb_new.shape[1]
|
| 507 |
+
if num_tokens:
|
| 508 |
+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
| 509 |
+
ntok_new -= num_tokens
|
| 510 |
+
else:
|
| 511 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
| 512 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
| 513 |
+
if not len(gs_new): # backwards compatibility
|
| 514 |
+
gs_new = [int(math.sqrt(ntok_new))] * 2
|
| 515 |
+
assert len(gs_new) >= 2
|
| 516 |
+
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
| 517 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
| 518 |
+
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
|
| 519 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
| 520 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
| 521 |
+
return posemb
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def checkpoint_filter_fn(state_dict, model):
|
| 525 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
| 526 |
+
out_dict = {}
|
| 527 |
+
if 'model' in state_dict:
|
| 528 |
+
# For deit models
|
| 529 |
+
state_dict = state_dict['model']
|
| 530 |
+
for k, v in state_dict.items():
|
| 531 |
+
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
| 532 |
+
# For old models that I trained prior to conv based patchification
|
| 533 |
+
O, I, H, W = model.patch_embed.proj.weight.shape
|
| 534 |
+
v = v.reshape(O, -1, H, W)
|
| 535 |
+
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
| 536 |
+
# To resize pos embedding when using model at different size from pretrained weights
|
| 537 |
+
v = resize_pos_embed(
|
| 538 |
+
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
| 539 |
+
out_dict[k] = v
|
| 540 |
+
return out_dict
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
| 544 |
+
default_cfg = default_cfg or default_cfgs[variant]
|
| 545 |
+
if kwargs.get('features_only', None):
|
| 546 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
| 547 |
+
|
| 548 |
+
# NOTE this extra code to support handling of repr size for in21k pretrained models
|
| 549 |
+
default_num_classes = default_cfg['num_classes']
|
| 550 |
+
num_classes = kwargs.get('num_classes', default_num_classes)
|
| 551 |
+
repr_size = kwargs.pop('representation_size', None)
|
| 552 |
+
if repr_size is not None and num_classes != default_num_classes:
|
| 553 |
+
# Remove representation layer if fine-tuning. This may not always be the desired action,
|
| 554 |
+
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
|
| 555 |
+
_logger.warning("Removing representation layer for fine-tuning.")
|
| 556 |
+
repr_size = None
|
| 557 |
+
|
| 558 |
+
model = build_model_with_cfg(
|
| 559 |
+
VisionTransformer, variant, pretrained,
|
| 560 |
+
default_cfg=default_cfg,
|
| 561 |
+
representation_size=repr_size,
|
| 562 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
| 563 |
+
**kwargs)
|
| 564 |
+
return model
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
@register_model
|
| 568 |
+
def vit_small_patch16_224(pretrained=False, **kwargs):
|
| 569 |
+
""" My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
|
| 570 |
+
NOTE:
|
| 571 |
+
* this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
|
| 572 |
+
* this model does not have a bias for QKV (unlike the official ViT and DeiT models)
|
| 573 |
+
"""
|
| 574 |
+
model_kwargs = dict(
|
| 575 |
+
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
|
| 576 |
+
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
|
| 577 |
+
if pretrained:
|
| 578 |
+
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
| 579 |
+
model_kwargs.setdefault('qk_scale', 768 ** -0.5)
|
| 580 |
+
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
| 581 |
+
return model
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
@register_model
|
| 585 |
+
def vit_base_patch16_224(pretrained=False, **kwargs):
|
| 586 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 587 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 588 |
+
"""
|
| 589 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 590 |
+
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
| 591 |
+
return model
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
@register_model
|
| 595 |
+
def vit_base_patch32_224(pretrained=False, **kwargs):
|
| 596 |
+
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
| 597 |
+
"""
|
| 598 |
+
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 599 |
+
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
|
| 600 |
+
return model
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
@register_model
|
| 604 |
+
def vit_base_patch16_384(pretrained=False, **kwargs):
|
| 605 |
+
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 606 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
| 607 |
+
"""
|
| 608 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 609 |
+
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
| 610 |
+
return model
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
@register_model
|
| 614 |
+
def vit_base_patch32_384(pretrained=False, **kwargs):
|
| 615 |
+
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
| 616 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
| 617 |
+
"""
|
| 618 |
+
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 619 |
+
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
|
| 620 |
+
return model
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
@register_model
|
| 624 |
+
def vit_large_patch16_224(pretrained=False, **kwargs):
|
| 625 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
| 626 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 627 |
+
"""
|
| 628 |
+
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
| 629 |
+
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
| 630 |
+
return model
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
@register_model
|
| 634 |
+
def vit_large_patch32_224(pretrained=False, **kwargs):
|
| 635 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
| 636 |
+
"""
|
| 637 |
+
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
| 638 |
+
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
|
| 639 |
+
return model
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
@register_model
|
| 643 |
+
def vit_large_patch16_384(pretrained=False, **kwargs):
|
| 644 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 645 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
| 646 |
+
"""
|
| 647 |
+
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
| 648 |
+
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
| 649 |
+
return model
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
@register_model
|
| 653 |
+
def vit_large_patch32_384(pretrained=False, **kwargs):
|
| 654 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
| 655 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
| 656 |
+
"""
|
| 657 |
+
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
| 658 |
+
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
|
| 659 |
+
return model
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
@register_model
|
| 663 |
+
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
| 664 |
+
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 665 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 666 |
+
"""
|
| 667 |
+
model_kwargs = dict(
|
| 668 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
| 669 |
+
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
| 670 |
+
return model
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
@register_model
|
| 674 |
+
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
| 675 |
+
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
| 676 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 677 |
+
"""
|
| 678 |
+
model_kwargs = dict(
|
| 679 |
+
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
| 680 |
+
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
| 681 |
+
return model
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
@register_model
|
| 685 |
+
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
| 686 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 687 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 688 |
+
"""
|
| 689 |
+
model_kwargs = dict(
|
| 690 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
| 691 |
+
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
| 692 |
+
return model
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
@register_model
|
| 696 |
+
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
| 697 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
| 698 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 699 |
+
"""
|
| 700 |
+
model_kwargs = dict(
|
| 701 |
+
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
| 702 |
+
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
| 703 |
+
return model
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
@register_model
|
| 707 |
+
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
| 708 |
+
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
| 709 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 710 |
+
NOTE: converted weights not currently available, too large for github release hosting.
|
| 711 |
+
"""
|
| 712 |
+
model_kwargs = dict(
|
| 713 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
|
| 714 |
+
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
|
| 715 |
+
return model
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
@register_model
|
| 719 |
+
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
| 720 |
+
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 721 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 722 |
+
"""
|
| 723 |
+
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
| 724 |
+
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
| 725 |
+
return model
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
@register_model
|
| 729 |
+
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
| 730 |
+
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 731 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 732 |
+
"""
|
| 733 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
| 734 |
+
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
| 735 |
+
return model
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
@register_model
|
| 739 |
+
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
|
| 740 |
+
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 741 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 742 |
+
"""
|
| 743 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 744 |
+
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
| 745 |
+
return model
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
@register_model
|
| 749 |
+
def vit_deit_base_patch16_384(pretrained=False, **kwargs):
|
| 750 |
+
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
| 751 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 752 |
+
"""
|
| 753 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 754 |
+
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
| 755 |
+
return model
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
@register_model
|
| 759 |
+
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
| 760 |
+
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 761 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 762 |
+
"""
|
| 763 |
+
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
| 764 |
+
model = _create_vision_transformer(
|
| 765 |
+
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 766 |
+
return model
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
@register_model
|
| 770 |
+
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
| 771 |
+
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 772 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 773 |
+
"""
|
| 774 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
| 775 |
+
model = _create_vision_transformer(
|
| 776 |
+
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 777 |
+
return model
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
@register_model
|
| 781 |
+
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
| 782 |
+
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 783 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 784 |
+
"""
|
| 785 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 786 |
+
model = _create_vision_transformer(
|
| 787 |
+
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 788 |
+
return model
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
@register_model
|
| 792 |
+
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
| 793 |
+
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
| 794 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 795 |
+
"""
|
| 796 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 797 |
+
model = _create_vision_transformer(
|
| 798 |
+
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 799 |
+
return model
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
@register_model
|
| 803 |
+
def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
|
| 804 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 805 |
+
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
| 806 |
+
"""
|
| 807 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
| 808 |
+
model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
|
| 809 |
+
return model
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
@register_model
|
| 813 |
+
def vit_base_patch16_224_miil(pretrained=False, **kwargs):
|
| 814 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 815 |
+
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
| 816 |
+
"""
|
| 817 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
| 818 |
+
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
|
| 819 |
+
return model
|
networks/vision_transformer.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Mostly copy-paste from timm library.
|
| 4 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 5 |
+
"""
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import math
|
| 8 |
+
from functools import partial
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 15 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 16 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 17 |
+
def norm_cdf(x):
|
| 18 |
+
# Computes standard normal cumulative distribution function
|
| 19 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 20 |
+
|
| 21 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 22 |
+
warnings.warn(
|
| 23 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.",
|
| 24 |
+
stacklevel=2
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
# Values are generated by using a truncated uniform distribution and
|
| 29 |
+
# then using the inverse CDF for the normal distribution.
|
| 30 |
+
# Get upper and lower cdf values
|
| 31 |
+
l = norm_cdf((a - mean) / std)
|
| 32 |
+
u = norm_cdf((b - mean) / std)
|
| 33 |
+
|
| 34 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 35 |
+
# [2l-1, 2u-1].
|
| 36 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 37 |
+
|
| 38 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 39 |
+
# standard normal
|
| 40 |
+
tensor.erfinv_()
|
| 41 |
+
|
| 42 |
+
# Transform to proper mean, std
|
| 43 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 44 |
+
tensor.add_(mean)
|
| 45 |
+
|
| 46 |
+
# Clamp to ensure it's in the proper range
|
| 47 |
+
tensor.clamp_(min=a, max=b)
|
| 48 |
+
return tensor
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 52 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
| 53 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 57 |
+
if drop_prob == 0. or not training:
|
| 58 |
+
return x
|
| 59 |
+
keep_prob = 1 - drop_prob
|
| 60 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 61 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 62 |
+
random_tensor.floor_() # binarize
|
| 63 |
+
output = x.div(keep_prob) * random_tensor
|
| 64 |
+
return output
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class DropPath(nn.Module):
|
| 68 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 69 |
+
"""
|
| 70 |
+
def __init__(self, drop_prob=None):
|
| 71 |
+
super(DropPath, self).__init__()
|
| 72 |
+
self.drop_prob = drop_prob
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Mlp(nn.Module):
|
| 79 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 80 |
+
super().__init__()
|
| 81 |
+
out_features = out_features or in_features
|
| 82 |
+
hidden_features = hidden_features or in_features
|
| 83 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 84 |
+
self.act = act_layer()
|
| 85 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 86 |
+
self.drop = nn.Dropout(drop)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
x = self.fc1(x)
|
| 90 |
+
x = self.act(x)
|
| 91 |
+
x = self.drop(x)
|
| 92 |
+
x = self.fc2(x)
|
| 93 |
+
x = self.drop(x)
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Attention(nn.Module):
|
| 98 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.num_heads = num_heads
|
| 101 |
+
head_dim = dim // num_heads
|
| 102 |
+
self.scale = qk_scale or head_dim ** -0.5 # square root of dimension for normalisation
|
| 103 |
+
|
| 104 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 105 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 106 |
+
|
| 107 |
+
self.proj = nn.Linear(dim, dim)
|
| 108 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
B, N, C = x.shape # B x (cls token + # patch tokens) x dim
|
| 112 |
+
|
| 113 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 114 |
+
# qkv: 3 x B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
| 115 |
+
|
| 116 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 117 |
+
# q, k, v: B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
| 118 |
+
|
| 119 |
+
# q: B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
| 120 |
+
# k.transpose(-2, -1) = B x Nh x (dim // Nh) x (cls token + # patch tokens)
|
| 121 |
+
# attn: B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
|
| 122 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale # @ operator is for matrix multiplication
|
| 123 |
+
attn = attn.softmax(dim=-1) # B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
|
| 124 |
+
attn = self.attn_drop(attn)
|
| 125 |
+
|
| 126 |
+
# attn = B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
|
| 127 |
+
# v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
| 128 |
+
# attn @ v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
| 129 |
+
# (attn @ v).transpose(1, 2) = B x (cls token + # patch tokens) x Nh x (dim // Nh)
|
| 130 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # B x (cls token + # patch tokens) x dim
|
| 131 |
+
x = self.proj(x) # B x (cls token + # patch tokens) x dim
|
| 132 |
+
x = self.proj_drop(x)
|
| 133 |
+
return x, attn
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Block(nn.Module):
|
| 137 |
+
def __init__(self,
|
| 138 |
+
dim, num_heads,
|
| 139 |
+
mlp_ratio=4.,
|
| 140 |
+
qkv_bias=False,
|
| 141 |
+
qk_scale=None,
|
| 142 |
+
drop=0.,
|
| 143 |
+
attn_drop=0.,
|
| 144 |
+
drop_path=0.,
|
| 145 |
+
act_layer=nn.GELU,
|
| 146 |
+
norm_layer=nn.LayerNorm):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.norm1 = norm_layer(dim)
|
| 149 |
+
self.attn = Attention(
|
| 150 |
+
dim,
|
| 151 |
+
num_heads=num_heads,
|
| 152 |
+
qkv_bias=qkv_bias,
|
| 153 |
+
qk_scale=qk_scale,
|
| 154 |
+
attn_drop=attn_drop,
|
| 155 |
+
proj_drop=drop
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 159 |
+
|
| 160 |
+
self.norm2 = norm_layer(dim)
|
| 161 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 162 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 163 |
+
|
| 164 |
+
def forward(self, x, return_attention=False):
|
| 165 |
+
y, attn = self.attn(self.norm1(x))
|
| 166 |
+
if return_attention:
|
| 167 |
+
return attn
|
| 168 |
+
x = x + self.drop_path(y)
|
| 169 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class PatchEmbed(nn.Module):
|
| 174 |
+
""" Image to Patch Embedding"""
|
| 175 |
+
def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768):
|
| 176 |
+
super().__init__()
|
| 177 |
+
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
| 178 |
+
self.img_size = img_size
|
| 179 |
+
self.patch_size = patch_size
|
| 180 |
+
self.num_patches = num_patches
|
| 181 |
+
|
| 182 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 183 |
+
|
| 184 |
+
def forward(self, x):
|
| 185 |
+
B, C, H, W = x.shape
|
| 186 |
+
x = self.proj(x)
|
| 187 |
+
x = x.flatten(2).transpose(1, 2) # B x (P_H * P_W) x C
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class VisionTransformer(nn.Module):
|
| 192 |
+
""" Vision Transformer """
|
| 193 |
+
def __init__(self,
|
| 194 |
+
img_size=(224, 224),
|
| 195 |
+
patch_size=16,
|
| 196 |
+
in_chans=3,
|
| 197 |
+
num_classes=0,
|
| 198 |
+
embed_dim=768,
|
| 199 |
+
depth=12,
|
| 200 |
+
num_heads=12,
|
| 201 |
+
mlp_ratio=4.,
|
| 202 |
+
qkv_bias=False,
|
| 203 |
+
qk_scale=None,
|
| 204 |
+
drop_rate=0.,
|
| 205 |
+
attn_drop_rate=0.,
|
| 206 |
+
drop_path_rate=0.,
|
| 207 |
+
norm_layer=nn.LayerNorm):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.num_features = self.embed_dim = embed_dim
|
| 210 |
+
|
| 211 |
+
self.patch_embed = PatchEmbed(
|
| 212 |
+
img_size=(224, 224), # noel: this is to load pretrained model.
|
| 213 |
+
patch_size=patch_size,
|
| 214 |
+
in_chans=in_chans,
|
| 215 |
+
embed_dim=embed_dim
|
| 216 |
+
)
|
| 217 |
+
num_patches = self.patch_embed.num_patches
|
| 218 |
+
|
| 219 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 220 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 221 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 222 |
+
|
| 223 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 224 |
+
self.blocks = nn.ModuleList([
|
| 225 |
+
Block(
|
| 226 |
+
dim=embed_dim,
|
| 227 |
+
num_heads=num_heads,
|
| 228 |
+
mlp_ratio=mlp_ratio,
|
| 229 |
+
qkv_bias=qkv_bias,
|
| 230 |
+
qk_scale=qk_scale,
|
| 231 |
+
drop=drop_rate,
|
| 232 |
+
attn_drop=attn_drop_rate,
|
| 233 |
+
drop_path=dpr[i],
|
| 234 |
+
norm_layer=norm_layer
|
| 235 |
+
) for i in range(depth)])
|
| 236 |
+
self.norm = norm_layer(embed_dim)
|
| 237 |
+
|
| 238 |
+
# Classifier head
|
| 239 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 240 |
+
|
| 241 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 242 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 243 |
+
self.apply(self._init_weights)
|
| 244 |
+
|
| 245 |
+
self.depth = depth
|
| 246 |
+
self.embed_dim = self.n_embs = embed_dim
|
| 247 |
+
self.mlp_ratio = mlp_ratio
|
| 248 |
+
self.n_heads = num_heads
|
| 249 |
+
self.patch_size = patch_size
|
| 250 |
+
|
| 251 |
+
def _init_weights(self, m):
|
| 252 |
+
if isinstance(m, nn.Linear):
|
| 253 |
+
trunc_normal_(m.weight, std=.02)
|
| 254 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 255 |
+
nn.init.constant_(m.bias, 0)
|
| 256 |
+
elif isinstance(m, nn.LayerNorm):
|
| 257 |
+
nn.init.constant_(m.bias, 0)
|
| 258 |
+
nn.init.constant_(m.weight, 1.0)
|
| 259 |
+
|
| 260 |
+
def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor:
|
| 261 |
+
"""Pad some pixels to make the input size divisible by the patch size."""
|
| 262 |
+
B, _, H_0, W_0 = x.shape
|
| 263 |
+
pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
|
| 264 |
+
pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
|
| 265 |
+
|
| 266 |
+
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
def prepare_tokens(self, x):
|
| 270 |
+
B, nc, h, w = x.shape
|
| 271 |
+
x: torch.Tensor = self.make_input_divisible(x)
|
| 272 |
+
patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
| 273 |
+
|
| 274 |
+
x = self.patch_embed(x) # patch linear embedding
|
| 275 |
+
|
| 276 |
+
# add positional encoding to each token
|
| 277 |
+
# add the [CLS] token to the embed patch tokens
|
| 278 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 279 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 280 |
+
x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
|
| 281 |
+
return self.pos_drop(x)
|
| 282 |
+
|
| 283 |
+
@staticmethod
|
| 284 |
+
def split_token(x, token_type: str):
|
| 285 |
+
if token_type == "cls":
|
| 286 |
+
return x[:, 0, :]
|
| 287 |
+
elif token_type == "patch":
|
| 288 |
+
return x[:, 1:, :]
|
| 289 |
+
else:
|
| 290 |
+
return x
|
| 291 |
+
|
| 292 |
+
# noel
|
| 293 |
+
def forward(self, x, layer: Optional[str] = None):
|
| 294 |
+
x: torch.Tensor = self.prepare_tokens(x)
|
| 295 |
+
|
| 296 |
+
features: dict = {}
|
| 297 |
+
for i, blk in enumerate(self.blocks):
|
| 298 |
+
x = blk(x)
|
| 299 |
+
features[f"layer{i + 1}"] = self.norm(x)
|
| 300 |
+
|
| 301 |
+
if layer is not None:
|
| 302 |
+
return features[layer]
|
| 303 |
+
else:
|
| 304 |
+
return features
|
| 305 |
+
|
| 306 |
+
# noel - for DINO's visual
|
| 307 |
+
def get_last_selfattention(self, x):
|
| 308 |
+
x = self.prepare_tokens(x)
|
| 309 |
+
for i, blk in enumerate(self.blocks):
|
| 310 |
+
if i < len(self.blocks) - 1:
|
| 311 |
+
x = blk(x)
|
| 312 |
+
else:
|
| 313 |
+
# return attention of the last block
|
| 314 |
+
return blk(x, return_attention=True)
|
| 315 |
+
|
| 316 |
+
def get_tokens(
|
| 317 |
+
self,
|
| 318 |
+
x,
|
| 319 |
+
layers: list,
|
| 320 |
+
patch_tokens: bool = False,
|
| 321 |
+
norm: bool = True,
|
| 322 |
+
input_tokens: bool = False,
|
| 323 |
+
post_pe: bool = False
|
| 324 |
+
):
|
| 325 |
+
"""Return intermediate tokens."""
|
| 326 |
+
list_tokens: list = []
|
| 327 |
+
|
| 328 |
+
B = x.shape[0]
|
| 329 |
+
x = self.patch_embed(x)
|
| 330 |
+
|
| 331 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 332 |
+
|
| 333 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 334 |
+
|
| 335 |
+
if input_tokens:
|
| 336 |
+
list_tokens.append(x)
|
| 337 |
+
|
| 338 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
| 339 |
+
x = x + pos_embed
|
| 340 |
+
|
| 341 |
+
if post_pe:
|
| 342 |
+
list_tokens.append(x)
|
| 343 |
+
|
| 344 |
+
x = self.pos_drop(x)
|
| 345 |
+
|
| 346 |
+
for i, blk in enumerate(self.blocks):
|
| 347 |
+
x = blk(x) # B x # patches x dim
|
| 348 |
+
if layers is None or i in layers:
|
| 349 |
+
list_tokens.append(self.norm(x) if norm else x)
|
| 350 |
+
|
| 351 |
+
tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
|
| 352 |
+
|
| 353 |
+
if not patch_tokens:
|
| 354 |
+
return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
|
| 355 |
+
|
| 356 |
+
else:
|
| 357 |
+
return tokens
|
| 358 |
+
|
| 359 |
+
def forward_features(self, x):
|
| 360 |
+
B = x.shape[0]
|
| 361 |
+
x = self.patch_embed(x)
|
| 362 |
+
|
| 363 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 364 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 365 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
| 366 |
+
x = x + pos_embed
|
| 367 |
+
x = self.pos_drop(x)
|
| 368 |
+
|
| 369 |
+
for blk in self.blocks:
|
| 370 |
+
x = blk(x)
|
| 371 |
+
|
| 372 |
+
if self.norm is not None:
|
| 373 |
+
x = self.norm(x)
|
| 374 |
+
|
| 375 |
+
return x[:, 0]
|
| 376 |
+
|
| 377 |
+
def interpolate_pos_encoding(self, x, pos_embed, size):
|
| 378 |
+
"""Interpolate the learnable positional encoding to match the number of patches.
|
| 379 |
+
|
| 380 |
+
x: B x (1 + N patches) x dim_embedding
|
| 381 |
+
pos_embed: B x (1 + N patches) x dim_embedding
|
| 382 |
+
|
| 383 |
+
return interpolated positional embedding
|
| 384 |
+
"""
|
| 385 |
+
npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
|
| 386 |
+
N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
|
| 387 |
+
if npatch == N:
|
| 388 |
+
return pos_embed
|
| 389 |
+
class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
|
| 390 |
+
|
| 391 |
+
dim = x.shape[-1] # dimension of embeddings
|
| 392 |
+
pos_embed = nn.functional.interpolate(
|
| 393 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
|
| 394 |
+
size=size,
|
| 395 |
+
mode='bicubic',
|
| 396 |
+
align_corners=False
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 400 |
+
pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
| 401 |
+
return pos_embed
|
| 402 |
+
|
| 403 |
+
def forward_selfattention(self, x, return_interm_attn=False):
|
| 404 |
+
B, nc, w, h = x.shape
|
| 405 |
+
N = self.pos_embed.shape[1] - 1
|
| 406 |
+
x = self.patch_embed(x)
|
| 407 |
+
|
| 408 |
+
# interpolate patch embeddings
|
| 409 |
+
dim = x.shape[-1]
|
| 410 |
+
w0 = w // self.patch_embed.patch_size
|
| 411 |
+
h0 = h // self.patch_embed.patch_size
|
| 412 |
+
class_pos_embed = self.pos_embed[:, 0]
|
| 413 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
| 414 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 415 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
| 416 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
| 417 |
+
mode='bicubic'
|
| 418 |
+
)
|
| 419 |
+
if w0 != patch_pos_embed.shape[-2]:
|
| 420 |
+
helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
|
| 421 |
+
patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
|
| 422 |
+
if h0 != patch_pos_embed.shape[-1]:
|
| 423 |
+
helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
|
| 424 |
+
pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)
|
| 425 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 426 |
+
pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
| 427 |
+
|
| 428 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # self.cls_token: 1 x 1 x emb_dim -> ?
|
| 429 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 430 |
+
x = x + pos_embed
|
| 431 |
+
x = self.pos_drop(x)
|
| 432 |
+
|
| 433 |
+
if return_interm_attn:
|
| 434 |
+
list_attn = []
|
| 435 |
+
for i, blk in enumerate(self.blocks):
|
| 436 |
+
attn = blk(x, return_attention=True)
|
| 437 |
+
x = blk(x)
|
| 438 |
+
list_attn.append(attn)
|
| 439 |
+
return torch.cat(list_attn, dim=0)
|
| 440 |
+
|
| 441 |
+
else:
|
| 442 |
+
for i, blk in enumerate(self.blocks):
|
| 443 |
+
if i < len(self.blocks) - 1:
|
| 444 |
+
x = blk(x)
|
| 445 |
+
else:
|
| 446 |
+
return blk(x, return_attention=True)
|
| 447 |
+
|
| 448 |
+
def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
|
| 449 |
+
B = x.shape[0]
|
| 450 |
+
x = self.patch_embed(x)
|
| 451 |
+
|
| 452 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 453 |
+
|
| 454 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 455 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
| 456 |
+
x = x + pos_embed
|
| 457 |
+
x = self.pos_drop(x)
|
| 458 |
+
|
| 459 |
+
# we will return the [CLS] tokens from the `n` last blocks
|
| 460 |
+
output = []
|
| 461 |
+
for i, blk in enumerate(self.blocks):
|
| 462 |
+
x = blk(x)
|
| 463 |
+
if len(self.blocks) - i <= n:
|
| 464 |
+
# get only CLS token (B x dim)
|
| 465 |
+
output.append(self.norm(x)[:, 0])
|
| 466 |
+
if return_patch_avgpool:
|
| 467 |
+
x = self.norm(x)
|
| 468 |
+
# In addition to the [CLS] tokens from the `n` last blocks, we also return
|
| 469 |
+
# the patch tokens from the last block. This is useful for linear eval.
|
| 470 |
+
output.append(torch.mean(x[:, 1:], dim=1))
|
| 471 |
+
return torch.cat(output, dim=-1)
|
| 472 |
+
|
| 473 |
+
def return_patch_emb_from_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
|
| 474 |
+
"""Return intermediate patch embeddings, rather than CLS token, from the last n blocks."""
|
| 475 |
+
B = x.shape[0]
|
| 476 |
+
x = self.patch_embed(x)
|
| 477 |
+
|
| 478 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 479 |
+
|
| 480 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 481 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
| 482 |
+
x = x + pos_embed
|
| 483 |
+
x = self.pos_drop(x)
|
| 484 |
+
|
| 485 |
+
# we will return the [CLS] tokens from the `n` last blocks
|
| 486 |
+
output = []
|
| 487 |
+
for i, blk in enumerate(self.blocks):
|
| 488 |
+
x = blk(x)
|
| 489 |
+
if len(self.blocks) - i <= n:
|
| 490 |
+
output.append(self.norm(x)[:, 1:]) # get only CLS token (B x dim)
|
| 491 |
+
|
| 492 |
+
if return_patch_avgpool:
|
| 493 |
+
x = self.norm(x)
|
| 494 |
+
# In addition to the [CLS] tokens from the `n` last blocks, we also return
|
| 495 |
+
# the patch tokens from the last block. This is useful for linear eval.
|
| 496 |
+
output.append(torch.mean(x[:, 1:], dim=1))
|
| 497 |
+
return torch.stack(output, dim=-1) # B x n_patches x dim x n
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def deit_tiny(patch_size=16, **kwargs):
|
| 501 |
+
model = VisionTransformer(
|
| 502 |
+
patch_size=patch_size,
|
| 503 |
+
embed_dim=192,
|
| 504 |
+
depth=12,
|
| 505 |
+
num_heads=3,
|
| 506 |
+
mlp_ratio=4,
|
| 507 |
+
qkv_bias=True,
|
| 508 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 509 |
+
**kwargs)
|
| 510 |
+
return model
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def deit_small(patch_size=16, **kwargs):
|
| 514 |
+
depth = kwargs.pop("depth") if "depth" in kwargs else 12
|
| 515 |
+
model = VisionTransformer(
|
| 516 |
+
patch_size=patch_size,
|
| 517 |
+
embed_dim=384,
|
| 518 |
+
depth=depth,
|
| 519 |
+
num_heads=6,
|
| 520 |
+
mlp_ratio=4,
|
| 521 |
+
qkv_bias=True,
|
| 522 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 523 |
+
**kwargs
|
| 524 |
+
)
|
| 525 |
+
return model
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def vit_base(patch_size=16, **kwargs):
|
| 529 |
+
model = VisionTransformer(
|
| 530 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
| 531 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 532 |
+
return model
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class DINOHead(nn.Module):
|
| 536 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
|
| 537 |
+
super().__init__()
|
| 538 |
+
nlayers = max(nlayers, 1)
|
| 539 |
+
if nlayers == 1:
|
| 540 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
| 541 |
+
else:
|
| 542 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
| 543 |
+
if use_bn:
|
| 544 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 545 |
+
layers.append(nn.GELU())
|
| 546 |
+
for _ in range(nlayers - 2):
|
| 547 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
| 548 |
+
if use_bn:
|
| 549 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 550 |
+
layers.append(nn.GELU())
|
| 551 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
| 552 |
+
self.mlp = nn.Sequential(*layers)
|
| 553 |
+
self.apply(self._init_weights)
|
| 554 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 555 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 556 |
+
if norm_last_layer:
|
| 557 |
+
self.last_layer.weight_g.requires_grad = False
|
| 558 |
+
|
| 559 |
+
def _init_weights(self, m):
|
| 560 |
+
if isinstance(m, nn.Linear):
|
| 561 |
+
trunc_normal_(m.weight, std=.02)
|
| 562 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 563 |
+
nn.init.constant_(m.bias, 0)
|
| 564 |
+
|
| 565 |
+
def forward(self, x):
|
| 566 |
+
x = self.mlp(x)
|
| 567 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
| 568 |
+
x = self.last_layer(x)
|
| 569 |
+
return x
|
resources/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
resources/0053.jpg
ADDED
|
resources/0236.jpg
ADDED
|
resources/0239.jpg
ADDED
|
resources/0403.jpg
ADDED
|
resources/0412.jpg
ADDED
|
resources/ILSVRC2012_test_00005309.jpg
ADDED
|
resources/ILSVRC2012_test_00012622.jpg
ADDED
|
resources/ILSVRC2012_test_00022698.jpg
ADDED
|
resources/ILSVRC2012_test_00040725.jpg
ADDED
|
resources/ILSVRC2012_test_00075738.jpg
ADDED
|
resources/ILSVRC2012_test_00080683.jpg
ADDED
|
resources/ILSVRC2012_test_00085874.jpg
ADDED
|
resources/im052.jpg
ADDED
|
resources/sun_ainjbonxmervsvpv.jpg
ADDED
|
resources/sun_alfntqzssslakmss.jpg
ADDED
|
resources/sun_amnrcxhisjfrliwa.jpg
ADDED
|
resources/sun_bvyxpvkouzlfwwod.jpg
ADDED
|