julien.blanchon
add app
c8c12e9
raw
history blame
7.86 kB
"""Utils for NNCf optimization."""
# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import logging
from copy import copy
from typing import Any, Dict, Iterator, List, Tuple
from nncf import NNCFConfig
from nncf.api.compression import CompressionAlgorithmController
from nncf.torch import create_compressed_model, load_state, register_default_init_args
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.nncf_network import NNCFNetwork
from torch import nn
from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(name="NNCF compression")
class InitLoader(PTInitializingDataLoader):
"""Initializing data loader for NNCF to be used with unsupervised training algorithms."""
def __init__(self, data_loader: DataLoader):
super().__init__(data_loader)
self._data_loader_iter: Iterator
def __iter__(self):
"""Create iterator for dataloader."""
self._data_loader_iter = iter(self._data_loader)
return self
def __next__(self) -> Any:
"""Return next item from dataloader iterator."""
loaded_item = next(self._data_loader_iter)
return loaded_item["image"]
def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
"""Get input to model.
Returns:
(dataloader_output,), {}: Tuple[Tuple, Dict]: The current model call to be made during
the initialization process
"""
return (dataloader_output,), {}
def get_target(self, _):
"""Return structure for ground truth in loss criterion based on dataloader output.
This implementation does not do anything and is a placeholder.
Returns:
None
"""
return None
def wrap_nncf_model(
model: nn.Module, config: Dict, dataloader: DataLoader = None, init_state_dict: Dict = None
) -> Tuple[CompressionAlgorithmController, NNCFNetwork]:
"""Wrap model by NNCF.
:param model: Anomalib model.
:param config: NNCF config.
:param dataloader: Dataloader for initialization of NNCF model.
:param init_state_dict: Opti
:return: compression controller, compressed model
"""
nncf_config = NNCFConfig.from_dict(config)
if not dataloader and not init_state_dict:
logger.warning(
"Either dataloader or NNCF pre-trained "
"model checkpoint should be set. Without this, "
"quantizers will not be initialized"
)
compression_state = None
resuming_state_dict = None
if init_state_dict:
resuming_state_dict = init_state_dict.get("model")
compression_state = init_state_dict.get("compression_state")
if dataloader:
init_loader = InitLoader(dataloader) # type: ignore
nncf_config = register_default_init_args(nncf_config, init_loader)
nncf_ctrl, nncf_model = create_compressed_model(
model=model, config=nncf_config, dump_graphs=False, compression_state=compression_state
)
if resuming_state_dict:
load_state(nncf_model, resuming_state_dict, is_resume=True)
return nncf_ctrl, nncf_model
def is_state_nncf(state: Dict) -> bool:
"""The function to check if sate is the result of NNCF-compressed model."""
return bool(state.get("meta", {}).get("nncf_enable_compression", False))
def compose_nncf_config(nncf_config: Dict, enabled_options: List[str]) -> Dict:
"""Compose NNCf config by selected options.
:param nncf_config:
:param enabled_options:
:return: config
"""
optimisation_parts = nncf_config
optimisation_parts_to_choose = []
if "order_of_parts" in optimisation_parts:
# The result of applying the changes from optimisation parts
# may depend on the order of applying the changes
# (e.g. if for nncf_quantization it is sufficient to have `total_epochs=2`,
# but for sparsity it is required `total_epochs=50`)
# So, user can define `order_of_parts` in the optimisation_config
# to specify the order of applying the parts.
order_of_parts = optimisation_parts["order_of_parts"]
assert isinstance(order_of_parts, list), 'The field "order_of_parts" in optimisation config should be a list'
for part in enabled_options:
assert part in order_of_parts, (
f"The part {part} is selected, " "but it is absent in order_of_parts={order_of_parts}"
)
optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options]
assert "base" in optimisation_parts, 'Error: the optimisation config does not contain the "base" part'
nncf_config_part = optimisation_parts["base"]
for part in optimisation_parts_to_choose:
assert part in optimisation_parts, f'Error: the optimisation config does not contain the part "{part}"'
optimisation_part_dict = optimisation_parts[part]
try:
nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict)
except AssertionError as cur_error:
err_descr = (
f"Error during merging the parts of nncf configs:\n"
f"the current part={part}, "
f"the order of merging parts into base is {optimisation_parts_to_choose}.\n"
f"The error is:\n{cur_error}"
)
raise RuntimeError(err_descr) from None
return nncf_config_part
# pylint: disable=invalid-name
def merge_dicts_and_lists_b_into_a(a, b):
"""The function to merge dict configs."""
return _merge_dicts_and_lists_b_into_a(a, b, "")
def _merge_dicts_and_lists_b_into_a(a, b, cur_key=None):
"""The function is inspired by mmcf.Config._merge_a_into_b.
* works with usual dicts and lists and derived types
* supports merging of lists (by concatenating the lists)
* makes recursive merging for dict + dict case
* overwrites when merging scalar into scalar
Note that we merge b into a (whereas Config makes merge a into b),
since otherwise the order of list merging is counter-intuitive.
"""
def _err_str(_a, _b, _key):
if _key is None:
_key_str = "of whole structures"
else:
_key_str = f"during merging for key=`{_key}`"
return (
f"Error in merging parts of config: different types {_key_str},"
f" type(a) = {type(_a)},"
f" type(b) = {type(_b)}"
)
assert isinstance(a, (dict, list)), f"Can merge only dicts and lists, whereas type(a)={type(a)}"
assert isinstance(b, (dict, list)), _err_str(a, b, cur_key)
assert isinstance(a, list) == isinstance(b, list), _err_str(a, b, cur_key)
if isinstance(a, list):
# the main diff w.r.t. mmcf.Config -- merging of lists
return a + b
a = copy(a)
for k in b.keys():
if k not in a:
a[k] = copy(b[k])
continue
new_cur_key = cur_key + "." + k if cur_key else k
if isinstance(a[k], (dict, list)):
a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key)
continue
assert not isinstance(b[k], (dict, list)), _err_str(a[k], b[k], new_cur_key)
# suppose here that a[k] and b[k] are scalars, just overwrite
a[k] = b[k]
return a