File size: 2,792 Bytes
231edce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import warnings
import torch.hub
import torch.nn as nn
from torchvision.models.video.resnet import R2Plus1dStem, BasicBlock, Bottleneck
from .utils import _generic_resnet, R2Plus1dStem_Pool, Conv2Plus1D, model_urls
__all__ = ["r2plus1d_34", "r2plus1d_152"]
def r2plus1d_34(pretraining="", use_pool1=False, progress=False, **kwargs):
avail_pretrainings = [
"kinetics_8frms",
"kinetics_32frms",
"ig65m_8frms",
"ig65m_32frms",
"32_ig65m"
]
if pretraining in avail_pretrainings:
arch = "r2plus1d_34_" + pretraining
pretrained = True
else:
warnings.warn(
f"Unrecognized pretraining dataset, continuing with randomly initialized network."
" Available pretrainings: {avail_pretrainings}",
UserWarning,
)
arch = "r2plus1d_34"
pretrained = False
model = _generic_resnet(
arch,
pretrained,
progress,
block=BasicBlock,
conv_makers=[Conv2Plus1D] * 4,
layers=[3, 4, 6, 3],
stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem,
**kwargs,
)
# We need exact Caffe2 momentum for BatchNorm scaling
for m in model.modules():
if isinstance(m, nn.BatchNorm3d):
m.eps = 1e-3
m.momentum = 0.9
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
model_urls[arch], progress=progress
)
model.load_state_dict(state_dict)
return model
def r2plus1d_152(pretraining="", use_pool1=True, progress=False, **kwargs):
avail_pretrainings = [
"ig65m_32frms",
"ig_ft_kinetics_32frms",
"sports1m_32frms",
"sports1m_ft_kinetics_32frms",
]
if pretraining in avail_pretrainings:
arch = "r2plus1d_" + pretraining
pretrained = True
else:
warnings.warn(
f"Unrecognized pretraining dataset, continuing with randomly initialized network."
" Available pretrainings: {avail_pretrainings}",
UserWarning,
)
arch = "r2plus1d_34"
pretrained = False
model = _generic_resnet(
arch,
pretrained,
progress,
block=Bottleneck,
conv_makers=[Conv2Plus1D] * 4,
layers=[3, 8, 36, 3],
stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem,
**kwargs,
)
# We need exact Caffe2 momentum for BatchNorm scaling
for m in model.modules():
if isinstance(m, nn.BatchNorm3d):
m.eps = 1e-3
m.momentum = 0.9
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
model_urls[arch], progress=progress
)
model.load_state_dict(state_dict)
return model
|