Spaces:
Configuration error
Configuration error
""" PyTorch FX Based Feature Extraction Helpers | |
Using https://pytorch.org/vision/stable/feature_extraction.html | |
""" | |
from typing import Callable, List, Dict, Union, Type | |
import torch | |
from torch import nn | |
from .features import _get_feature_info | |
try: | |
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor | |
has_fx_feature_extraction = True | |
except ImportError: | |
has_fx_feature_extraction = False | |
# Layers we went to treat as leaf modules | |
from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame | |
from .layers.non_local_attn import BilinearAttnTransform | |
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame | |
# NOTE: By default, any modules from custom_timm.models.layers that we want to treat as leaf modules go here | |
# BUT modules from custom_timm.models should use the registration mechanism below | |
_leaf_modules = { | |
BilinearAttnTransform, # reason: flow control t <= 1 | |
# Reason: get_same_padding has a max which raises a control flow error | |
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, | |
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) | |
} | |
try: | |
from .layers import InplaceAbn | |
_leaf_modules.add(InplaceAbn) | |
except ImportError: | |
pass | |
def register_notrace_module(module: Type[nn.Module]): | |
""" | |
Any module not under timm.models.layers should get this decorator if we don't want to trace through it. | |
""" | |
_leaf_modules.add(module) | |
return module | |
# Functions we want to autowrap (treat them as leaves) | |
_autowrap_functions = set() | |
def register_notrace_function(func: Callable): | |
""" | |
Decorator for functions which ought not to be traced through | |
""" | |
_autowrap_functions.add(func) | |
return func | |
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): | |
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' | |
return _create_feature_extractor( | |
model, return_nodes, | |
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} | |
) | |
class FeatureGraphNet(nn.Module): | |
""" A FX Graph based feature extractor that works with the model feature_info metadata | |
""" | |
def __init__(self, model, out_indices, out_map=None): | |
super().__init__() | |
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' | |
self.feature_info = _get_feature_info(model, out_indices) | |
if out_map is not None: | |
assert len(out_map) == len(out_indices) | |
return_nodes = { | |
info['module']: out_map[i] if out_map is not None else info['module'] | |
for i, info in enumerate(self.feature_info) if i in out_indices} | |
self.graph_module = create_feature_extractor(model, return_nodes) | |
def forward(self, x): | |
return list(self.graph_module(x).values()) | |
class GraphExtractNet(nn.Module): | |
""" A standalone feature extraction wrapper that maps dict -> list or single tensor | |
NOTE: | |
* one can use feature_extractor directly if dictionary output is desired | |
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info | |
metadata for builtin feature extraction mode | |
* create_feature_extractor can be used directly if dictionary output is acceptable | |
Args: | |
model: model to extract features from | |
return_nodes: node names to return features from (dict or list) | |
squeeze_out: if only one output, and output in list format, flatten to single tensor | |
""" | |
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): | |
super().__init__() | |
self.squeeze_out = squeeze_out | |
self.graph_module = create_feature_extractor(model, return_nodes) | |
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: | |
out = list(self.graph_module(x).values()) | |
if self.squeeze_out and len(out) == 1: | |
return out[0] | |
return out | |