""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging from typing import List from torch import nn def tie_encoder_decoder_weights( encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str ): uninitialized_encoder_weights: List[str] = [] if decoder.__class__ != encoder.__class__: logging.info( f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." ) def tie_encoder_to_decoder_recursively( decoder_pointer: nn.Module, encoder_pointer: nn.Module, module_name: str, uninitialized_encoder_weights: List[str], skip_key: str, depth=0, ): assert isinstance(decoder_pointer, nn.Module) and isinstance( encoder_pointer, nn.Module ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" if hasattr(decoder_pointer, "weight") and skip_key not in module_name: assert hasattr(encoder_pointer, "weight") encoder_pointer.weight = decoder_pointer.weight if hasattr(decoder_pointer, "bias"): assert hasattr(encoder_pointer, "bias") encoder_pointer.bias = decoder_pointer.bias print(module_name + " is tied") return encoder_modules = encoder_pointer._modules decoder_modules = decoder_pointer._modules if len(decoder_modules) > 0: assert ( len(encoder_modules) > 0 ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" all_encoder_weights = set( [module_name + "/" + sub_name for sub_name in encoder_modules.keys()] ) encoder_layer_pos = 0 for name, module in decoder_modules.items(): if name.isdigit(): encoder_name = str(int(name) + encoder_layer_pos) decoder_name = name if not isinstance( decoder_modules[decoder_name], type(encoder_modules[encoder_name]), ) and len(encoder_modules) != len(decoder_modules): # this can happen if the name corresponds to the position in a list module list of layers # in this case the decoder has added a cross-attention that the encoder does not have # thus skip this step and subtract one layer pos from encoder encoder_layer_pos -= 1 continue elif name not in encoder_modules: continue elif depth > 500: raise ValueError( "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." ) else: decoder_name = encoder_name = name tie_encoder_to_decoder_recursively( decoder_modules[decoder_name], encoder_modules[encoder_name], module_name + "/" + name, uninitialized_encoder_weights, skip_key, depth=depth + 1, ) all_encoder_weights.remove(module_name + "/" + encoder_name) uninitialized_encoder_weights += list(all_encoder_weights) # tie weights recursively tie_encoder_to_decoder_recursively( decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key )