|
import warnings |
|
|
|
import torch.hub |
|
import torch.nn as nn |
|
from torchvision.models.video.resnet import R2Plus1dStem, BasicBlock, Bottleneck |
|
|
|
|
|
from .utils import _generic_resnet, R2Plus1dStem_Pool, Conv2Plus1D, model_urls |
|
|
|
|
|
__all__ = ["r2plus1d_34", "r2plus1d_152"] |
|
|
|
|
|
def r2plus1d_34(pretraining="", use_pool1=False, progress=False, **kwargs): |
|
avail_pretrainings = [ |
|
"kinetics_8frms", |
|
"kinetics_32frms", |
|
"ig65m_8frms", |
|
"ig65m_32frms", |
|
"32_ig65m" |
|
] |
|
if pretraining in avail_pretrainings: |
|
arch = "r2plus1d_34_" + pretraining |
|
pretrained = True |
|
else: |
|
warnings.warn( |
|
f"Unrecognized pretraining dataset, continuing with randomly initialized network." |
|
" Available pretrainings: {avail_pretrainings}", |
|
UserWarning, |
|
) |
|
arch = "r2plus1d_34" |
|
pretrained = False |
|
|
|
model = _generic_resnet( |
|
arch, |
|
pretrained, |
|
progress, |
|
block=BasicBlock, |
|
conv_makers=[Conv2Plus1D] * 4, |
|
layers=[3, 4, 6, 3], |
|
stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem, |
|
**kwargs, |
|
) |
|
|
|
for m in model.modules(): |
|
if isinstance(m, nn.BatchNorm3d): |
|
m.eps = 1e-3 |
|
m.momentum = 0.9 |
|
|
|
if pretrained: |
|
state_dict = torch.hub.load_state_dict_from_url( |
|
model_urls[arch], progress=progress |
|
) |
|
model.load_state_dict(state_dict) |
|
|
|
return model |
|
|
|
|
|
def r2plus1d_152(pretraining="", use_pool1=True, progress=False, **kwargs): |
|
avail_pretrainings = [ |
|
"ig65m_32frms", |
|
"ig_ft_kinetics_32frms", |
|
"sports1m_32frms", |
|
"sports1m_ft_kinetics_32frms", |
|
] |
|
if pretraining in avail_pretrainings: |
|
arch = "r2plus1d_" + pretraining |
|
pretrained = True |
|
else: |
|
warnings.warn( |
|
f"Unrecognized pretraining dataset, continuing with randomly initialized network." |
|
" Available pretrainings: {avail_pretrainings}", |
|
UserWarning, |
|
) |
|
arch = "r2plus1d_34" |
|
pretrained = False |
|
|
|
model = _generic_resnet( |
|
arch, |
|
pretrained, |
|
progress, |
|
block=Bottleneck, |
|
conv_makers=[Conv2Plus1D] * 4, |
|
layers=[3, 8, 36, 3], |
|
stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem, |
|
**kwargs, |
|
) |
|
|
|
for m in model.modules(): |
|
if isinstance(m, nn.BatchNorm3d): |
|
m.eps = 1e-3 |
|
m.momentum = 0.9 |
|
|
|
if pretrained: |
|
state_dict = torch.hub.load_state_dict_from_url( |
|
model_urls[arch], progress=progress |
|
) |
|
model.load_state_dict(state_dict) |
|
|
|
return model |
|
|