Spaces:
Sleeping
Sleeping
| import math | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from .vits_config import VitsConfig | |
| #............................................. | |
| def _rational_quadratic_spline( | |
| inputs, | |
| unnormalized_widths, | |
| unnormalized_heights, | |
| unnormalized_derivatives, | |
| reverse, | |
| tail_bound, | |
| min_bin_width, | |
| min_bin_height, | |
| min_derivative, | |
| ): | |
| """ | |
| This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the | |
| function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`. | |
| Args: | |
| inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
| Second half of the hidden-states input to the Vits convolutional flow module. | |
| unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
| First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
| layer in the convolutional flow module | |
| unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
| Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
| layer in the convolutional flow module | |
| unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
| Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
| layer in the convolutional flow module | |
| reverse (`bool`): | |
| Whether the model is being run in reverse mode. | |
| tail_bound (`float`): | |
| Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the | |
| transform behaves as an identity function. | |
| min_bin_width (`float`): | |
| Minimum bin value across the width dimension for the piecewise rational quadratic function. | |
| min_bin_height (`float`): | |
| Minimum bin value across the height dimension for the piecewise rational quadratic function. | |
| min_derivative (`float`): | |
| Minimum bin value across the derivatives for the piecewise rational quadratic function. | |
| Returns: | |
| outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
| Hidden-states as transformed by the piecewise rational quadratic function. | |
| log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
| Logarithm of the absolute value of the determinants corresponding to the `outputs`. | |
| """ | |
| upper_bound = tail_bound | |
| lower_bound = -tail_bound | |
| if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound: | |
| raise ValueError("Input to a transform is not within its domain") | |
| num_bins = unnormalized_widths.shape[-1] | |
| if min_bin_width * num_bins > 1.0: | |
| raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}") | |
| if min_bin_height * num_bins > 1.0: | |
| raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}") | |
| widths = nn.functional.softmax(unnormalized_widths, dim=-1) | |
| widths = min_bin_width + (1 - min_bin_width * num_bins) * widths | |
| cumwidths = torch.cumsum(widths, dim=-1) | |
| cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) | |
| cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound | |
| cumwidths[..., 0] = lower_bound | |
| cumwidths[..., -1] = upper_bound | |
| widths = cumwidths[..., 1:] - cumwidths[..., :-1] | |
| derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives) | |
| heights = nn.functional.softmax(unnormalized_heights, dim=-1) | |
| heights = min_bin_height + (1 - min_bin_height * num_bins) * heights | |
| cumheights = torch.cumsum(heights, dim=-1) | |
| cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) | |
| cumheights = (upper_bound - lower_bound) * cumheights + lower_bound | |
| cumheights[..., 0] = lower_bound | |
| cumheights[..., -1] = upper_bound | |
| heights = cumheights[..., 1:] - cumheights[..., :-1] | |
| bin_locations = cumheights if reverse else cumwidths | |
| bin_locations[..., -1] += 1e-6 | |
| bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 | |
| bin_idx = bin_idx[..., None] | |
| input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] | |
| input_bin_widths = widths.gather(-1, bin_idx)[..., 0] | |
| input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] | |
| delta = heights / widths | |
| input_delta = delta.gather(-1, bin_idx)[..., 0] | |
| input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] | |
| input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] | |
| input_heights = heights.gather(-1, bin_idx)[..., 0] | |
| intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta | |
| if not reverse: | |
| theta = (inputs - input_cumwidths) / input_bin_widths | |
| theta_one_minus_theta = theta * (1 - theta) | |
| numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) | |
| denominator = input_delta + intermediate1 * theta_one_minus_theta | |
| outputs = input_cumheights + numerator / denominator | |
| derivative_numerator = input_delta.pow(2) * ( | |
| input_derivatives_plus_one * theta.pow(2) | |
| + 2 * input_delta * theta_one_minus_theta | |
| + input_derivatives * (1 - theta).pow(2) | |
| ) | |
| log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) | |
| return outputs, log_abs_det | |
| else: | |
| # find the roots of a quadratic equation | |
| intermediate2 = inputs - input_cumheights | |
| intermediate3 = intermediate2 * intermediate1 | |
| a = input_heights * (input_delta - input_derivatives) + intermediate3 | |
| b = input_heights * input_derivatives - intermediate3 | |
| c = -input_delta * intermediate2 | |
| discriminant = b.pow(2) - 4 * a * c | |
| if not (discriminant >= 0).all(): | |
| raise RuntimeError(f"invalid discriminant {discriminant}") | |
| root = (2 * c) / (-b - torch.sqrt(discriminant)) | |
| outputs = root * input_bin_widths + input_cumwidths | |
| theta_one_minus_theta = root * (1 - root) | |
| denominator = input_delta + intermediate1 * theta_one_minus_theta | |
| derivative_numerator = input_delta.pow(2) * ( | |
| input_derivatives_plus_one * root.pow(2) | |
| + 2 * input_delta * theta_one_minus_theta | |
| + input_derivatives * (1 - root).pow(2) | |
| ) | |
| log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) | |
| return outputs, -log_abs_det | |
| #............................................. | |
| def _unconstrained_rational_quadratic_spline( | |
| inputs, | |
| unnormalized_widths, | |
| unnormalized_heights, | |
| unnormalized_derivatives, | |
| reverse=False, | |
| tail_bound=5.0, | |
| min_bin_width=1e-3, | |
| min_bin_height=1e-3, | |
| min_derivative=1e-3, | |
| ): | |
| """ | |
| This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the | |
| `tail_bound`, the transform behaves as an identity function. | |
| Args: | |
| inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
| Second half of the hidden-states input to the Vits convolutional flow module. | |
| unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
| First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
| layer in the convolutional flow module | |
| unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
| Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
| layer in the convolutional flow module | |
| unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
| Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
| layer in the convolutional flow module | |
| reverse (`bool`, *optional*, defaults to `False`): | |
| Whether the model is being run in reverse mode. | |
| tail_bound (`float`, *optional* defaults to 5): | |
| Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the | |
| transform behaves as an identity function. | |
| min_bin_width (`float`, *optional*, defaults to 1e-3): | |
| Minimum bin value across the width dimension for the piecewise rational quadratic function. | |
| min_bin_height (`float`, *optional*, defaults to 1e-3): | |
| Minimum bin value across the height dimension for the piecewise rational quadratic function. | |
| min_derivative (`float`, *optional*, defaults to 1e-3): | |
| Minimum bin value across the derivatives for the piecewise rational quadratic function. | |
| Returns: | |
| outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
| Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits | |
| applied. | |
| log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
| Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound` | |
| limits applied. | |
| """ | |
| inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) | |
| outside_interval_mask = ~inside_interval_mask | |
| outputs = torch.zeros_like(inputs) | |
| log_abs_det = torch.zeros_like(inputs) | |
| constant = np.log(np.exp(1 - min_derivative) - 1) | |
| unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1)) | |
| unnormalized_derivatives[..., 0] = constant | |
| unnormalized_derivatives[..., -1] = constant | |
| outputs[outside_interval_mask] = inputs[outside_interval_mask] | |
| log_abs_det[outside_interval_mask] = 0.0 | |
| outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline( | |
| inputs=inputs[inside_interval_mask], | |
| unnormalized_widths=unnormalized_widths[inside_interval_mask, :], | |
| unnormalized_heights=unnormalized_heights[inside_interval_mask, :], | |
| unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], | |
| reverse=reverse, | |
| tail_bound=tail_bound, | |
| min_bin_width=min_bin_width, | |
| min_bin_height=min_bin_height, | |
| min_derivative=min_derivative, | |
| ) | |
| return outputs, log_abs_det | |
| #............................................................................................. | |
| class VitsConvFlow(nn.Module): | |
| def __init__(self, config: VitsConfig): | |
| super().__init__() | |
| self.filter_channels = config.hidden_size | |
| self.half_channels = config.depth_separable_channels // 2 | |
| self.num_bins = config.duration_predictor_flow_bins | |
| self.tail_bound = config.duration_predictor_tail_bound | |
| self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1) | |
| self.conv_dds = VitsDilatedDepthSeparableConv(config) | |
| self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1) | |
| def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): | |
| first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) | |
| hidden_states = self.conv_pre(first_half) | |
| hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning) | |
| hidden_states = self.conv_proj(hidden_states) * padding_mask | |
| batch_size, channels, length = first_half.shape | |
| hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2) | |
| unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels) | |
| unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) | |
| unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :] | |
| second_half, log_abs_det = _unconstrained_rational_quadratic_spline( | |
| second_half, | |
| unnormalized_widths, | |
| unnormalized_heights, | |
| unnormalized_derivatives, | |
| reverse=reverse, | |
| tail_bound=self.tail_bound, | |
| ) | |
| outputs = torch.cat([first_half, second_half], dim=1) * padding_mask | |
| if not reverse: | |
| log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2]) | |
| return outputs, log_determinant | |
| else: | |
| return outputs, None | |
| #............................................................................................. | |
| class VitsElementwiseAffine(nn.Module): | |
| def __init__(self, config: VitsConfig): | |
| super().__init__() | |
| self.channels = config.depth_separable_channels | |
| self.translate = nn.Parameter(torch.zeros(self.channels, 1)) | |
| self.log_scale = nn.Parameter(torch.zeros(self.channels, 1)) | |
| def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): | |
| if not reverse: | |
| outputs = self.translate + torch.exp(self.log_scale) * inputs | |
| outputs = outputs * padding_mask | |
| log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2]) | |
| return outputs, log_determinant | |
| else: | |
| outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask | |
| return outputs, None | |
| #............................................................................................. | |
| class VitsDilatedDepthSeparableConv(nn.Module): | |
| def __init__(self, config: VitsConfig, dropout_rate=0.0): | |
| super().__init__() | |
| kernel_size = config.duration_predictor_kernel_size | |
| channels = config.hidden_size | |
| self.num_layers = config.depth_separable_num_layers | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.convs_dilated = nn.ModuleList() | |
| self.convs_pointwise = nn.ModuleList() | |
| self.norms_1 = nn.ModuleList() | |
| self.norms_2 = nn.ModuleList() | |
| for i in range(self.num_layers): | |
| dilation = kernel_size**i | |
| padding = (kernel_size * dilation - dilation) // 2 | |
| self.convs_dilated.append( | |
| nn.Conv1d( | |
| in_channels=channels, | |
| out_channels=channels, | |
| kernel_size=kernel_size, | |
| groups=channels, | |
| dilation=dilation, | |
| padding=padding, | |
| ) | |
| ) | |
| self.convs_pointwise.append(nn.Conv1d(channels, channels, 1)) | |
| self.norms_1.append(nn.LayerNorm(channels)) | |
| self.norms_2.append(nn.LayerNorm(channels)) | |
| def forward(self, inputs, padding_mask, global_conditioning=None): | |
| if global_conditioning is not None: | |
| inputs = inputs + global_conditioning | |
| for i in range(self.num_layers): | |
| hidden_states = self.convs_dilated[i](inputs * padding_mask) | |
| hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1) | |
| hidden_states = nn.functional.gelu(hidden_states) | |
| hidden_states = self.convs_pointwise[i](hidden_states) | |
| hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1) | |
| hidden_states = nn.functional.gelu(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| inputs = inputs + hidden_states | |
| return inputs * padding_mask | |
| #............................................................................................. | |
| class VitsStochasticDurationPredictor(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| embed_dim = config.speaker_embedding_size | |
| filter_channels = config.hidden_size | |
| self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1) | |
| self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) | |
| self.conv_dds = VitsDilatedDepthSeparableConv( | |
| config, | |
| dropout_rate=config.duration_predictor_dropout, | |
| ) | |
| if embed_dim != 0: | |
| self.cond = nn.Conv1d(embed_dim, filter_channels, 1) | |
| self.flows = nn.ModuleList() | |
| self.flows.append(VitsElementwiseAffine(config)) | |
| for _ in range(config.duration_predictor_num_flows): | |
| self.flows.append(VitsConvFlow(config)) | |
| self.post_conv_pre = nn.Conv1d(1, filter_channels, 1) | |
| self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) | |
| self.post_conv_dds = VitsDilatedDepthSeparableConv( | |
| config, | |
| dropout_rate=config.duration_predictor_dropout, | |
| ) | |
| self.post_flows = nn.ModuleList() | |
| self.post_flows.append(VitsElementwiseAffine(config)) | |
| for _ in range(config.duration_predictor_num_flows): | |
| self.post_flows.append(VitsConvFlow(config)) | |
| self.filter_channels = filter_channels | |
| def resize_speaker_embeddings(self, speaker_embedding_size): | |
| self.cond = nn.Conv1d(speaker_embedding_size, self.filter_channels, 1) | |
| def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0): | |
| inputs = torch.detach(inputs) | |
| inputs = self.conv_pre(inputs) | |
| if global_conditioning is not None: | |
| global_conditioning = torch.detach(global_conditioning) | |
| inputs = inputs + self.cond(global_conditioning) | |
| inputs = self.conv_dds(inputs, padding_mask) | |
| inputs = self.conv_proj(inputs) * padding_mask | |
| if not reverse: | |
| hidden_states = self.post_conv_pre(durations) | |
| hidden_states = self.post_conv_dds(hidden_states, padding_mask) | |
| hidden_states = self.post_conv_proj(hidden_states) * padding_mask | |
| random_posterior = ( | |
| torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype) | |
| * padding_mask | |
| ) | |
| latents_posterior = random_posterior | |
| latents_posterior, log_determinant = self.post_flows[0]( | |
| latents_posterior, padding_mask, global_conditioning=inputs + hidden_states | |
| ) | |
| log_determinant_posterior_sum = log_determinant | |
| for flow in self.post_flows[1:]: | |
| latents_posterior, log_determinant = flow( | |
| latents_posterior, padding_mask, global_conditioning=inputs + hidden_states | |
| ) | |
| latents_posterior = torch.flip(latents_posterior, [1]) | |
| log_determinant_posterior_sum += log_determinant | |
| first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1) | |
| log_determinant_posterior_sum += torch.sum( | |
| (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2] | |
| ) | |
| logq = ( | |
| torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2]) | |
| - log_determinant_posterior_sum | |
| ) | |
| first_half = (durations - torch.sigmoid(first_half)) * padding_mask | |
| first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask | |
| log_determinant_sum = torch.sum(-first_half, [1, 2]) | |
| latents = torch.cat([first_half, second_half], dim=1) | |
| latents, log_determinant = self.flows[0](latents, padding_mask, global_conditioning=inputs) | |
| log_determinant_sum += log_determinant | |
| for flow in self.flows[1:]: | |
| latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs) | |
| latents = torch.flip(latents, [1]) | |
| log_determinant_sum += log_determinant | |
| nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum | |
| return nll + logq | |
| else: | |
| flows = list(reversed(self.flows)) | |
| flows = flows[:-2] + [flows[-1]] # remove a useless vflow | |
| latents = ( | |
| torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype) | |
| * noise_scale | |
| ) | |
| for flow in flows: | |
| latents = torch.flip(latents, [1]) | |
| latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True) | |
| log_duration, _ = torch.split(latents, [1, 1], dim=1) | |
| return log_duration | |
| #............................................................................................. | |
| class VitsDurationPredictor(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| kernel_size = config.duration_predictor_kernel_size | |
| filter_channels = config.duration_predictor_filter_channels | |
| self.dropout = nn.Dropout(config.duration_predictor_dropout) | |
| self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2) | |
| self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) | |
| self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) | |
| self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) | |
| self.proj = nn.Conv1d(filter_channels, 1, 1) | |
| if config.speaker_embedding_size != 0: | |
| self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1) | |
| self.hidden_size = config.hidden_size | |
| def resize_speaker_embeddings(self, speaker_embedding_size): | |
| self.cond = nn.Conv1d(speaker_embedding_size, self.hidden_size, 1) | |
| def forward(self, inputs, padding_mask, global_conditioning=None): | |
| inputs = torch.detach(inputs) | |
| if global_conditioning is not None: | |
| global_conditioning = torch.detach(global_conditioning) | |
| inputs = inputs + self.cond(global_conditioning) | |
| inputs = self.conv_1(inputs * padding_mask) | |
| inputs = torch.relu(inputs) | |
| inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1) | |
| inputs = self.dropout(inputs) | |
| inputs = self.conv_2(inputs * padding_mask) | |
| inputs = torch.relu(inputs) | |
| inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1) | |
| inputs = self.dropout(inputs) | |
| inputs = self.proj(inputs * padding_mask) | |
| return inputs * padding_mask | |
| #............................................................................................. |