File size: 1,880 Bytes
e85fecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""

https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py



Copyright(c) 2023 lyuwenyu. All Rights Reserved.

"""

from collections import OrderedDict
from typing import Dict, List

import torch.nn as nn


class IntermediateLayerGetter(nn.ModuleDict):
    """

    Module wrapper that returns intermediate layers from a model



    It has a strong assumption that the modules have been registered

    into the model in the same order as they are used.

    This means that one should **not** reuse the same nn.Module

    twice in the forward if you want this to work.



    Additionally, it is only able to query submodules that are directly

    assigned to the model. So if `model` is passed, `model.feature1` can

    be returned, but not `model.feature1.layer2`.

    """

    _version = 3

    def __init__(self, model: nn.Module, return_layers: List[str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError(
                "return_layers are not present in model. {}".format(
                    [name for name, _ in model.named_children()]
                )
            )
        orig_return_layers = return_layers
        return_layers = {str(k): str(k) for k in return_layers}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super().__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x):
        outputs = []
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                outputs.append(x)

        return outputs