|
"""
|
|
https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py
|
|
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
from collections import OrderedDict
|
|
from typing import Dict, List
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class IntermediateLayerGetter(nn.ModuleDict):
|
|
"""
|
|
Module wrapper that returns intermediate layers from a model
|
|
|
|
It has a strong assumption that the modules have been registered
|
|
into the model in the same order as they are used.
|
|
This means that one should **not** reuse the same nn.Module
|
|
twice in the forward if you want this to work.
|
|
|
|
Additionally, it is only able to query submodules that are directly
|
|
assigned to the model. So if `model` is passed, `model.feature1` can
|
|
be returned, but not `model.feature1.layer2`.
|
|
"""
|
|
|
|
_version = 3
|
|
|
|
def __init__(self, model: nn.Module, return_layers: List[str]) -> None:
|
|
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
|
|
raise ValueError(
|
|
"return_layers are not present in model. {}".format(
|
|
[name for name, _ in model.named_children()]
|
|
)
|
|
)
|
|
orig_return_layers = return_layers
|
|
return_layers = {str(k): str(k) for k in return_layers}
|
|
layers = OrderedDict()
|
|
for name, module in model.named_children():
|
|
layers[name] = module
|
|
if name in return_layers:
|
|
del return_layers[name]
|
|
if not return_layers:
|
|
break
|
|
|
|
super().__init__(layers)
|
|
self.return_layers = orig_return_layers
|
|
|
|
def forward(self, x):
|
|
outputs = []
|
|
for name, module in self.items():
|
|
x = module(x)
|
|
if name in self.return_layers:
|
|
outputs.append(x)
|
|
|
|
return outputs
|
|
|