ianpan's picture
Initial commit
231edce
import warnings
import torch.hub
import torch.nn as nn
from torchvision.models.video.resnet import R2Plus1dStem, BasicBlock, Bottleneck
from .utils import _generic_resnet, R2Plus1dStem_Pool, Conv2Plus1D, model_urls
__all__ = ["r2plus1d_34", "r2plus1d_152"]
def r2plus1d_34(pretraining="", use_pool1=False, progress=False, **kwargs):
avail_pretrainings = [
"kinetics_8frms",
"kinetics_32frms",
"ig65m_8frms",
"ig65m_32frms",
"32_ig65m"
]
if pretraining in avail_pretrainings:
arch = "r2plus1d_34_" + pretraining
pretrained = True
else:
warnings.warn(
f"Unrecognized pretraining dataset, continuing with randomly initialized network."
" Available pretrainings: {avail_pretrainings}",
UserWarning,
)
arch = "r2plus1d_34"
pretrained = False
model = _generic_resnet(
arch,
pretrained,
progress,
block=BasicBlock,
conv_makers=[Conv2Plus1D] * 4,
layers=[3, 4, 6, 3],
stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem,
**kwargs,
)
# We need exact Caffe2 momentum for BatchNorm scaling
for m in model.modules():
if isinstance(m, nn.BatchNorm3d):
m.eps = 1e-3
m.momentum = 0.9
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
model_urls[arch], progress=progress
)
model.load_state_dict(state_dict)
return model
def r2plus1d_152(pretraining="", use_pool1=True, progress=False, **kwargs):
avail_pretrainings = [
"ig65m_32frms",
"ig_ft_kinetics_32frms",
"sports1m_32frms",
"sports1m_ft_kinetics_32frms",
]
if pretraining in avail_pretrainings:
arch = "r2plus1d_" + pretraining
pretrained = True
else:
warnings.warn(
f"Unrecognized pretraining dataset, continuing with randomly initialized network."
" Available pretrainings: {avail_pretrainings}",
UserWarning,
)
arch = "r2plus1d_34"
pretrained = False
model = _generic_resnet(
arch,
pretrained,
progress,
block=Bottleneck,
conv_makers=[Conv2Plus1D] * 4,
layers=[3, 8, 36, 3],
stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem,
**kwargs,
)
# We need exact Caffe2 momentum for BatchNorm scaling
for m in model.modules():
if isinstance(m, nn.BatchNorm3d):
m.eps = 1e-3
m.momentum = 0.9
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
model_urls[arch], progress=progress
)
model.load_state_dict(state_dict)
return model