Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2023 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ Collection of utils to be used by backbones and their components.""" | |
| import enum | |
| import inspect | |
| from typing import Iterable, List, Optional, Tuple, Union | |
| class BackboneType(enum.Enum): | |
| TIMM = "timm" | |
| TRANSFORMERS = "transformers" | |
| def verify_out_features_out_indices( | |
| out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]] | |
| ): | |
| """ | |
| Verify that out_indices and out_features are valid for the given stage_names. | |
| """ | |
| if stage_names is None: | |
| raise ValueError("Stage_names must be set for transformers backbones") | |
| if out_features is not None: | |
| if not isinstance(out_features, (list,)): | |
| raise ValueError(f"out_features must be a list {type(out_features)}") | |
| if any(feat not in stage_names for feat in out_features): | |
| raise ValueError(f"out_features must be a subset of stage_names: {stage_names} got {out_features}") | |
| if out_indices is not None: | |
| if not isinstance(out_indices, (list, tuple)): | |
| raise ValueError(f"out_indices must be a list or tuple, got {type(out_indices)}") | |
| if any(idx >= len(stage_names) for idx in out_indices): | |
| raise ValueError("out_indices must be valid indices for stage_names {stage_names}, got {out_indices}") | |
| if out_features is not None and out_indices is not None: | |
| if len(out_features) != len(out_indices): | |
| raise ValueError("out_features and out_indices should have the same length if both are set") | |
| if out_features != [stage_names[idx] for idx in out_indices]: | |
| raise ValueError("out_features and out_indices should correspond to the same stages if both are set") | |
| def _align_output_features_output_indices( | |
| out_features: Optional[List[str]], | |
| out_indices: Optional[Union[List[int], Tuple[int]]], | |
| stage_names: List[str], | |
| ): | |
| """ | |
| Finds the corresponding `out_features` and `out_indices` for the given `stage_names`. | |
| The logic is as follows: | |
| - `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the | |
| `out_indices`. | |
| - `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the | |
| `out_features`. | |
| - `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage. | |
| - `out_indices` and `out_features` set: input `out_indices` and `out_features` are returned. | |
| Args: | |
| out_features (`List[str]`): The names of the features for the backbone to output. | |
| out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output. | |
| stage_names (`List[str]`): The names of the stages of the backbone. | |
| """ | |
| if out_indices is None and out_features is None: | |
| out_indices = [len(stage_names) - 1] | |
| out_features = [stage_names[-1]] | |
| elif out_indices is None and out_features is not None: | |
| out_indices = [stage_names.index(layer) for layer in out_features] | |
| elif out_features is None and out_indices is not None: | |
| out_features = [stage_names[idx] for idx in out_indices] | |
| return out_features, out_indices | |
| def get_aligned_output_features_output_indices( | |
| out_features: Optional[List[str]], | |
| out_indices: Optional[Union[List[int], Tuple[int]]], | |
| stage_names: List[str], | |
| ) -> Tuple[List[str], List[int]]: | |
| """ | |
| Get the `out_features` and `out_indices` so that they are aligned. | |
| The logic is as follows: | |
| - `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the | |
| `out_indices`. | |
| - `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the | |
| `out_features`. | |
| - `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage. | |
| - `out_indices` and `out_features` set: they are verified to be aligned. | |
| Args: | |
| out_features (`List[str]`): The names of the features for the backbone to output. | |
| out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output. | |
| stage_names (`List[str]`): The names of the stages of the backbone. | |
| """ | |
| # First verify that the out_features and out_indices are valid | |
| verify_out_features_out_indices(out_features=out_features, out_indices=out_indices, stage_names=stage_names) | |
| output_features, output_indices = _align_output_features_output_indices( | |
| out_features=out_features, out_indices=out_indices, stage_names=stage_names | |
| ) | |
| # Verify that the aligned out_features and out_indices are valid | |
| verify_out_features_out_indices(out_features=output_features, out_indices=output_indices, stage_names=stage_names) | |
| return output_features, output_indices | |
| class BackboneMixin: | |
| backbone_type: Optional[BackboneType] = None | |
| def _init_timm_backbone(self, config) -> None: | |
| """ | |
| Initialize the backbone model from timm The backbone must already be loaded to self._backbone | |
| """ | |
| if getattr(self, "_backbone", None) is None: | |
| raise ValueError("self._backbone must be set before calling _init_timm_backbone") | |
| # These will diagree with the defaults for the transformers models e.g. for resnet50 | |
| # the transformer model has out_features = ['stem', 'stage1', 'stage2', 'stage3', 'stage4'] | |
| # the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4'] | |
| self.stage_names = [stage["module"] for stage in self._backbone.feature_info.info] | |
| self.num_features = [stage["num_chs"] for stage in self._backbone.feature_info.info] | |
| out_indices = self._backbone.feature_info.out_indices | |
| out_features = self._backbone.feature_info.module_name() | |
| # We verify the out indices and out features are valid | |
| verify_out_features_out_indices( | |
| out_features=out_features, out_indices=out_indices, stage_names=self.stage_names | |
| ) | |
| self._out_features, self._out_indices = out_features, out_indices | |
| def _init_transformers_backbone(self, config) -> None: | |
| stage_names = getattr(config, "stage_names") | |
| out_features = getattr(config, "out_features", None) | |
| out_indices = getattr(config, "out_indices", None) | |
| self.stage_names = stage_names | |
| self._out_features, self._out_indices = get_aligned_output_features_output_indices( | |
| out_features=out_features, out_indices=out_indices, stage_names=stage_names | |
| ) | |
| # Number of channels for each stage. This is set in the transformer backbone model init | |
| self.num_features = None | |
| def _init_backbone(self, config) -> None: | |
| """ | |
| Method to initialize the backbone. This method is called by the constructor of the base class after the | |
| pretrained model weights have been loaded. | |
| """ | |
| self.config = config | |
| self.use_timm_backbone = getattr(config, "use_timm_backbone", False) | |
| self.backbone_type = BackboneType.TIMM if self.use_timm_backbone else BackboneType.TRANSFORMERS | |
| if self.backbone_type == BackboneType.TIMM: | |
| self._init_timm_backbone(config) | |
| elif self.backbone_type == BackboneType.TRANSFORMERS: | |
| self._init_transformers_backbone(config) | |
| else: | |
| raise ValueError(f"backbone_type {self.backbone_type} not supported.") | |
| def out_features(self): | |
| return self._out_features | |
| def out_features(self, out_features: List[str]): | |
| """ | |
| Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. | |
| """ | |
| self._out_features, self._out_indices = get_aligned_output_features_output_indices( | |
| out_features=out_features, out_indices=None, stage_names=self.stage_names | |
| ) | |
| def out_indices(self): | |
| return self._out_indices | |
| def out_indices(self, out_indices: Union[Tuple[int], List[int]]): | |
| """ | |
| Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. | |
| """ | |
| self._out_features, self._out_indices = get_aligned_output_features_output_indices( | |
| out_features=None, out_indices=out_indices, stage_names=self.stage_names | |
| ) | |
| def out_feature_channels(self): | |
| # the current backbones will output the number of channels for each stage | |
| # even if that stage is not in the out_features list. | |
| return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)} | |
| def channels(self): | |
| return [self.out_feature_channels[name] for name in self.out_features] | |
| def forward_with_filtered_kwargs(self, *args, **kwargs): | |
| signature = dict(inspect.signature(self.forward).parameters) | |
| filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature} | |
| return self(*args, **filtered_kwargs) | |
| def forward( | |
| self, | |
| pixel_values, | |
| output_hidden_states: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ): | |
| raise NotImplementedError("This method should be implemented by the derived class.") | |
| def to_dict(self): | |
| """ | |
| Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to | |
| include the `out_features` and `out_indices` attributes. | |
| """ | |
| output = super().to_dict() | |
| output["out_features"] = output.pop("_out_features") | |
| output["out_indices"] = output.pop("_out_indices") | |
| return output | |
| class BackboneConfigMixin: | |
| """ | |
| A Mixin to support handling the `out_features` and `out_indices` attributes for the backbone configurations. | |
| """ | |
| def out_features(self): | |
| return self._out_features | |
| def out_features(self, out_features: List[str]): | |
| """ | |
| Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. | |
| """ | |
| self._out_features, self._out_indices = get_aligned_output_features_output_indices( | |
| out_features=out_features, out_indices=None, stage_names=self.stage_names | |
| ) | |
| def out_indices(self): | |
| return self._out_indices | |
| def out_indices(self, out_indices: Union[Tuple[int], List[int]]): | |
| """ | |
| Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. | |
| """ | |
| self._out_features, self._out_indices = get_aligned_output_features_output_indices( | |
| out_features=None, out_indices=out_indices, stage_names=self.stage_names | |
| ) | |
| def to_dict(self): | |
| """ | |
| Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to | |
| include the `out_features` and `out_indices` attributes. | |
| """ | |
| output = super().to_dict() | |
| output["out_features"] = output.pop("_out_features") | |
| output["out_indices"] = output.pop("_out_indices") | |
| return output | |