# -*- coding: utf-8 -*- # Copyright 2024 Yiwei Guo # Derived mostly from fairseq (https://github.com/facebookresearch/fairseq) """Prompt Pre-net Modules.""" import math import torch.nn as nn from vec2wav2.models.fairseq_modules.fp32_group_norm import Fp32GroupNorm from vec2wav2.models.fairseq_modules.layer_norm import Fp32LayerNorm from vec2wav2.models.fairseq_modules.transpose_last import TransposeLast import torch def norm_block(is_layer_norm, dim, affine=True): if is_layer_norm: mod = nn.Sequential( TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=affine), TransposeLast(), ) else: mod = Fp32GroupNorm(1, dim, affine=affine) return mod class ZeroPad1d(nn.Module): def __init__(self, pad_left, pad_right): super().__init__() self.pad_left = pad_left self.pad_right = pad_right def forward(self, x): return nn.functional.pad(x, (self.pad_left, self.pad_right)) class ConvPromptPrenet(nn.Module): def __init__( self, conv_layers, embed, dropout, skip_connections, residual_scale, non_affine_group_norm, conv_bias, activation, ): super().__init__() def block(n_in, n_out, k, stride, pad): return nn.Sequential( nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, padding=pad), nn.Dropout(p=dropout), norm_block(False, n_out, affine=not non_affine_group_norm), activation, ) in_d = embed self.conv_layers = nn.ModuleList() self.residual_proj = nn.ModuleList() for dim, k, stride, pad in conv_layers: if in_d != dim and skip_connections: self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) else: self.residual_proj.append(None) self.conv_layers.append(block(in_d, dim, k, stride, pad)) in_d = dim self.conv_layers = nn.Sequential(*self.conv_layers) self.skip_connections = skip_connections self.residual_scale = math.sqrt(residual_scale) def forward(self, x): for rproj, conv in zip(self.residual_proj, self.conv_layers): residual = x x = conv(x) if self.skip_connections: if rproj is not None: residual = rproj(residual) x = (x + residual) * self.residual_scale return x