|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from math import pi |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.utils.checkpoint |
|
|
|
from ...configuration_utils import ConfigMixin, register_to_config |
|
from ...models.modeling_utils import ModelMixin |
|
from ...utils import BaseOutput, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class StableAudioPositionalEmbedding(nn.Module): |
|
"""Used for continuous time""" |
|
|
|
def __init__(self, dim: int): |
|
super().__init__() |
|
assert (dim % 2) == 0 |
|
half_dim = dim // 2 |
|
self.weights = nn.Parameter(torch.randn(half_dim)) |
|
|
|
def forward(self, times: torch.Tensor) -> torch.Tensor: |
|
times = times[..., None] |
|
freqs = times * self.weights[None] * 2 * pi |
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) |
|
fouriered = torch.cat((times, fouriered), dim=-1) |
|
return fouriered |
|
|
|
|
|
@dataclass |
|
class StableAudioProjectionModelOutput(BaseOutput): |
|
""" |
|
Args: |
|
Class for StableAudio projection layer's outputs. |
|
text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. |
|
seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): |
|
Sequence of hidden-states obtained by linearly projecting the audio start hidden states. |
|
seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): |
|
Sequence of hidden-states obtained by linearly projecting the audio end hidden states. |
|
""" |
|
|
|
text_hidden_states: Optional[torch.Tensor] = None |
|
seconds_start_hidden_states: Optional[torch.Tensor] = None |
|
seconds_end_hidden_states: Optional[torch.Tensor] = None |
|
|
|
|
|
class StableAudioNumberConditioner(nn.Module): |
|
""" |
|
A simple linear projection model to map numbers to a latent space. |
|
|
|
Args: |
|
number_embedding_dim (`int`): |
|
Dimensionality of the number embeddings. |
|
min_value (`int`): |
|
The minimum value of the seconds number conditioning modules. |
|
max_value (`int`): |
|
The maximum value of the seconds number conditioning modules |
|
internal_dim (`int`): |
|
Dimensionality of the intermediate number hidden states. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
number_embedding_dim, |
|
min_value, |
|
max_value, |
|
internal_dim: Optional[int] = 256, |
|
): |
|
super().__init__() |
|
self.time_positional_embedding = nn.Sequential( |
|
StableAudioPositionalEmbedding(internal_dim), |
|
nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), |
|
) |
|
|
|
self.number_embedding_dim = number_embedding_dim |
|
self.min_value = min_value |
|
self.max_value = max_value |
|
|
|
def forward( |
|
self, |
|
floats: torch.Tensor, |
|
): |
|
floats = floats.clamp(self.min_value, self.max_value) |
|
|
|
normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) |
|
|
|
|
|
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype |
|
normalized_floats = normalized_floats.to(embedder_dtype) |
|
|
|
embedding = self.time_positional_embedding(normalized_floats) |
|
float_embeds = embedding.view(-1, 1, self.number_embedding_dim) |
|
|
|
return float_embeds |
|
|
|
|
|
class StableAudioProjectionModel(ModelMixin, ConfigMixin): |
|
""" |
|
A simple linear projection model to map the conditioning values to a shared latent space. |
|
|
|
Args: |
|
text_encoder_dim (`int`): |
|
Dimensionality of the text embeddings from the text encoder (T5). |
|
conditioning_dim (`int`): |
|
Dimensionality of the output conditioning tensors. |
|
min_value (`int`): |
|
The minimum value of the seconds number conditioning modules. |
|
max_value (`int`): |
|
The maximum value of the seconds number conditioning modules |
|
""" |
|
|
|
@register_to_config |
|
def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): |
|
super().__init__() |
|
self.text_projection = ( |
|
nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) |
|
) |
|
self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) |
|
self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) |
|
|
|
def forward( |
|
self, |
|
text_hidden_states: Optional[torch.Tensor] = None, |
|
start_seconds: Optional[torch.Tensor] = None, |
|
end_seconds: Optional[torch.Tensor] = None, |
|
): |
|
text_hidden_states = ( |
|
text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) |
|
) |
|
seconds_start_hidden_states = ( |
|
start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) |
|
) |
|
seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) |
|
|
|
return StableAudioProjectionModelOutput( |
|
text_hidden_states=text_hidden_states, |
|
seconds_start_hidden_states=seconds_start_hidden_states, |
|
seconds_end_hidden_states=seconds_end_hidden_states, |
|
) |
|
|