Spaces:
Runtime error
Runtime error
import warnings | |
import torch.hub | |
import torch.nn as nn | |
from torchvision.models.video.resnet import R2Plus1dStem, BasicBlock, Bottleneck | |
from .utils import _generic_resnet, R2Plus1dStem_Pool, Conv2Plus1D, model_urls | |
__all__ = ["r2plus1d_34", "r2plus1d_152"] | |
def r2plus1d_34(pretraining="", use_pool1=False, progress=False, **kwargs): | |
avail_pretrainings = [ | |
"kinetics_8frms", | |
"kinetics_32frms", | |
"ig65m_8frms", | |
"ig65m_32frms", | |
"32_ig65m" | |
] | |
if pretraining in avail_pretrainings: | |
arch = "r2plus1d_34_" + pretraining | |
pretrained = True | |
else: | |
warnings.warn( | |
f"Unrecognized pretraining dataset, continuing with randomly initialized network." | |
" Available pretrainings: {avail_pretrainings}", | |
UserWarning, | |
) | |
arch = "r2plus1d_34" | |
pretrained = False | |
model = _generic_resnet( | |
arch, | |
pretrained, | |
progress, | |
block=BasicBlock, | |
conv_makers=[Conv2Plus1D] * 4, | |
layers=[3, 4, 6, 3], | |
stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem, | |
**kwargs, | |
) | |
# We need exact Caffe2 momentum for BatchNorm scaling | |
for m in model.modules(): | |
if isinstance(m, nn.BatchNorm3d): | |
m.eps = 1e-3 | |
m.momentum = 0.9 | |
if pretrained: | |
state_dict = torch.hub.load_state_dict_from_url( | |
model_urls[arch], progress=progress | |
) | |
model.load_state_dict(state_dict) | |
return model | |
def r2plus1d_152(pretraining="", use_pool1=True, progress=False, **kwargs): | |
avail_pretrainings = [ | |
"ig65m_32frms", | |
"ig_ft_kinetics_32frms", | |
"sports1m_32frms", | |
"sports1m_ft_kinetics_32frms", | |
] | |
if pretraining in avail_pretrainings: | |
arch = "r2plus1d_" + pretraining | |
pretrained = True | |
else: | |
warnings.warn( | |
f"Unrecognized pretraining dataset, continuing with randomly initialized network." | |
" Available pretrainings: {avail_pretrainings}", | |
UserWarning, | |
) | |
arch = "r2plus1d_34" | |
pretrained = False | |
model = _generic_resnet( | |
arch, | |
pretrained, | |
progress, | |
block=Bottleneck, | |
conv_makers=[Conv2Plus1D] * 4, | |
layers=[3, 8, 36, 3], | |
stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem, | |
**kwargs, | |
) | |
# We need exact Caffe2 momentum for BatchNorm scaling | |
for m in model.modules(): | |
if isinstance(m, nn.BatchNorm3d): | |
m.eps = 1e-3 | |
m.momentum = 0.9 | |
if pretrained: | |
state_dict = torch.hub.load_state_dict_from_url( | |
model_urls[arch], progress=progress | |
) | |
model.load_state_dict(state_dict) | |
return model | |