diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7ab702a0189fe6aa06ca41545d7a6db43f639bf5 --- /dev/null +++ b/app.py @@ -0,0 +1,23 @@ +from orator.src.orator.tts import OratorTTS +import gradio as gr + + +model = OratorTTS.from_pretrained("cuda") + + +def generate(text, audio_prompt_path, emotion_adv): + wav = model.generate(text, audio_prompt_path=audio_prompt_path, emotion_adv=emotion_adv) + return 24000, wav.squeeze(0).numpy() + +demo = gr.Interface( + generate, + [ + gr.Textbox(value="What does the fox say?", label="Text to synthesize"), + gr.Audio(sources="upload", type="filepath", label="Input Audio File"), + gr.Slider(0, 1, step=.05, label="emotion_adv", value=.5), + ], + "audio", +) + +if __name__ == "__main__": + demo.launch() diff --git a/orator/src/orator.egg-info/PKG-INFO b/orator/src/orator.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..8a2974074c3fca0ccd532f1159641db692426a02 --- /dev/null +++ b/orator/src/orator.egg-info/PKG-INFO @@ -0,0 +1,17 @@ +Metadata-Version: 2.4 +Name: orator +Version: 0.1 +Description-Content-Type: text/markdown +Requires-Dist: numpy==1.26.0 +Requires-Dist: resampy==0.4.3 +Requires-Dist: librosa==0.10.0 +Requires-Dist: s3tokenizer +Requires-Dist: torch==2.6.0 +Requires-Dist: torchaudio==2.6.0 +Requires-Dist: transformers==4.46.3 +Requires-Dist: diffusers==0.29.0 +Requires-Dist: omegaconf==2.3.0 +Requires-Dist: conformer==0.3.2 + +# orator +Open source TTS model diff --git a/orator/src/orator.egg-info/SOURCES.txt b/orator/src/orator.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..cada9c6ef8507252610ee20e70be2314eeb74cca --- /dev/null +++ b/orator/src/orator.egg-info/SOURCES.txt @@ -0,0 +1,52 @@ +README.md +pyproject.toml +src/orator/__init__.py +src/orator/model_checkpoints.py +src/orator/tts.py +src/orator.egg-info/PKG-INFO +src/orator.egg-info/SOURCES.txt +src/orator.egg-info/dependency_links.txt +src/orator.egg-info/requires.txt +src/orator.egg-info/top_level.txt +src/orator/models/s3gen/__init__.py +src/orator/models/s3gen/const.py +src/orator/models/s3gen/decoder.py +src/orator/models/s3gen/f0_predictor.py +src/orator/models/s3gen/flow.py +src/orator/models/s3gen/flow_matching.py +src/orator/models/s3gen/hifigan.py +src/orator/models/s3gen/s3gen.py +src/orator/models/s3gen/xvector.py +src/orator/models/s3gen/matcha/decoder.py +src/orator/models/s3gen/matcha/flow_matching.py +src/orator/models/s3gen/matcha/text_encoder.py +src/orator/models/s3gen/matcha/transformer.py +src/orator/models/s3gen/transformer/__init__.py +src/orator/models/s3gen/transformer/activation.py +src/orator/models/s3gen/transformer/attention.py +src/orator/models/s3gen/transformer/convolution.py +src/orator/models/s3gen/transformer/embedding.py +src/orator/models/s3gen/transformer/encoder_layer.py +src/orator/models/s3gen/transformer/positionwise_feed_forward.py +src/orator/models/s3gen/transformer/subsampling.py +src/orator/models/s3gen/transformer/upsample_encoder.py +src/orator/models/s3gen/utils/class_utils.py +src/orator/models/s3gen/utils/mask.py +src/orator/models/s3gen/utils/mel.py +src/orator/models/s3tokenizer/__init__.py +src/orator/models/s3tokenizer/s3tokenizer.py +src/orator/models/t3/__init__.py +src/orator/models/t3/llama_configs.py +src/orator/models/t3/t3.py +src/orator/models/t3/inference/t3_hf_backend.py +src/orator/models/t3/modules/cond_enc.py +src/orator/models/t3/modules/learned_pos_emb.py +src/orator/models/t3/modules/perceiver.py +src/orator/models/t3/modules/t3_config.py +src/orator/models/tokenizers/__init__.py +src/orator/models/tokenizers/tokenizer.py +src/orator/models/voice_encoder/__init__.py +src/orator/models/voice_encoder/voice_encoder.py +src/orator/transforms/spectrogram.py +src/orator/transforms/syn_transforms.py +src/orator/transforms/webrtc.py \ No newline at end of file diff --git a/orator/src/orator.egg-info/dependency_links.txt b/orator/src/orator.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/orator/src/orator.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/orator/src/orator.egg-info/requires.txt b/orator/src/orator.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..d5214bc4cfaf3ee370117d092878abedde29e924 --- /dev/null +++ b/orator/src/orator.egg-info/requires.txt @@ -0,0 +1,10 @@ +numpy==1.26.0 +resampy==0.4.3 +librosa==0.10.0 +s3tokenizer +torch==2.6.0 +torchaudio==2.6.0 +transformers==4.46.3 +diffusers==0.29.0 +omegaconf==2.3.0 +conformer==0.3.2 diff --git a/orator/src/orator.egg-info/top_level.txt b/orator/src/orator.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..043cdf6363c1d150134985603544ec5a33b0d53a --- /dev/null +++ b/orator/src/orator.egg-info/top_level.txt @@ -0,0 +1 @@ +orator diff --git a/orator/src/orator/__init__.py b/orator/src/orator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b13b095b627341de9c4a1e63ca91b180910ee3 --- /dev/null +++ b/orator/src/orator/__init__.py @@ -0,0 +1 @@ +from .tts import OratorTTS \ No newline at end of file diff --git a/orator/src/orator/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ad8f5de7c9ca3e1810d0327c25196f38459a597 Binary files /dev/null and b/orator/src/orator/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/__pycache__/tts.cpython-311.pyc b/orator/src/orator/__pycache__/tts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29395581a11f480b8650fe8156c2c05d04b7c889 Binary files /dev/null and b/orator/src/orator/__pycache__/tts.cpython-311.pyc differ diff --git a/orator/src/orator/model_checkpoints.py b/orator/src/orator/model_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/orator/src/orator/models/s3gen/__init__.py b/orator/src/orator/models/s3gen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bef618df9cff52479712a67364c1922b2e27ebff --- /dev/null +++ b/orator/src/orator/models/s3gen/__init__.py @@ -0,0 +1,2 @@ +from .s3gen import S3Token2Wav as S3Gen +from .const import S3GEN_SR diff --git a/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c1699d85577c0eb3fe46fe8d05804981f0498e1 Binary files /dev/null and b/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6fe0e739088104765e9d4b0a805d61b5fca4bc9 Binary files /dev/null and b/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15039b8ac9d6b33e48e3171267d0f625f572d501 Binary files /dev/null and b/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a166e2f186d6fdf7d82570266e1ac96ec6add09 Binary files /dev/null and b/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdbe55fdb8e75b258ce91e94ca09cc743adbb2eb Binary files /dev/null and b/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96606c6e34fb092f600b96dad307e292162edbd6 Binary files /dev/null and b/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ab636cef70ced1ae7dbd376e19e73f03d089a04 Binary files /dev/null and b/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f643c02e44e0ad043029e5379998c15ea69e9c0f Binary files /dev/null and b/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2bf0cbd41dc525b9131d9c18f96021ecccd4b72 Binary files /dev/null and b/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/const.py b/orator/src/orator/models/s3gen/const.py new file mode 100644 index 0000000000000000000000000000000000000000..72de6a2355d1c30dc9ff3ad7ab83df64ea8a17df --- /dev/null +++ b/orator/src/orator/models/s3gen/const.py @@ -0,0 +1 @@ +S3GEN_SR = 24000 diff --git a/orator/src/orator/models/s3gen/decoder.py b/orator/src/orator/models/s3gen/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c568c2dfabd760aa2ee7dcfc688a19a2b5bc6484 --- /dev/null +++ b/orator/src/orator/models/s3gen/decoder.py @@ -0,0 +1,317 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import pack, rearrange, repeat + +from .utils.mask import add_optional_chunk_mask +from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \ + TimestepEmbedding, Upsample1D +from .matcha.transformer import BasicTransformerBlock + + +def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + assert mask.dtype == torch.bool + assert dtype in [torch.float32, torch.bfloat16, torch.float16] + mask = mask.to(dtype) + # attention mask bias + # NOTE(Mddct): torch.finfo jit issues + # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min + mask = (1.0 - mask) * -1.0e+10 + return mask + + + +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor): + x = torch.transpose(x, self.dim0, self.dim1) + return x + + +class CausalBlock1D(Block1D): + def __init__(self, dim: int, dim_out: int): + super(CausalBlock1D, self).__init__(dim, dim_out) + self.block = torch.nn.Sequential( + CausalConv1d(dim, dim_out, 3), + Transpose(1, 2), + nn.LayerNorm(dim_out), + Transpose(1, 2), + nn.Mish(), + ) + + def forward(self, x: torch.Tensor, mask: torch.Tensor): + output = self.block(x * mask) + return output * mask + + +class CausalResnetBlock1D(ResnetBlock1D): + def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): + super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) + self.block1 = CausalBlock1D(dim, dim_out) + self.block2 = CausalBlock1D(dim_out, dim_out) + + +class CausalConv1d(torch.nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + super(CausalConv1d, self).__init__(in_channels, out_channels, + kernel_size, stride, + padding=0, dilation=dilation, + groups=groups, bias=bias, + padding_mode=padding_mode, + device=device, dtype=dtype) + assert stride == 1 + self.causal_padding = (kernel_size - 1, 0) + + def forward(self, x: torch.Tensor): + x = F.pad(x, self.causal_padding) + x = super(CausalConv1d, self).forward(x) + return x + + +class ConditionalDecoder(nn.Module): + def __init__( + self, + in_channels=320, + out_channels=80, + causal=True, + channels=[256], + dropout=0.0, + attention_head_dim=64, + n_blocks=4, + num_mid_blocks=12, + num_heads=8, + act_fn="gelu", + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + self.causal = causal + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + # NOTE jrm: `static_chunk_size` is missing? + self.static_chunk_size = 0 + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ + ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else + CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for _ in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ + ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = CausalResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) if self.causal else ResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + self.initialize_weights() + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t).to(t.dtype) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) + attn_mask = mask_to_bias(attn_mask == 1, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) + attn_mask = mask_to_bias(attn_mask == 1, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) + attn_mask = mask_to_bias(attn_mask == 1, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask diff --git a/orator/src/orator/models/s3gen/f0_predictor.py b/orator/src/orator/models/s3gen/f0_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..172c5f50bdece3d4ac2b3874b0a32deb9f957b93 --- /dev/null +++ b/orator/src/orator/models/s3gen/f0_predictor.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from torch.nn.utils.parametrizations import weight_norm + + +class ConvRNNF0Predictor(nn.Module): + def __init__(self, + num_class: int = 1, + in_channels: int = 80, + cond_channels: int = 512 + ): + super().__init__() + + self.num_class = num_class + self.condnet = nn.Sequential( + weight_norm( + nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + ) + self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.condnet(x) + x = x.transpose(1, 2) + return torch.abs(self.classifier(x).squeeze(-1)) diff --git a/orator/src/orator/models/s3gen/flow.py b/orator/src/orator/models/s3gen/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..a460ddef5db032967e849a2c4e134fcdf58d622d --- /dev/null +++ b/orator/src/orator/models/s3gen/flow.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +from typing import Dict, Optional +import torch +import torch.nn as nn +from torch.nn import functional as F +from omegaconf import DictConfig +from .utils.mask import make_pad_mask + + +class MaskedDiffWithXvec(torch.nn.Module): + def __init__(self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + encoder: torch.nn.Module = None, + length_regulator: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, + 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', + 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), + 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, + 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, + mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, + 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) + self.encoder = encoder + self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) + self.decoder = decoder + self.length_regulator = length_regulator + self.only_mask_loss = only_mask_loss + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + token = batch['speech_token'].to(device) + token_len = batch['speech_token_len'].to(device) + feat = batch['speech_feat'].to(device) + feat_len = batch['speech_feat_len'].to(device) + embedding = batch['embedding'].to(device) + + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len) + h = self.encoder_proj(h) + h, h_lengths = self.length_regulator(h, feat_len) + + # get conditions + conds = torch.zeros(feat.shape, device=token.device) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(feat_len)).to(h) + feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) + loss, _ = self.decoder.compute_loss( + feat.transpose(1, 2).contiguous(), + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + embedding, + cond=conds + ) + return {'loss': loss} + + @torch.inference_mode() + def inference(self, + token, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + embedding, + flow_cache): + if self.fp16 is True: + prompt_feat = prompt_feat.half() + embedding = embedding.half() + + assert token.shape[0] == 1 + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + token_len1, token_len2 = prompt_token.shape[1], token.shape[1] + token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len + mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len) + h = self.encoder_proj(h) + mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256) + h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate) + + # get conditions + conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) + conds[:, :mel_len1] = prompt_feat + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) + feat, flow_cache = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=10, + prompt_len=mel_len1, + flow_cache=flow_cache + ) + feat = feat[:, :, mel_len1:] + assert feat.shape[2] == mel_len2 + return feat.float(), flow_cache + + +class CausalMaskedDiffWithXvec(torch.nn.Module): + def __init__(self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 6561, + input_frame_rate: int = 25, + only_mask_loss: bool = True, + token_mel_ratio: int = 2, + pre_lookahead_len: int = 3, + encoder: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, + 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', + 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), + 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, + 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, + mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, + 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) + self.encoder = encoder + self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) + self.decoder = decoder + self.only_mask_loss = only_mask_loss + self.token_mel_ratio = token_mel_ratio + self.pre_lookahead_len = pre_lookahead_len + + # FIXME: this was missing - just putting it in as false + self.fp16 = False + + @torch.inference_mode() + def inference(self, + token, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + embedding, + finalize): + if self.fp16 is True: + prompt_feat = prompt_feat.half() + embedding = embedding.half() + + assert token.shape[0] == 1 + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len + mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len) + if finalize is False: + h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] + mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] + h = self.encoder_proj(h) + + # get conditions + conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) + conds[:, :mel_len1] = prompt_feat + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) + feat, _ = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=10 + ) + feat = feat[:, :, mel_len1:] + assert feat.shape[2] == mel_len2 + return feat.float(), None # NOTE jrm: why are they returning None here? diff --git a/orator/src/orator/models/s3gen/flow_matching.py b/orator/src/orator/models/s3gen/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..8307e3c0d6120a81b6ff414fafa30e9fc63d015c --- /dev/null +++ b/orator/src/orator/models/s3gen/flow_matching.py @@ -0,0 +1,228 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import torch +import torch.nn.functional as F +from .matcha.flow_matching import BASECFM +from omegaconf import OmegaConf + + +CFM_PARAMS = OmegaConf.create({ + "sigma_min": 1e-06, + "solver": "euler", + "t_scheduler": "cosine", + "training_cfg_rate": 0.2, + "inference_cfg_rate": 0.7, + "reg_loss_type": "l1" +}) + + +class ConditionalCFM(BASECFM): + def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + self.t_scheduler = cfm_params.t_scheduler + self.training_cfg_rate = cfm_params.training_cfg_rate + self.inference_cfg_rate = cfm_params.inference_cfg_rate + in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) + # Just change the architecture of the estimator here + self.estimator = estimator + self.lock = threading.Lock() + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + + z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature + cache_size = flow_cache.shape[2] + # fix prompt and overlap part mu and z + if cache_size != 0: + z[:, :, :cache_size] = flow_cache[:, :, :, 0] + mu[:, :, :cache_size] = flow_cache[:, :, :, 1] + z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) + mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) + flow_cache = torch.stack([z_cache, mu_cache], dim=-1) + + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(dim=0) + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + # Do not use concat, it may cause memory format changed and trt infer with wrong results! + x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype) + mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + t_in = torch.zeros([2], device=x.device, dtype=x.dtype) + spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype) + cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + for step in range(1, len(t_span)): + # Classifier-Free Guidance inference introduced in VoiceBox + x_in[:] = x + mask_in[:] = mask + mu_in[0] = mu + t_in[:] = t.unsqueeze(0) + spks_in[0] = spks + cond_in[0] = cond + dphi_dt = self.forward_estimator( + x_in, mask_in, + mu_in, t_in, + spks_in, + cond_in + ) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1].float() + + def forward_estimator(self, x, mask, mu, t, spks, cond): + if isinstance(self.estimator, torch.nn.Module): + return self.estimator.forward(x, mask, mu, t, spks, cond) + else: + with self.lock: + self.estimator.set_input_shape('x', (2, 80, x.size(2))) + self.estimator.set_input_shape('mask', (2, 1, x.size(2))) + self.estimator.set_input_shape('mu', (2, 80, x.size(2))) + self.estimator.set_input_shape('t', (2,)) + self.estimator.set_input_shape('spks', (2, 80)) + self.estimator.set_input_shape('cond', (2, 80, x.size(2))) + # run trt engine + self.estimator.execute_v2([x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()]) + return x + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t = 1 - torch.cos(t * 0.5 * torch.pi) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + # during training, we randomly drop condition to trade off mode coverage and sample fidelity + if self.training_cfg_rate > 0: + cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate + mu = mu * cfg_mask.view(-1, 1, 1) + spks = spks * cfg_mask.view(-1, 1) + cond = cond * cfg_mask.view(-1, 1, 1) + + pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) + loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) + return loss, y + + +class CausalConditionalCFM(ConditionalCFM): + def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None): + super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) + self.rand_noise = torch.randn([1, 80, 50 * 300]) + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + + z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature + # fix prompt and overlap part mu and z + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None diff --git a/orator/src/orator/models/s3gen/hifigan.py b/orator/src/orator/models/s3gen/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..33f9387e8018169d175fba777a9d70d89035348a --- /dev/null +++ b/orator/src/orator/models/s3gen/hifigan.py @@ -0,0 +1,474 @@ +# jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py +# most modules should be reusable, but I found their SineGen changed a git. + +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HIFI-GAN""" + +from typing import Dict, Optional, List +import numpy as np +from scipy.signal import get_window +import torch +import torch.nn.functional as F +from torch.nn import Conv1d +from torch.nn import ConvTranspose1d +from torch.nn.utils import remove_weight_norm +from torch.nn.utils.parametrizations import weight_norm +from torch.distributions.uniform import Uniform +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +"""hifigan based generator implementation. + +This code is modified from https://github.com/jik876/hifi-gan + ,https://github.com/kan-bayashi/ParallelWaveGAN and + https://github.com/NVIDIA/BigVGAN + +""" + + +class ResBlock(torch.nn.Module): + """Residual block module in HiFiGAN/BigVGAN.""" + def __init__( + self, + channels: int = 512, + kernel_size: int = 3, + dilations: List[int] = [1, 3, 5], + ): + super(ResBlock, self).__init__() + self.convs1 = nn.ModuleList() + self.convs2 = nn.ModuleList() + + for dilation in dilations: + self.convs1.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + padding=get_padding(kernel_size, dilation) + ) + ) + ) + self.convs2.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1) + ) + ) + ) + self.convs1.apply(init_weights) + self.convs2.apply(init_weights) + self.activations1 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs1)) + ]) + self.activations2 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs2)) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for idx in range(len(self.convs1)): + xt = self.activations1[idx](x) + xt = self.convs1[idx](xt) + xt = self.activations2[idx](xt) + xt = self.convs2[idx](xt) + x = xt + x + return x + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): + remove_weight_norm(self.convs1[idx]) + remove_weight_norm(self.convs2[idx]) + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + @torch.no_grad() + def forward(self, f0): + """ + :param f0: [B, 1, sample_len], Hz + :return: [B, 1, sample_len] + """ + + F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) + for i in range(self.harmonic_num + 1): + F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate + + theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) + u_dist = Uniform(low=-np.pi, high=np.pi) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device) + phase_vec[:, 0, :] = 0 + + # generate sine waveforms + sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec) + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2)) + sine_wavs = sine_wavs.transpose(1, 2) + uv = uv.transpose(1, 2) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class HiFTGenerator(nn.Module): + """ + HiFTNet Generator: Neural Source Filter + ISTFTNet + https://arxiv.org/abs/2309.09493 + """ + def __init__( + self, + in_channels: int = 80, + base_channels: int = 512, + nb_harmonics: int = 8, + sampling_rate: int = 22050, + nsf_alpha: float = 0.1, + nsf_sigma: float = 0.003, + nsf_voiced_threshold: float = 10, + upsample_rates: List[int] = [8, 8], + upsample_kernel_sizes: List[int] = [16, 16], + istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + source_resblock_kernel_sizes: List[int] = [7, 11], + source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]], + lrelu_slope: float = 0.1, + audio_limit: float = 0.99, + f0_predictor: torch.nn.Module = None, + ): + super(HiFTGenerator, self).__init__() + + self.out_channels = 1 + self.nb_harmonics = nb_harmonics + self.sampling_rate = sampling_rate + self.istft_params = istft_params + self.lrelu_slope = lrelu_slope + self.audio_limit = audio_limit + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=sampling_rate, + upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], + harmonic_num=nb_harmonics, + sine_amp=nsf_alpha, + add_noise_std=nsf_sigma, + voiced_threshod=nsf_voiced_threshold) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]) + + self.conv_pre = weight_norm( + Conv1d(in_channels, base_channels, 7, 1, padding=3) + ) + + # Up + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + base_channels // (2**i), + base_channels // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + # Down + self.source_downs = nn.ModuleList() + self.source_resblocks = nn.ModuleList() + downsample_rates = [1] + upsample_rates[::-1][:-1] + downsample_cum_rates = np.cumprod(downsample_rates) + for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)): + if u == 1: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) + ) + else: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2)) + ) + + self.source_resblocks.append( + ResBlock(base_channels // (2 ** (i + 1)), k, d) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = base_channels // (2**(i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = nn.ReflectionPad1d((1, 0)) + self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) + self.f0_predictor = f0_predictor + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + self.m_source.remove_weight_norm() + for l in self.source_downs: + remove_weight_norm(l) + for l in self.source_resblocks: + l.remove_weight_norm() + + def _stft(self, x): + spec = torch.stft( + x, + self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), + return_complex=True) + spec = torch.view_as_real(spec) # [B, F, TT, 2] + return spec[..., 0], spec[..., 1] + + def _istft(self, magnitude, phase): + magnitude = torch.clip(magnitude, max=1e2) + real = magnitude * torch.cos(phase) + img = magnitude * torch.sin(phase) + inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], + self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) + return inverse_transform + + def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) + s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) + + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, self.lrelu_slope) + x = self.ups[i](x) + + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + + # fusion + si = self.source_downs[i](s_stft) + si = self.source_resblocks[i](si) + x = x + si + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) + phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy + + x = self._istft(magnitude, phase) + x = torch.clamp(x, -self.audio_limit, self.audio_limit) + return x + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + speech_feat = batch['speech_feat'].transpose(1, 2).to(device) + # mel->f0 + f0 = self.f0_predictor(speech_feat) + # f0->source + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + s, _, _ = self.m_source(s) + s = s.transpose(1, 2) + # mel+source->speech + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, f0 + + @torch.inference_mode() + def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + # mel->f0 + f0 = self.f0_predictor(speech_feat) + # f0->source + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + s, _, _ = self.m_source(s) + s = s.transpose(1, 2) + # use cache_source to avoid glitch + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, s diff --git a/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc b/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42e1166c8d244884c63a58957d94ec95160ea70b Binary files /dev/null and b/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc b/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf4494fa06afb946d3cc8e613a9b6fb2cf6f2411 Binary files /dev/null and b/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc b/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..562bb3e3db1dc26262d483b98ff00e9184264a81 Binary files /dev/null and b/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/matcha/decoder.py b/orator/src/orator/models/s3gen/matcha/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6919f32d9d0a04a0c734190d8b815abb40ad69db --- /dev/null +++ b/orator/src/orator/models/s3gen/matcha/decoder.py @@ -0,0 +1,443 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from conformer import ConformerBlock +from diffusers.models.activations import get_activation +from einops import pack, rearrange, repeat + +from .transformer import BasicTransformerBlock + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv1d(dim, dim_out, 3, padding=1), + torch.nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + + self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class ConformerWrapper(ConformerBlock): + def __init__( # pylint: disable=useless-super-delegation + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0, + ff_dropout=0, + conv_dropout=0, + conv_causal=False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + ): + return super().forward(x=hidden_states, mask=attention_mask.bool()) + + +class Decoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + + resnet = ResnetBlock1D( + dim=2 * input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + @staticmethod + def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c") + mask_down = rearrange(mask_down, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_down, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c") + mask_mid = rearrange(mask_mid, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_mid, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) + x = rearrange(x, "b c t -> b t c") + mask_up = rearrange(mask_up, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_up, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + + return output * mask diff --git a/orator/src/orator/models/s3gen/matcha/flow_matching.py b/orator/src/orator/models/s3gen/matcha/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..add7b08c4661ae7b56a19898cab3e088414a1b40 --- /dev/null +++ b/orator/src/orator/models/s3gen/matcha/flow_matching.py @@ -0,0 +1,129 @@ +from abc import ABC + +import torch +import torch.nn.functional as F + +from .decoder import Decoder + + +class BASECFM(torch.nn.Module, ABC): + def __init__( + self, + n_feats, + cfm_params, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 1e-4 + + self.estimator = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( + torch.sum(mask) * u.shape[1] + ) + return loss, y + + +class CFM(BASECFM): + def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) + # Just change the architecture of the estimator here + self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) diff --git a/orator/src/orator/models/s3gen/matcha/text_encoder.py b/orator/src/orator/models/s3gen/matcha/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..276eee7350b884cd37fd313f1e44db487a77f577 --- /dev/null +++ b/orator/src/orator/models/s3gen/matcha/text_encoder.py @@ -0,0 +1,413 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +import torch.nn as nn +from einops import rearrange + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = rearrange(x, "b h t d -> t b h d") + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d :] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + + return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) + key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) + value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) + + query = self.query_rotary_pe(query) + key = self.key_rotary_pe(key) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + @staticmethod + def _attention_bias_proximal(length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + encoder_type, + encoder_params, + duration_predictor_params, + n_vocab, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.encoder_type = encoder_type + self.n_vocab = n_vocab + self.n_feats = encoder_params.n_feats + self.n_channels = encoder_params.n_channels + self.spk_emb_dim = spk_emb_dim + self.n_spks = n_spks + + self.emb = torch.nn.Embedding(n_vocab, self.n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) + + if encoder_params.prenet: + self.prenet = ConvReluNorm( + self.n_channels, + self.n_channels, + self.n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0.5, + ) + else: + self.prenet = lambda x, x_mask: x + + self.encoder = Encoder( + encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) + self.proj_w = DurationPredictor( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), + duration_predictor_params.filter_channels_dp, + duration_predictor_params.kernel_size, + duration_predictor_params.p_dropout, + ) + + def forward(self, x, x_lengths, spks=None): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size,) + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ + x = self.emb(x) * math.sqrt(self.n_channels) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x, x_mask) + if self.n_spks > 1: + x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + + x_dp = torch.detach(x) + logw = self.proj_w(x_dp, x_mask) + + return mu, logw, x_mask diff --git a/orator/src/orator/models/s3gen/matcha/transformer.py b/orator/src/orator/models/s3gen/matcha/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1afa3aff5383912209e508676c6885e13ef4ee --- /dev/null +++ b/orator/src/orator/models/s3gen/matcha/transformer.py @@ -0,0 +1,316 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from diffusers.models.attention import ( + GEGLU, + GELU, + AdaLayerNorm, + AdaLayerNormZero, + ApproximateGELU, +) +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.proj = LoRACompatibleLinear(in_features, out_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) + self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = torch.exp(self.alpha) + beta = torch.exp(self.beta) + else: + alpha = self.alpha + beta = self.beta + + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + elif activation_fn == "snakebeta": + act_fn = SnakeBeta(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # scale_qk=False, # uncomment this to not to use flash attention + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states diff --git a/orator/src/orator/models/s3gen/s3gen.py b/orator/src/orator/models/s3gen/s3gen.py new file mode 100644 index 0000000000000000000000000000000000000000..97b7c0bd40ad6cd258ca3c4bd4ae752c78f28b19 --- /dev/null +++ b/orator/src/orator/models/s3gen/s3gen.py @@ -0,0 +1,305 @@ +# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import torch +import torchaudio as ta +from functools import lru_cache +from typing import Optional +from omegaconf import DictConfig + +from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer +from .const import S3GEN_SR +from .flow import CausalMaskedDiffWithXvec +from .xvector import CAMPPlus +from .utils.mel import mel_spectrogram +from .f0_predictor import ConvRNNF0Predictor +from .hifigan import HiFTGenerator +from .transformer.upsample_encoder import UpsampleConformerEncoder +from .flow_matching import CausalConditionalCFM +from .decoder import ConditionalDecoder + + +def drop_invalid_tokens(x): + assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now" + return x[x < SPEECH_VOCAB_SIZE] + + +# TODO: global resampler cache +@lru_cache(100) +def get_resampler(src_sr, dst_sr, device): + return ta.transforms.Resample(src_sr, dst_sr).to(device) + + +class S3Token2Mel(torch.nn.Module): + """ + CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms. + + TODO: make these modules configurable? + """ + def __init__(self): + super().__init__() + self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz") + self.mel_extractor = mel_spectrogram # TODO: make it a torch module? + self.speaker_encoder = CAMPPlus() # use default args + + encoder = UpsampleConformerEncoder( + output_size=512, + attention_heads=8, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.1, + normalize_before=True, + input_layer='linear', + pos_enc_layer_type='rel_pos_espnet', + selfattention_layer_type='rel_selfattn', + input_size=512, + use_cnn_module=False, + macaron_style=False, + ) + + estimator = ConditionalDecoder( + in_channels=320, + out_channels=80, + causal=True, + channels=[256], + dropout=0.0, + attention_head_dim=64, + n_blocks=4, + num_mid_blocks=12, + num_heads=8, + act_fn='gelu', + ) + cfm_params = DictConfig({ + "sigma_min": 1e-06, + "solver": 'euler', + "t_scheduler": 'cosine', + "training_cfg_rate": 0.2, + "inference_cfg_rate": 0.7, + "reg_loss_type": 'l1', + }) + decoder = CausalConditionalCFM( + spk_emb_dim=80, + cfm_params=cfm_params, + estimator=estimator, + ) + + self.flow = CausalMaskedDiffWithXvec( + encoder=encoder, + decoder=decoder + ) + + self.resamplers = {} + + @property + def device(self): + params = self.tokenizer.parameters() + return next(params).device + + def embed_ref( + self, + ref_wav: torch.Tensor, + ref_sr: int, + device="auto", + ref_fade_out=True, + ): + device = self.device if device == "auto" else device + if isinstance(ref_wav, np.ndarray): + ref_wav = torch.from_numpy(ref_wav).float() + + if ref_wav.device != device: + ref_wav = ref_wav.to(device) + + if len(ref_wav.shape) == 1: + ref_wav = ref_wav.unsqueeze(0) # (B, L) + + if ref_wav.size(1) > 10 * ref_sr: + print("WARNING: cosydec received ref longer than 10s") + + ref_wav_24 = ref_wav + if ref_sr != S3GEN_SR: + ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav) + + ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device) + ref_mels_24_len = None + + # Resample to 16kHz + ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device) + + # Speaker embedding + ref_x_vector = self.speaker_encoder.inference(ref_wav_16) + + # Tokenize 16khz reference + ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16) + + # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms) + if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]: + logging.warning( + "Reference mel length is not equal to 2 * reference token length.\n" + ) + ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2] + ref_speech_token_lens[0] = ref_speech_tokens.shape[1] + + return dict( + prompt_token=ref_speech_tokens.to(device), + prompt_token_len=ref_speech_token_lens, + prompt_feat=ref_mels_24, + prompt_feat_len=ref_mels_24_len, + embedding=ref_x_vector, + ) + + def forward( + self, + speech_tokens: torch.LongTensor, + # locally-computed ref embedding (mutex with ref_dict) + ref_wav: Optional[torch.Tensor], + ref_sr: Optional[int], + # pre-computed ref embedding (prod API) + ref_dict: Optional[dict] = None, + finalize: bool = False, + ): + """ + Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from. + + NOTE: + - The speaker encoder accepts 16 kHz waveform. + - S3TokenizerV2 accepts 16 kHz waveform. + - The mel-spectrogram for the reference assumes 24 kHz input signal. + - This function is designed for batch_size=1 only. + + Args + ---- + - `speech_tokens`: S3 speech tokens [B=1, T] + - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T]) + - `ref_sr`: reference sample rate + - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored. + """ + assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})" + + if ref_dict is None: + ref_dict = self.embed_ref(ref_wav, ref_sr) + else: + # type/device casting (all values will be numpy if it's from a prod API call) + for rk in list(ref_dict): + if isinstance(ref_dict[rk], np.ndarray): + ref_dict[rk] = torch.from_numpy(ref_dict[rk]) + if torch.is_tensor(ref_dict[rk]): + ref_dict[rk] = ref_dict[rk].to(self.device) + + if len(speech_tokens.shape) == 1: + speech_tokens = speech_tokens.unsqueeze(0) + + # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now" + speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device) + + output_mels, _ = self.flow.inference( + token=speech_tokens, + token_len=speech_token_lens, + finalize=finalize, + **ref_dict, + ) + return output_mels + + +class S3Token2Wav(S3Token2Mel): + """ + The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules. + + TODO: make these modules configurable? + """ + + def __init__(self): + super().__init__() + + f0_predictor = ConvRNNF0Predictor() + self.mel2wav = HiFTGenerator( + sampling_rate=S3GEN_SR, + upsample_rates=[8, 5, 3], + upsample_kernel_sizes=[16, 11, 7], + source_resblock_kernel_sizes=[7, 7, 11], + source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + f0_predictor=f0_predictor, + ) + + # silence out a few ms and fade audio in to reduce artifacts + n_trim = S3GEN_SR // 50 # 20ms = half of a frame + trim_fade = torch.zeros(2 * n_trim) + trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2 + self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting) + + def forward( + self, + speech_tokens, + # locally-computed ref embedding (mutex with ref_dict) + ref_wav: Optional[torch.Tensor], + ref_sr: Optional[int], + # pre-computed ref embedding (prod API) + ref_dict: Optional[dict] = None, + finalize: bool = False + ): + output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) + + # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now. + hift_cache_source = torch.zeros(1, 1, 0).to(self.device) + + output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source) + + if not self.training: + # NOTE: ad-hoc method to reduce "spillover" from the reference clip. + output_wavs[:, :len(self.trim_fade)] *= self.trim_fade + + return output_wavs + + @torch.inference_mode() + def flow_inference( + self, + speech_tokens, + # locally-computed ref embedding (mutex with ref_dict) + ref_wav: Optional[torch.Tensor] = None, + ref_sr: Optional[int] = None, + # pre-computed ref embedding (prod API) + ref_dict: Optional[dict] = None, + finalize: bool = False, + ): + return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) + + @torch.inference_mode() + def hift_inference(self, speech_feat, cache_source: torch.Tensor = None): + if cache_source is None: + cache_source = torch.zeros(1, 1, 0).to(self.device) + return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source) + + @torch.inference_mode() + def inference( + self, + speech_tokens, + # locally-computed ref embedding (mutex with ref_dict) + ref_wav: Optional[torch.Tensor] = None, + ref_sr: Optional[int] = None, + # pre-computed ref embedding (prod API) + ref_dict: Optional[dict] = None, + cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here + finalize: bool = True, + ): + output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) + output_wavs, output_sources = self.hift_inference(output_mels, cache_source) + + # NOTE: ad-hoc method to reduce "spillover" from the reference clip. + output_wavs[:, :len(self.trim_fade)] *= self.trim_fade + + return output_wavs, output_sources diff --git a/orator/src/orator/models/s3gen/transformer/__init__.py b/orator/src/orator/models/s3gen/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68975983af6e89a0802d465c904e25165e7a94f9 Binary files /dev/null and b/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4c2e7e5ed0829a5f0b1d810c25a2e6da94c302e Binary files /dev/null and b/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21b931aa26786a0bc88a3aa8cea5e1c174fdea37 Binary files /dev/null and b/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bd11bf383f2959ca171b136301a685ae73898da Binary files /dev/null and b/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f45bc1da1ca3ec42d268e82e89f551d5517740bd Binary files /dev/null and b/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bfca5ce699f05d5afa8e077df735eb7684e56e3 Binary files /dev/null and b/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8493bd49f27b689b902b376ed4e8e60660e05162 Binary files /dev/null and b/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..069b00e6c39c956c5398b9f9da80d0801a56fe14 Binary files /dev/null and b/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3266bb5614d1c229a25543d11fa6c9f0d39c9d74 Binary files /dev/null and b/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/activation.py b/orator/src/orator/models/s3gen/transformer/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..8cea54816385d3b6585ccc2417bc71630d578177 --- /dev/null +++ b/orator/src/orator/models/s3gen/transformer/activation.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) +# 2020 Northwestern Polytechnical University (Pengcheng Guo) +# 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swish() activation function for Conformer.""" + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swish activation function.""" + return x * torch.sigmoid(x) + + +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x diff --git a/orator/src/orator/models/s3gen/transformer/attention.py b/orator/src/orator/models/s3gen/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..95e1d84035e84b27cfa88680c3d42fc84c0b7aed --- /dev/null +++ b/orator/src/orator/models/s3gen/transformer/attention.py @@ -0,0 +1,330 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-Head Attention layer definition.""" + +import math +from typing import Tuple + +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + """Construct an MultiHeadedAttention object.""" + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be True? + # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the + # 1st chunk to ease the onnx export.] + # 2. pytorch training + if mask.size(2) > 0: # time2 > 0 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # For last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + # NOTE(xcsong): When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + CosyVoice. + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + """ + q, k, v = self.forward_qkv(query, key, value) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, key_bias) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + return x + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used + if matrix_ac.shape != matrix_bd.shape: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/orator/src/orator/models/s3gen/transformer/convolution.py b/orator/src/orator/models/s3gen/transformer/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5d96149154776000991a681a666fbe55e562fe --- /dev/null +++ b/orator/src/orator/models/s3gen/transformer/convolution.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""ConvolutionModule definition.""" + +from typing import Tuple + +import torch +from torch import nn + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model.""" + + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + super().__init__() + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1d(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + x = torch.cat((cache, x), dim=2) + assert (x.size(2) > self.lorder) + new_cache = x[:, :, -self.lorder:] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + return x.transpose(1, 2), new_cache diff --git a/orator/src/orator/models/s3gen/transformer/embedding.py b/orator/src/orator/models/s3gen/transformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..eae8c8ecabb15b4174cc3aa73c070ae702bb5f82 --- /dev/null +++ b/orator/src/orator/models/s3gen/transformer/embedding.py @@ -0,0 +1,294 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Positonal Encoding Module.""" + +import math +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +import numpy as np + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.max_len = max_len + + self.pe = torch.zeros(self.max_len, self.d_model) + position = torch.arange(0, self.max_len, + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + self.pe[:, 0::2] = torch.sin(position * div_term) + self.pe[:, 1::2] = torch.cos(position * div_term) + self.pe = self.pe.unsqueeze(0) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int, torch.tensor): position offset + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ + + self.pe = self.pe.to(x.device) + pos_emb = self.position_encoding(offset, x.size(1), False) + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, + apply_dropout: bool = True) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + if isinstance(offset, int): + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + else: # for batched streaming decoding on GPU + assert torch.max(offset) + size <= self.max_len + index = offset.unsqueeze(1) + \ + torch.arange(0, size).to(offset.device) # B X T + flag = index > 0 + # remove negative offset + index = index * flag + pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model + + if apply_dropout: + pos_emb = self.dropout(pos_emb) + return pos_emb + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.pe = self.pe.to(x.device) + x = x * self.xscale + pos_emb = self.position_encoding(offset, x.size(1), False) + return self.dropout(x), self.dropout(pos_emb) + + +class WhisperPositionalEncoding(PositionalEncoding): + """ Sinusoids position encoding used in openai-whisper.encoder + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): + super().__init__(d_model, dropout_rate, max_len) + self.xscale = 1.0 + log_timescale_increment = np.log(10000) / (d_model // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * + torch.arange(d_model // 2)) + scaled_time = torch.arange(max_len)[:, np.newaxis] * \ + inv_timescales[np.newaxis, :] + pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + delattr(self, "pe") + self.register_buffer("pe", pe.unsqueeze(0)) + + +class LearnablePositionalEncoding(PositionalEncoding): + """ Learnable position encoding used in openai-whisper.decoder + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): + super().__init__(d_model, dropout_rate, max_len) + # NOTE(xcsong): overwrite self.pe & self.xscale + self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model)) + self.xscale = 1.0 + + +class NoPositionalEncoding(torch.nn.Module): + """ No position encoding + """ + + def __init__(self, d_model: int, dropout_rate: float): + super().__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """ Just return zero vector for interface compatibility + """ + pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) + return self.dropout(x), pos_emb + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return torch.zeros(1, size, self.d_model) + + +class EspnetRelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Construct an PositionalEncoding object.""" + super(EspnetRelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: torch.Tensor): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.position_encoding(size=x.size(1), offset=offset) + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, + ] + return pos_emb diff --git a/orator/src/orator/models/s3gen/transformer/encoder_layer.py b/orator/src/orator/models/s3gen/transformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..efbb12dd365770bebe8bca75276fe63be260a08f --- /dev/null +++ b/orator/src/orator/models/s3gen/transformer/encoder_layer.py @@ -0,0 +1,236 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder self-attention layer definition.""" + +from typing import Optional, Tuple + +import torch +from torch import nn + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: torch.nn.Module, + dropout_rate: float, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, eps=1e-12) + self.norm2 = nn.LayerNorm(size, eps=1e-12) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): just for interface compatibility + to ConformerEncoderLayer + mask_pad (torch.Tensor): does not used in transformer layer, + just for unified api with conformer. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2), not used here, it's for interface + compatibility to ConformerEncoderLayer. + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). + + """ + residual = x + if self.normalize_before: + x = self.norm1(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + return x, mask, new_att_cache, fake_cnn_cache + + +class ConformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[nn.Module] = None, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module + self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module + self.norm_final = nn.LayerNorm( + size, eps=1e-12) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, + att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/orator/src/orator/models/s3gen/transformer/positionwise_feed_forward.py b/orator/src/orator/models/s3gen/transformer/positionwise_feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a2cf6e7315e3a5ed2794423daff0a59cc5b208 --- /dev/null +++ b/orator/src/orator/models/s3gen/transformer/positionwise_feed_forward.py @@ -0,0 +1,115 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__( + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + ): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = torch.nn.Dropout(dropout_rate) + self.w_2 = torch.nn.Linear(hidden_units, idim) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) + + +class MoEFFNLayer(torch.nn.Module): + """ + Mixture of expert with Positionwise feed forward layer + See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf + The output dim is same with the input dim. + + Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 + https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 + Args: + n_expert: number of expert. + n_expert_per_token: The actual number of experts used for each frame + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__( + self, + n_expert: int, + n_expert_per_token: int, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + ): + super(MoEFFNLayer, self).__init__() + self.gate = torch.nn.Linear(idim, n_expert, bias=False) + self.experts = torch.nn.ModuleList( + PositionwiseFeedForward(idim, hidden_units, dropout_rate, + activation) for _ in range(n_expert)) + self.n_expert_per_token = n_expert_per_token + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Foward function. + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + + """ + B, L, D = xs.size( + ) # batch size, sequence length, embedding dimension (idim) + xs = xs.view(-1, D) # (B*L, D) + router = self.gate(xs) # (B*L, n_expert) + logits, indices = torch.topk( + router, self.n_expert_per_token + ) # probs:(B*L, n_expert), indices: (B*L, n_expert) + weights = torch.nn.functional.softmax( + logits, dim=1, + dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) + output = torch.zeros_like(xs) # (B*L, D) + for i, expert in enumerate(self.experts): + mask = indices == i + batch_idx, ith_expert = torch.where(mask) + output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( + xs[batch_idx]) + return output.view(B, L, D) diff --git a/orator/src/orator/models/s3gen/transformer/subsampling.py b/orator/src/orator/models/s3gen/transformer/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..e17c2e324e3afb24e1b619effe29cef07c9c5b3a --- /dev/null +++ b/orator/src/orator/models/s3gen/transformer/subsampling.py @@ -0,0 +1,383 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Subsampling layer definition.""" + +from typing import Tuple, Union + +import torch + + +class BaseSubsampling(torch.nn.Module): + + def __init__(self): + super().__init__() + self.right_context = 0 + self.subsampling_rate = 1 + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return self.pos_enc.position_encoding(offset, size) + + +class EmbedinigNoSubsampling(BaseSubsampling): + """Embedding input without subsampling + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + super().__init__() + self.embed = torch.nn.Embedding(idim, odim) + self.pos_enc = pos_enc_class + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.embed(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class LinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class Conv1dSubsampling2(BaseSubsampling): + """Convolutional 1D subsampling (to 1/2 length). + It is designed for Whisper, ref: + https://github.com/openai/whisper/blob/main/whisper/model.py + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv1dSubsampling2 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1), + torch.nn.GELU(), + torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1), + torch.nn.GELU(), + ) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 2 + # 4 = (3 - 1) * 1 + (3 - 1) * 1 + self.right_context = 4 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 2. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 2. + torch.Tensor: positional encoding + + """ + time = x.size(1) + x = x.transpose(1, 2) # (b, f, t) + x = self.conv(x) + x = x.transpose(1, 2) # (b, t, f) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, (time + 1) % 2::2] + + +class Conv2dSubsampling4(BaseSubsampling): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling4 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + torch.Tensor: positional encoding + + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2] + + +class Conv2dSubsampling6(BaseSubsampling): + """Convolutional 2D subsampling (to 1/6 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling6 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), + odim) + self.pos_enc = pos_enc_class + # 10 = (3 - 1) * 1 + (5 - 1) * 2 + self.subsampling_rate = 6 + self.right_context = 10 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3] + + +class Conv2dSubsampling8(BaseSubsampling): + """Convolutional 2D subsampling (to 1/8 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling8 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear( + odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) + self.pos_enc = pos_enc_class + self.subsampling_rate = 8 + # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 + self.right_context = 14 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2] + + +class LegacyLinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask diff --git a/orator/src/orator/models/s3gen/transformer/upsample_encoder.py b/orator/src/orator/models/s3gen/transformer/upsample_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..766a5e4e77070ff5579b1a567607c2879391bf8a --- /dev/null +++ b/orator/src/orator/models/s3gen/transformer/upsample_encoder.py @@ -0,0 +1,318 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder definition.""" +from typing import Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from .convolution import ConvolutionModule +from .encoder_layer import ConformerEncoderLayer +from .positionwise_feed_forward import PositionwiseFeedForward +from ..utils.class_utils import ( + COSYVOICE_EMB_CLASSES, + COSYVOICE_SUBSAMPLE_CLASSES, + COSYVOICE_ATTENTION_CLASSES, + COSYVOICE_ACTIVATION_CLASSES, +) +from ..utils.mask import make_pad_mask +from ..utils.mask import add_optional_chunk_mask + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.stride = stride + # In this mode, first repeat interpolate, than conv with stride=1 + self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) + + def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): + outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") + outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) + outputs = self.conv(outputs) + return outputs, input_lengths * self.stride + + +class PreLookaheadLayer(nn.Module): + def __init__(self, channels: int, pre_lookahead_len: int = 1): + super().__init__() + self.channels = channels + self.pre_lookahead_len = pre_lookahead_len + self.conv1 = nn.Conv1d( + channels, channels, + kernel_size=pre_lookahead_len + 1, + stride=1, padding=0, + ) + self.conv2 = nn.Conv1d( + channels, channels, + kernel_size=3, stride=1, padding=0, + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + inputs: (batch_size, seq_len, channels) + """ + outputs = inputs.transpose(1, 2).contiguous() + # look ahead + outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) + outputs = F.leaky_relu(self.conv1(outputs)) + # outputs + outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0) + outputs = self.conv2(outputs) + outputs = outputs.transpose(1, 2).contiguous() + + # residual connection + outputs = outputs + inputs + return outputs + + +class UpsampleConformerEncoder(torch.nn.Module): + + def __init__( + self, + input_size: int = 512, + output_size: int = 512, + attention_heads: int = 8, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.1, + input_layer: str = "linear", + pos_enc_layer_type: str = "rel_pos_espnet", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = False, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = False, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + key_bias: bool = True, + gradient_checkpointing: bool = False, + ): + """ + Args: + input_size (int): input dim + output_size (int): dimension of attention + attention_heads (int): the number of heads of multi head attention + linear_units (int): the hidden units number of position-wise feed + forward + num_blocks (int): the number of decoder blocks + dropout_rate (float): dropout rate + attention_dropout_rate (float): dropout rate in attention + positional_dropout_rate (float): dropout rate after adding + positional encoding + input_layer (str): input layer type. + optional [linear, conv2d, conv2d6, conv2d8] + pos_enc_layer_type (str): Encoder positional encoding layer type. + opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] + normalize_before (bool): + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + static_chunk_size (int): chunk size for static chunk training and + decoding + use_dynamic_chunk (bool): whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dyanmic chunk size(use_dynamic_chunk = True) + global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module + use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. + """ + super().__init__() + self._output_size = output_size + + self.global_cmvn = global_cmvn + self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( + input_size, + output_size, + dropout_rate, + COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size, + positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.gradient_checkpointing = gradient_checkpointing + activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]() + # self-attention module definition + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + key_bias, + ) + # feed-forward module definition + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal) + self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3) + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + PositionwiseFeedForward(*positionwise_layer_args), + PositionwiseFeedForward( + *positionwise_layer_args) if macaron_style else None, + ConvolutionModule( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + ) for _ in range(num_blocks) + ]) + self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2) + self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( + input_size, + output_size, + dropout_rate, + COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size, + positional_dropout_rate), + ) + self.up_encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + PositionwiseFeedForward(*positionwise_layer_args), + PositionwiseFeedForward( + *positionwise_layer_args) if macaron_style else None, + ConvolutionModule( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + ) for _ in range(4) + ]) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + # lookahead + conformer encoder + xs = self.pre_lookahead_layer(xs) + xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) + + # upsample + conformer encoder + xs = xs.transpose(1, 2).contiguous() + xs, xs_lens = self.up_layer(xs, xs_lens) + xs = xs.transpose(1, 2).contiguous() + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + xs, pos_emb, masks = self.up_embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size * self.up_layer.stride, + num_decoding_left_chunks) + xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) + + if self.normalize_before: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs + + def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.up_encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs diff --git a/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc b/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12857ba8f4f8bd70e07bd6e21b959c5d7f11ffec Binary files /dev/null and b/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc b/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9c6b04b80c213f9f464d8619bcdf4d6ccb80abb Binary files /dev/null and b/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc b/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..214a3ebeab56b71cd7de1a499192d6d4b5d893ed Binary files /dev/null and b/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/utils/class_utils.py b/orator/src/orator/models/s3gen/utils/class_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cd31e48029ce1ee11728a2edbffec479cc0a1bd6 --- /dev/null +++ b/orator/src/orator/models/s3gen/utils/class_utils.py @@ -0,0 +1,71 @@ +# Copyright [2023-11-28] +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from ..transformer.activation import Swish +from ..transformer.subsampling import ( + LinearNoSubsampling, + EmbedinigNoSubsampling, + Conv1dSubsampling2, + Conv2dSubsampling4, + Conv2dSubsampling6, + Conv2dSubsampling8, +) +from ..transformer.embedding import ( + PositionalEncoding, + RelPositionalEncoding, + WhisperPositionalEncoding, + LearnablePositionalEncoding, + NoPositionalEncoding) +from ..transformer.attention import (MultiHeadedAttention, + RelPositionMultiHeadedAttention) +from ..transformer.embedding import EspnetRelPositionalEncoding +from ..transformer.subsampling import LegacyLinearNoSubsampling + + +COSYVOICE_ACTIVATION_CLASSES = { + "hardtanh": torch.nn.Hardtanh, + "tanh": torch.nn.Tanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": getattr(torch.nn, "SiLU", Swish), + "gelu": torch.nn.GELU, +} + +COSYVOICE_SUBSAMPLE_CLASSES = { + "linear": LinearNoSubsampling, + "linear_legacy": LegacyLinearNoSubsampling, + "embed": EmbedinigNoSubsampling, + "conv1d2": Conv1dSubsampling2, + "conv2d": Conv2dSubsampling4, + "conv2d6": Conv2dSubsampling6, + "conv2d8": Conv2dSubsampling8, + 'paraformer_dummy': torch.nn.Identity +} + +COSYVOICE_EMB_CLASSES = { + "embed": PositionalEncoding, + "abs_pos": PositionalEncoding, + "rel_pos": RelPositionalEncoding, + "rel_pos_espnet": EspnetRelPositionalEncoding, + "no_pos": NoPositionalEncoding, + "abs_pos_whisper": WhisperPositionalEncoding, + "embed_learnable_pe": LearnablePositionalEncoding, +} + +COSYVOICE_ATTENTION_CLASSES = { + "selfattn": MultiHeadedAttention, + "rel_selfattn": RelPositionMultiHeadedAttention, +} diff --git a/orator/src/orator/models/s3gen/utils/mask.py b/orator/src/orator/models/s3gen/utils/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..08c97a3ed6f49d9e623b252273d2eee9d26c408b --- /dev/null +++ b/orator/src/orator/models/s3gen/utils/mask.py @@ -0,0 +1,193 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +''' +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=torch.bool) + return torch.tril(ret) +''' + + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks + # actually this is not needed after we have inference cache implemented, will remove it later + pos_idx = torch.arange(size, device=device) + block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size + ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) + return ret + + +def add_optional_chunk_mask(xs: torch.Tensor, + masks: torch.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int, + enable_full_context: bool = True): + """ Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + enable_full_context (bool): + True: chunk size is either [1, 25] or full context(max_len) + False: chunk size ~ U[1, 25] + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. + chunk_size = torch.randint(1, max_len, (1, )).item() + num_left_chunks = -1 + if chunk_size > max_len // 2 and enable_full_context: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = torch.randint(0, max_left_chunks, + (1, )).item() + chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + else: + chunk_masks = masks + assert chunk_masks.dtype == torch.bool + if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: + logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') + chunk_masks[chunk_masks.sum(dim=-1)==0] = True + return chunk_masks + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask diff --git a/orator/src/orator/models/s3gen/utils/mel.py b/orator/src/orator/models/s3gen/utils/mel.py new file mode 100644 index 0000000000000000000000000000000000000000..5a9ff9d11d67e1d6a96dd97d45a02366a3bba300 --- /dev/null +++ b/orator/src/orator/models/s3gen/utils/mel.py @@ -0,0 +1,81 @@ +"""mel-spectrogram extraction in Matcha-TTS""" +from librosa.filters import mel as librosa_mel_fn +import torch +import numpy as np + + +# NOTE: they decalred these global vars +mel_basis = {} +hann_window = {} + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + +""" +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1920 + num_mels: 80 + sampling_rate: 24000 + hop_size: 480 + win_size: 1920 + fmin: 0 + fmax: 8000 + center: False + +""" + +def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, + fmin=0, fmax=8000, center=False): + """Copied from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py + Set default values according to Cosyvoice's config. + """ + + if isinstance(y, np.ndarray): + y = torch.tensor(y).float() + + if len(y.shape) == 1: + y = y[None, ] + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/orator/src/orator/models/s3gen/xvector.py b/orator/src/orator/models/s3gen/xvector.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb99af4aad25b33698211aa033d182d2f753379 --- /dev/null +++ b/orator/src/orator/models/s3gen/xvector.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) +# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker) + + +from collections import OrderedDict +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as cp +import torchaudio.compliance.kaldi as Kaldi + + +def pad_list(xs, pad_value): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, : xs[i].size(0)] = xs[i] + + return pad + + +def extract_feature(audio): + features = [] + feature_times = [] + feature_lengths = [] + for au in audio: + feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80) + feature = feature - feature.mean(dim=0, keepdim=True) + features.append(feature) + feature_times.append(au.shape[0]) + feature_lengths.append(feature.shape[0]) + # padding for batch inference + features_padded = pad_list(features, pad_value=0) + # features = torch.cat(features) + return features_padded, feature_lengths, feature_times + + +class BasicResBlock(torch.nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicResBlock, self).__init__() + self.conv1 = torch.nn.Conv2d( + in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False + ) + self.bn1 = torch.nn.BatchNorm2d(planes) + self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = torch.nn.BatchNorm2d(planes) + + self.shortcut = torch.nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = torch.nn.Sequential( + torch.nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=(stride, 1), + bias=False, + ), + torch.nn.BatchNorm2d(self.expansion * planes), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class FCM(torch.nn.Module): + def __init__(self, block=BasicResBlock, num_blocks=[2, 2], m_channels=32, feat_dim=80): + super(FCM, self).__init__() + self.in_planes = m_channels + self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = torch.nn.BatchNorm2d(m_channels) + + self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) + self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2) + + self.conv2 = torch.nn.Conv2d( + m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False + ) + self.bn2 = torch.nn.BatchNorm2d(m_channels) + self.out_channels = m_channels * (feat_dim // 8) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return torch.nn.Sequential(*layers) + + def forward(self, x): + x = x.unsqueeze(1) + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = F.relu(self.bn2(self.conv2(out))) + + shape = out.shape + out = out.reshape(shape[0], shape[1] * shape[2], shape[3]) + return out + + +def get_nonlinear(config_str, channels): + nonlinear = torch.nn.Sequential() + for name in config_str.split("-"): + if name == "relu": + nonlinear.add_module("relu", torch.nn.ReLU(inplace=True)) + elif name == "prelu": + nonlinear.add_module("prelu", torch.nn.PReLU(channels)) + elif name == "batchnorm": + nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels)) + elif name == "batchnorm_": + nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels, affine=False)) + else: + raise ValueError("Unexpected module ({}).".format(name)) + return nonlinear + + +def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2): + mean = x.mean(dim=dim) + std = x.std(dim=dim, unbiased=unbiased) + stats = torch.cat([mean, std], dim=-1) + if keepdim: + stats = stats.unsqueeze(dim=dim) + return stats + + +class StatsPool(torch.nn.Module): + def forward(self, x): + return statistics_pooling(x) + + +class TDNNLayer(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=False, + config_str="batchnorm-relu", + ): + super(TDNNLayer, self).__init__() + if padding < 0: + assert ( + kernel_size % 2 == 1 + ), "Expect equal paddings, but got even kernel size ({})".format(kernel_size) + padding = (kernel_size - 1) // 2 * dilation + self.linear = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + self.nonlinear = get_nonlinear(config_str, out_channels) + + def forward(self, x): + x = self.linear(x) + x = self.nonlinear(x) + return x + + +class CAMLayer(torch.nn.Module): + def __init__( + self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2 + ): + super(CAMLayer, self).__init__() + self.linear_local = torch.nn.Conv1d( + bn_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1) + self.relu = torch.nn.ReLU(inplace=True) + self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + y = self.linear_local(x) + context = x.mean(-1, keepdim=True) + self.seg_pooling(x) + context = self.relu(self.linear1(context)) + m = self.sigmoid(self.linear2(context)) + return y * m + + def seg_pooling(self, x, seg_len=100, stype="avg"): + if stype == "avg": + seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) + elif stype == "max": + seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) + else: + raise ValueError("Wrong segment pooling type.") + shape = seg.shape + seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1) + seg = seg[..., : x.shape[-1]] + return seg + + +class CAMDenseTDNNLayer(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + bn_channels, + kernel_size, + stride=1, + dilation=1, + bias=False, + config_str="batchnorm-relu", + memory_efficient=False, + ): + super(CAMDenseTDNNLayer, self).__init__() + assert kernel_size % 2 == 1, "Expect equal paddings, but got even kernel size ({})".format( + kernel_size + ) + padding = (kernel_size - 1) // 2 * dilation + self.memory_efficient = memory_efficient + self.nonlinear1 = get_nonlinear(config_str, in_channels) + self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False) + self.nonlinear2 = get_nonlinear(config_str, bn_channels) + self.cam_layer = CAMLayer( + bn_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + def bn_function(self, x): + return self.linear1(self.nonlinear1(x)) + + def forward(self, x): + if self.training and self.memory_efficient: + x = cp.checkpoint(self.bn_function, x) + else: + x = self.bn_function(x) + x = self.cam_layer(self.nonlinear2(x)) + return x + + +class CAMDenseTDNNBlock(torch.nn.ModuleList): + def __init__( + self, + num_layers, + in_channels, + out_channels, + bn_channels, + kernel_size, + stride=1, + dilation=1, + bias=False, + config_str="batchnorm-relu", + memory_efficient=False, + ): + super(CAMDenseTDNNBlock, self).__init__() + for i in range(num_layers): + layer = CAMDenseTDNNLayer( + in_channels=in_channels + i * out_channels, + out_channels=out_channels, + bn_channels=bn_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + bias=bias, + config_str=config_str, + memory_efficient=memory_efficient, + ) + self.add_module("tdnnd%d" % (i + 1), layer) + + def forward(self, x): + for layer in self: + x = torch.cat([x, layer(x)], dim=1) + return x + + +class TransitLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"): + super(TransitLayer, self).__init__() + self.nonlinear = get_nonlinear(config_str, in_channels) + self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias) + + def forward(self, x): + x = self.nonlinear(x) + x = self.linear(x) + return x + + +class DenseLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"): + super(DenseLayer, self).__init__() + self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias) + self.nonlinear = get_nonlinear(config_str, out_channels) + + def forward(self, x): + if len(x.shape) == 2: + x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1) + else: + x = self.linear(x) + x = self.nonlinear(x) + return x + +# @tables.register("model_classes", "CAMPPlus") +class CAMPPlus(torch.nn.Module): + def __init__( + self, + feat_dim=80, + embedding_size=192, + growth_rate=32, + bn_size=4, + init_channels=128, + config_str="batchnorm-relu", + memory_efficient=True, + output_level="segment", + **kwargs, + ): + super().__init__() + + self.head = FCM(feat_dim=feat_dim) + channels = self.head.out_channels + self.output_level = output_level + + self.xvector = torch.nn.Sequential( + OrderedDict( + [ + ( + "tdnn", + TDNNLayer( + channels, + init_channels, + 5, + stride=2, + dilation=1, + padding=-1, + config_str=config_str, + ), + ), + ] + ) + ) + channels = init_channels + for i, (num_layers, kernel_size, dilation) in enumerate( + zip((12, 24, 16), (3, 3, 3), (1, 2, 2)) + ): + block = CAMDenseTDNNBlock( + num_layers=num_layers, + in_channels=channels, + out_channels=growth_rate, + bn_channels=bn_size * growth_rate, + kernel_size=kernel_size, + dilation=dilation, + config_str=config_str, + memory_efficient=memory_efficient, + ) + self.xvector.add_module("block%d" % (i + 1), block) + channels = channels + num_layers * growth_rate + self.xvector.add_module( + "transit%d" % (i + 1), + TransitLayer(channels, channels // 2, bias=False, config_str=config_str), + ) + channels //= 2 + + self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels)) + + if self.output_level == "segment": + self.xvector.add_module("stats", StatsPool()) + self.xvector.add_module( + "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_") + ) + else: + assert ( + self.output_level == "frame" + ), "`output_level` should be set to 'segment' or 'frame'. " + + for m in self.modules(): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): + torch.nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = self.head(x) + x = self.xvector(x) + if self.output_level == "frame": + x = x.transpose(1, 2) + return x + + def inference(self, audio_list): + speech, speech_lengths, speech_times = extract_feature(audio_list) + results = self.forward(speech.to(torch.float32)) + return results diff --git a/orator/src/orator/models/s3tokenizer/__init__.py b/orator/src/orator/models/s3tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..99d0168c250cde792a6679510f002443b95558ab --- /dev/null +++ b/orator/src/orator/models/s3tokenizer/__init__.py @@ -0,0 +1,13 @@ +from .s3tokenizer import ( + S3_SR, + S3_HOP, + S3_TOKEN_HOP, + S3_TOKEN_RATE, + SPEECH_VOCAB_SIZE, + S3Tokenizer, +) + + +def drop_invalid_tokens(x): + assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now" + return x[x < SPEECH_VOCAB_SIZE] diff --git a/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..717760c538cda8eb3167fe178e2619cd52b8bc29 Binary files /dev/null and b/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc b/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d25539d27187172122848ab843e24cfd1bf0fb1e Binary files /dev/null and b/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3tokenizer/s3tokenizer.py b/orator/src/orator/models/s3tokenizer/s3tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8648608ae4d8f28bfeec090b5fdb426b6b0ad336 --- /dev/null +++ b/orator/src/orator/models/s3tokenizer/s3tokenizer.py @@ -0,0 +1,168 @@ +from typing import List, Tuple + +import numpy as np +import librosa +import torch +import torch.nn.functional as F +from s3tokenizer.utils import padding +from s3tokenizer.model_v2 import ( + S3TokenizerV2, + ModelConfig, +) + + +# Sampling rate of the inputs to S3TokenizerV2 +S3_SR = 16_000 +S3_HOP = 160 # 100 frames/sec +S3_TOKEN_HOP = 640 # 25 tokens/sec +S3_TOKEN_RATE = 25 +SPEECH_VOCAB_SIZE = 6561 + + +class S3Tokenizer(S3TokenizerV2): + """ + s3tokenizer.S3TokenizerV2 with the following changes: + - a more integrated `forward` + - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` + """ + + ignore_state_dict_missing = ("_mel_filters", "window") + + def __init__( + self, + name: str="speech_tokenizer_v2_25hz", + config: ModelConfig = ModelConfig() + ): + super().__init__(name) + + self.n_fft = 400 + _mel_filters = librosa.filters.mel( + sr=S3_SR, + n_fft=self.n_fft, + n_mels=config.n_mels + ) + self.register_buffer( + "_mel_filters", + torch.FloatTensor(_mel_filters), + ) + + self.register_buffer( + "window", + torch.hann_window(self.n_fft), + ) + + def pad(self, wavs, sr) -> List[torch.Tensor]: + """ + Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). + """ + processed_wavs = [] + for wav in wavs: + if isinstance(wav, np.ndarray): + wav = torch.from_numpy(wav) + if wav.dim() == 1: + wav = wav.unsqueeze(0) + + n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE + n_tokens = np.ceil(n_tokens) + intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) + intended_wav_len = int(intended_wav_len) + wav = torch.nn.functional.pad( + wav, + (0, intended_wav_len - wav.shape[-1]), + mode="constant", + value=0 + ) + processed_wavs.append(wav) + return processed_wavs + + def _prepare_audio(self, wavs): + """Prepare a list of audios for s3tokenizer processing.""" + processed_wavs = [] + for wav in wavs: + if isinstance(wav, np.ndarray): + wav = torch.from_numpy(wav) + if wav.dim() == 1: + wav = wav.unsqueeze(0) + + processed_wavs.append(wav) + return processed_wavs + + @torch.no_grad() + def forward( + self, + wavs: torch.Tensor, + accelerator: 'Accelerator'=None, + max_len: int=None, + ) -> Tuple[torch.Tensor, torch.LongTensor]: + """ + NOTE: mel-spec has a hop size of 160 points (100 frame/sec). + FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. + + Args + ---- + - `wavs`: 16 kHz speech audio + - `max_len` max length to truncate the output sequence to (25 token/sec). + NOTE: please pad the waveform if longer sequence is needed. + """ + processed_wavs = self._prepare_audio(wavs) + mels, mel_lens = [], [] + for wav in processed_wavs: + wav = wav.to(self.device) + mel = self.log_mel_spectrogram(wav) # [B=1, F, T] + if max_len is not None: + mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens + mels.append(mel.squeeze(0)) + + mels, mel_lens = padding(mels) + if accelerator is None: + tokenizer = self + else: + tokenizer = accelerator.unwrap_model(self) + + speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) + return ( + speech_tokens.long().detach(), + speech_token_lens.long().detach(), + ) + + def log_mel_spectrogram( + self, + audio: torch.Tensor, + padding: int = 0, + ): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: torch.Tensor, shape = (*) + The path to audio or either a NumPy array or Tensor containing the + audio waveform in 16 kHz + + padding: int + Number of zero samples to pad to the right + + Returns + ------- + torch.Tensor, shape = (128, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + audio = torch.from_numpy(audio) + + audio = audio.to(self.device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + stft = torch.stft( + audio, self.n_fft, S3_HOP, + window=self.window.to(self.device), + return_complex=True + ) + magnitudes = stft[..., :-1].abs()**2 + + mel_spec = self._mel_filters.to(self.device) @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec diff --git a/orator/src/orator/models/t3/__init__.py b/orator/src/orator/models/t3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c15519f6107cd9f4d825e420d2ecbd85e92c8671 --- /dev/null +++ b/orator/src/orator/models/t3/__init__.py @@ -0,0 +1 @@ +from .t3 import T3 diff --git a/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26e581a98c437c8f004bace5aed0f5e445248941 Binary files /dev/null and b/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc b/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee55c55d42c9253ee521f637f35cf3c61a6d9c5d Binary files /dev/null and b/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc b/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8604a9a811aaf1ff22e22df605bb7207bdb54820 Binary files /dev/null and b/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc b/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a3a360d772733fd0e439662f94c3eade57ce7db Binary files /dev/null and b/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/inference/t3_hf_backend.py b/orator/src/orator/models/t3/inference/t3_hf_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..8d2b175093074e9b8a566ce02a807de9804160a0 --- /dev/null +++ b/orator/src/orator/models/t3/inference/t3_hf_backend.py @@ -0,0 +1,116 @@ +from typing import Optional + +import torch +from torch import nn as nn +from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + +class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin): + """ + Override some HuggingFace interface methods so we can use the standard `generate` method with our + custom embedding / logit layers. + + NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights! + """ + + def __init__( + self, + config: LlamaConfig, + llama: LlamaModel, + *, + speech_enc, + speech_head, + latents_queue=None, + logits_queue=None, + ): + super().__init__(config) + self.model = llama + self.speech_enc = speech_enc + self.speech_head = speech_head + self.latents_queue = latents_queue + self.logits_queue = logits_queue + self._added_cond = False + + @torch.inference_mode() + def prepare_inputs_for_generation( + self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None, + # This argument was introduced in some recent version of transformers (>=4.29.1) + cache_position=None + ): + """ + This is a method used by huggingface's generate() method. + Overridden here to apply our custom speech token embedding layer. + + :param input_ids: (B, S) int64 tensors of input tokens. + :param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to ) + """ + + # Make use of the kv cache: only the last input ID is new, we trim away all the ones before + if not use_cache: + past_key_values = None + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + # custom speech token embedding layer + inputs_embeds = self.speech_enc(input_ids) + + # prefix decoder conditioning if applicable + if not self._added_cond: + assert past_key_values is not None # should be first step + if decoder_cond.size(0) != inputs_embeds.size(0): + decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1) + inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1) + self._added_cond = True + + return { + "inputs_embeds": inputs_embeds, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + @torch.inference_mode() + def forward( + self, + inputs_embeds: torch.Tensor, + past_key_values: Optional[torch.Tensor]=None, + use_cache=True, + output_attentions=False, + output_hidden_states=True, + return_dict=True, + ): + """ + This is a method used by huggingface's generate() method. + Overridden here to apply our custom layer norm and speech logit projection layers. + + :param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given, + S should be 1. + """ + is_large_input = inputs_embeds.size(1) != 1 + has_cache = past_key_values is not None and len(past_key_values) > 0 + assert not (is_large_input and has_cache) + assert return_dict + assert output_hidden_states + + tfmr_out = self.model( + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim) + if self.latents_queue is not None: + self.latents_queue.put(hidden_states) + + logits = self.speech_head(hidden_states) + if self.logits_queue is not None: + self.logits_queue.put(logits) + + return CausalLMOutputWithCrossAttentions( + logits=logits, + past_key_values=tfmr_out.past_key_values, + hidden_states=tfmr_out.hidden_states, + attentions=tfmr_out.attentions, + ) diff --git a/orator/src/orator/models/t3/llama_configs.py b/orator/src/orator/models/t3/llama_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..14d068161ddb38de613c92d54e34d1ed72261d40 --- /dev/null +++ b/orator/src/orator/models/t3/llama_configs.py @@ -0,0 +1,37 @@ +LLAMA_520M_CONFIG_DICT = dict( + # Arbitrary small number that won't cause problems when loading. + # These param are unused due to custom input layers. + vocab_size=8, + # default params needed for loading most pretrained 1B weights + max_position_embeddings=131072, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=30, + num_attention_heads=16, + attn_implementation="sdpa", + head_dim=64, + tie_word_embeddings=False, + hidden_act="silu", + attention_bias=False, + attention_dropout=0.0, + initializer_range=0.02, + mlp_bias=False, + model_type="llama", + num_key_value_heads=16, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3" + ), + rope_theta=500000.0, + torch_dtype="bfloat16", + use_cache=True, +) + +LLAMA_CONFIGS = { + "Llama_520M": LLAMA_520M_CONFIG_DICT, +} diff --git a/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc b/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c1cbde46e791ff61e0fe278140506bb2d299994 Binary files /dev/null and b/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc b/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5fac9d18536a99dd46ef5719c9aba8d574dc727 Binary files /dev/null and b/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc b/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39d8bd691b40e87ee83e217913874321f04bfcad Binary files /dev/null and b/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc b/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d74a6c5dcf3897b1fcac425735a289fa3901b6e7 Binary files /dev/null and b/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/modules/cond_enc.py b/orator/src/orator/models/t3/modules/cond_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f15c685783fbb048f6c0e86fc2ea8fbf1ec3de --- /dev/null +++ b/orator/src/orator/models/t3/modules/cond_enc.py @@ -0,0 +1,97 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn, Tensor + +from .perceiver import Perceiver +from .t3_config import T3Config + + +@dataclass +class T3Cond: + """ + Dataclass container for most / all conditioning info. + TODO: serialization methods aren't used, keeping them around for convenience + """ + + speaker_emb: Tensor + clap_emb: Optional[Tensor] = None + cond_prompt_speech_tokens: Optional[Tensor] = None + cond_prompt_speech_emb: Optional[Tensor] = None + emotion_adv: Optional[Tensor] = 0.5 + + def to(self, *, device=None, dtype=None): + "Cast to a device and dtype. Dtype casting is ignored for long/int tensors." + for k, v in self.__dict__.items(): + if torch.is_tensor(v): + is_fp = type(v.view(-1)[0].item()) is not int + setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None)) + return self + + def save(self, fpath): + torch.save(self.__dict__, fpath) + + @staticmethod + def load(fpath, map_location="cpu"): + kwargs = torch.load(fpath, map_location=map_location, weights_only=True) + return T3Cond(**kwargs) + + +class T3CondEnc(nn.Module): + """ + Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc. + """ + + def __init__(self, hp: T3Config): + super().__init__() + self.hp = hp + if hp.encoder_type == "voice_encoder": + self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels) + else: + raise NotImplementedError(str(hp.encoder_type)) + + # emotion adv + self.emotion_adv_fc = None + if hp.emotion_adv: + self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False) + + # perceiver resampler + self.perceiver = None + if hp.use_perceiver_resampler: + self.perceiver = Perceiver() + + def forward(self, cond: T3Cond): + # Validate + assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \ + "no embeddings for cond_prompt_speech_tokens" + + # Speaker embedding projection + cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim) + empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim) + + # TODO CLAP + assert cond.clap_emb is None, "clap_embed not implemented" + cond_clap = empty # (B, 0, dim) + + # Cond prompt + cond_prompt_speech_emb = cond.cond_prompt_speech_emb + if cond_prompt_speech_emb is None: + cond_prompt_speech_emb = empty # (B, 0, dim) + elif self.hp.use_perceiver_resampler: + cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb) + + # Emotion Adv: must provide a value if this model uses emotion conditioning + cond_emotion_adv = empty # (B, 0, dim) + if self.hp.emotion_adv: + assert cond.emotion_adv is not None + cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1)) + + # Concat and return + cond_embeds = torch.cat(( + cond_spkr, + cond_clap, + cond_prompt_speech_emb, + cond_emotion_adv, + ), dim=1) + return cond_embeds diff --git a/orator/src/orator/models/t3/modules/learned_pos_emb.py b/orator/src/orator/models/t3/modules/learned_pos_emb.py new file mode 100644 index 0000000000000000000000000000000000000000..9b197f218192688f743a904676d66ff741eb33e3 --- /dev/null +++ b/orator/src/orator/models/t3/modules/learned_pos_emb.py @@ -0,0 +1,32 @@ +from typing import Union + +import torch +from torch import nn, Tensor + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=.02): + super().__init__() + self.emb = nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + + def forward(self, x): + """ + Returns positional embeddings for index 0 up to the length of x + """ + sl = x.shape[1] + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, idx: 'Union[int, Tensor]'): + """ + Args: + idx: scalar int or an integer tensor of shape (T,) or (B, T) + Returns: + positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input + """ + device = self.emb.weight.device + idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device) + idx = torch.atleast_2d(idx) + assert idx.ndim == 2 + return self.emb(idx) # (B, T, dim) diff --git a/orator/src/orator/models/t3/modules/perceiver.py b/orator/src/orator/models/t3/modules/perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa4b87c65832d380741c332c0a8288a4e8a9854 --- /dev/null +++ b/orator/src/orator/models/t3/modules/perceiver.py @@ -0,0 +1,208 @@ +import math + +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + + +class RelativePositionBias(nn.Module): + def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): + super().__init__() + self.scale = scale + self.causal = causal + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + if not causal: + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, qk_dots): + i, j, device = *qk_dots.shape[-2:], qk_dots.device + q_pos = torch.arange(i, dtype=torch.long, device=device) + k_pos = torch.arange(j, dtype=torch.long, device=device) + rel_pos = k_pos[None, :] - q_pos[:, None] + rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, + max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + bias = rearrange(values, 'i j h -> () h i j') + return qk_dots + (bias * self.scale) + + +class AttentionQKV(nn.Module): + def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False): + super().__init__() + self.n_heads = n_heads + self.head_dim = head_dim + self.scale = scale if scale is not None else head_dim ** -0.5 + self.flash = flash + self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(dropout_rate) + self.flash_config = self.setup_flash_config() if flash else None + + def setup_flash_config(self): + # Setup flash attention configuration + flash_config = { + 'enable_flash': True, + 'enable_math': True, + 'enable_mem_efficient': True + } + return flash_config + + def forward(self, q, k, v, mask=None): + q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]] + if self.flash: + out = self.flash_attention(q, k, v, mask=mask) + else: + out = self.scaled_dot_product_attention(q, k, v, mask=mask) + + return self.combine_heads(out) + + def scaled_dot_product_attention(self, q, k, v, mask=None): + sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale + if mask is not None: + sim = sim.masked_fill(mask == 0, float('-inf')) + attn = torch.softmax(sim, dim=-1) + attn = self.dropout(attn) + return torch.einsum("bhts,bhls->bhlt", attn, v) + + def flash_attention(self, q, k, v, mask=None): + config = self.flash_config if self.flash_config else {} + with torch.backends.cuda.sdp_kernel(**config): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask=mask, + dropout_p=self.dropout_rate if self.training else 0. + ) + return out + + def split_heads(self, x): + bs, length, _ = x.shape + x = x.view(bs, length, self.n_heads, self.head_dim) + return x.permute(0, 2, 1, 3) + + def combine_heads(self, x): + bs, _, length, _ = x.shape + x = x.permute(0, 2, 1, 3).contiguous() + return x.view(bs, length, -1) + + +class AttentionBlock2(nn.Module): + """ + An attention block that allows spatial positions to attend to each other, + using AttentionQKV and separate linear transformations for Q, K, and V. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + relative_pos_embeddings=False, + flash_attention=True, + dropout_rate=0.2, + scale=None + ): + super().__init__() + self.channels = channels + + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + + self.norm = nn.LayerNorm(channels) + + # Separate linear layers for Q, K, and V + self.to_q = nn.Linear(channels, channels) + self.to_k = nn.Linear(channels, channels) + self.to_v = nn.Linear(channels, channels) + + self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale) + + self.proj_out = nn.Linear(channels, channels) + + if relative_pos_embeddings: + self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) + else: + self.relative_pos_embeddings = None + + def forward(self, x1, x2, mask=None): + b1, c1, *spatial1 = x1.shape + b2, c2, *spatial2 = x2.shape + + x1_norm = self.norm(x1) + x2_norm = self.norm(x2) + + q = self.to_q(x1_norm) + k = self.to_k(x2_norm) + v = self.to_v(x2_norm) + + h = self.attention(q, k, v, mask=mask) + h = self.proj_out(h) + + return (x1 + h).reshape(b1, c1, *spatial1) + + +class Perceiver(nn.Module): + def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4): + """ + Initialize the perceiver module. + + :param pre_attention_query_token: Number of query tokens for pre-attention + :param pre_attention_query_size: Size of each query token + :param embedding_dim: Dimension of the embedding space + :param num_attn_heads: Number of attention heads + """ + super().__init__() + + # Initialize the pre-attention query parameter + self.pre_attention_query = torch.nn.Parameter( + torch.empty(1, pre_attention_query_token, pre_attention_query_size) + ) + + # Calculate the variance for uniform initialization + query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token)) + + # Initialize the pre-attention query with uniform distribution + self.pre_attention_query.data.uniform_(-query_variance, query_variance) + + # Initialize the attention block + self.attn = AttentionBlock2(embedding_dim, num_attn_heads) + + def forward(self, h): + """ + Forward pass of the perceiver module. + :param h: Input tensor + :return: Output after applying attention mechanisms + """ + # Expand the pre-attention query to match the batch size of the input + query_ = self.pre_attention_query.expand(h.shape[0], -1, -1) + # Apply the first attention mechanism (cross-attention) + pre_att = self.attn(query_, h) + # Apply the second attention mechanism (self-attention) + attn = self.attn(pre_att, pre_att) + return attn diff --git a/orator/src/orator/models/t3/modules/t3_config.py b/orator/src/orator/models/t3/modules/t3_config.py new file mode 100644 index 0000000000000000000000000000000000000000..38c9fafc12a875759a43e23aa955ab135a136b7e --- /dev/null +++ b/orator/src/orator/models/t3/modules/t3_config.py @@ -0,0 +1,27 @@ +from ..llama_configs import LLAMA_CONFIGS + + +class T3Config: + start_text_token = 255 + stop_text_token = 0 + text_tokens_dict_size = 704 + max_text_tokens = 2048 + + start_speech_token = 6561 + stop_speech_token = 6562 + speech_tokens_dict_size = 6563 + max_speech_tokens = 4096 + + llama_config_name = "Llama_520M" + input_pos_emb = "learned" + speech_cond_prompt_len = 150 + + # For T3CondEnc + encoder_type = "voice_encoder" + speaker_embed_size = 256 + use_perceiver_resampler = True + emotion_adv = True + + @property + def n_channels(self): + return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"] \ No newline at end of file diff --git a/orator/src/orator/models/t3/t3.py b/orator/src/orator/models/t3/t3.py new file mode 100644 index 0000000000000000000000000000000000000000..39978dfa8588f8d7bbcf2ea639c739119503708e --- /dev/null +++ b/orator/src/orator/models/t3/t3.py @@ -0,0 +1,276 @@ +import logging +from typing import Union, Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from transformers import LlamaModel, LlamaConfig + +from .modules.learned_pos_emb import LearnedPositionEmbeddings + +from .modules.cond_enc import T3CondEnc, T3Cond +from .modules.t3_config import T3Config +from .inference.t3_hf_backend import T3HuggingfaceBackend +from .llama_configs import LLAMA_CONFIGS + + +logger = logging.getLogger(__name__) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def _ensure_BOT_EOT(text_tokens: Tensor, hp): + B = text_tokens.size(0) + assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token" + assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token" + + +class T3(nn.Module): + """ + Token-To-Token (T3) TTS model using huggingface transformer models as backbones, + * tokenization, including start / stop tokens are always added externally to this class + * conditioning data like CLAP, emotion, etc are all in a separate file for more modularity + * careful! this class assumes relative positional encoding -- with absolute PE, we would at + least want to reset the position to 0 when speech tokens begin, and optionally use a + different PE embedding space for speech. + """ + + def __init__(self, hp=T3Config()): + super().__init__() + self.hp = hp + self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name]) + self.tfmr = LlamaModel(self.cfg) + self.dim = self.cfg.hidden_size + self.deepspeed_patch_applied = False + + # conditioning / embedding + self.cond_enc = T3CondEnc(hp) + self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim) + self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim) + + # custom position embedding + if hp.input_pos_emb == "learned": + max_text_seq_len = hp.max_text_tokens + 2 + self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim) + + max_mel_seq_len = hp.max_speech_tokens + 2 + 2 + self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim) + + # logit projection + self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False) + self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False) + self.compiled = False + + @property + def device(self): + return self.speech_head.weight.device + + def prepare_conditioning(self, t3_cond: T3Cond): + """ + Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`. + """ + if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None: + t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \ + self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens) + return self.cond_enc(t3_cond) # (B, len_cond, dim) + + def prepare_input_embeds( + self, + *, + t3_cond: T3Cond, + text_tokens: torch.LongTensor, + speech_tokens: torch.LongTensor, + ): + # prepare input embeddings (skip backbone tranformer embeddings) + cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim) + text_emb = self.text_emb(text_tokens) # (B, len_text, dim) + speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim) + if self.hp.input_pos_emb == "learned": + text_emb = text_emb + self.text_pos_emb(text_tokens) + speech_emb = speech_emb + self.speech_pos_emb(speech_tokens) + len_cond = cond_emb.size(1) + + if cond_emb.size(0) != text_emb.size(0): + cond_emb = cond_emb.expand(text_emb.size(0), -1, -1) + + # concat + embeds = torch.stack([ + torch.cat((ce, te, se)) + for ce, te, se in zip(cond_emb, text_emb, speech_emb) + ]) # (B, length, dim) + return embeds, len_cond + + def forward( + self, + *, + t3_cond: T3Cond, + text_tokens: torch.LongTensor, + text_token_lens: torch.LongTensor, + speech_tokens: torch.LongTensor, + speech_token_lens: torch.LongTensor, + training=False, + ): + _ensure_BOT_EOT(text_tokens, self.hp) + + # prepare custom input embeds + embeds, len_cond = self.prepare_input_embeds( + t3_cond=t3_cond, + text_tokens=text_tokens, + speech_tokens=speech_tokens, + ) + + # backbone tranformer forward + tfmr_out = self.tfmr.forward( + input_ids=None, + # position_ids=position_ids, # TODO? ROPE should be fine? + inputs_embeds=embeds, + output_hidden_states=True, + return_dict=True, + use_cache=(not training), + ) + hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim) + + # post-processing: splice out text and speech parts of hidden states + len_text = text_tokens.size(1) + len_speech = speech_tokens.size(1) + B, _, dim = hidden_states.shape + device, dtype = hidden_states.device, hidden_states.dtype + text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device) + speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device) + ttl, stl = text_token_lens, speech_token_lens + for i in range(B): + text_end = len_cond + ttl[i].item() + speech_start = len_cond + text_tokens.size(1) + speech_end = speech_start + stl[i].item() + text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end] + speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end] + + # logit projection + text_logits = self.text_head(text_latents) + speech_logits = self.speech_head(speech_latents) + + return AttrDict( + text_logits=text_logits, + text_latents=text_latents, + speech_logits=speech_logits, + speech_latents=speech_latents, + hidden_states=hidden_states, + ) + + def loss( + self, + *, + t3_cond: T3Cond, + text_tokens: torch.LongTensor, + text_token_lens: torch.LongTensor, + speech_tokens: torch.LongTensor, + speech_token_lens: torch.LongTensor, + ): + "training method" + len_text = text_tokens.size(1) + len_speech = speech_tokens.size(1) + assert len_text == text_token_lens.max() + assert len_speech == speech_token_lens.max() + + out = self.forward( + t3_cond=t3_cond, + text_tokens=text_tokens, + text_token_lens=text_token_lens, + speech_tokens=speech_tokens, + speech_token_lens=speech_token_lens, + training=True, + ) # (B, seq, vocab_size) + + # Calc CCE losses + IGNORE_ID = -100 + device = out.text_logits.device + mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text) + mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech) + masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID) + masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID) + loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID) + loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID) + + return loss_text, loss_speech + + @torch.inference_mode() + def inference( + self, + *, + t3_cond: T3Cond, + text_tokens: Tensor, + initial_speech_tokens: Optional[Tensor]=None, + + # misc conditioning + prepend_prompt_speech_tokens: Optional[Tensor]=None, + + # HF generate args + num_return_sequences=1, + max_new_tokens=None, + stop_on_eos=True, + do_sample=True, + temperature=0.8, + top_p=0.8, + length_penalty=1.0, + repetition_penalty=2.0, + ): + """ + Args: + text_tokens: a 1D (unbatched) or 2D (batched) tensor. + tokens_queue: if a ReferenceQueue is provided, tokens will be streamed to it during generation. + latents_queue: if a ReferenceQueue is provided, latents will be streamed to it during generation. + logits_queue: if a ReferenceQueue is provided, logits will be streamed to it during generation. + """ + # Validate / sanitize inputs + assert prepend_prompt_speech_tokens is None, "not implemented" + _ensure_BOT_EOT(text_tokens, self.hp) + text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device) + + # Default initial speech to a single start-of-speech token + if initial_speech_tokens is None: + initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1]) + + # Prepare custom input embeds + embeds, _ = self.prepare_input_embeds( + t3_cond=t3_cond, + text_tokens=text_tokens, + speech_tokens=initial_speech_tokens, + ) + + # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic + # Note the llama-specific logic. Other tfmr types can be added later. + + self.compiled = False + + # TODO? synchronize the expensive compile function + # with self.compile_lock: + if not self.compiled: + patched_model = T3HuggingfaceBackend( + config=self.cfg, + llama=self.tfmr, + speech_enc=self.speech_emb, + speech_head=self.speech_head, + ) + self.patched_model = patched_model + self.compiled = True + + # Run normal generate method, which calls our custom extended methods + return self.patched_model.generate( + inputs=initial_speech_tokens, + decoder_cond=embeds, + bos_token_id=self.hp.start_speech_token, + eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1), + pad_token_id=self.hp.stop_speech_token, + max_new_tokens=max_new_tokens or self.hp.max_speech_tokens, + num_return_sequences=num_return_sequences, + temperature=temperature, + top_p=top_p, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + do_sample=do_sample, + # cache_implementation=None if not self.compiled else "static", + ) diff --git a/orator/src/orator/models/tokenizers/__init__.py b/orator/src/orator/models/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97457e6fd720a10b2c64d2cdbabce9ca5fbf9aad --- /dev/null +++ b/orator/src/orator/models/tokenizers/__init__.py @@ -0,0 +1 @@ +from .tokenizer import EnTokenizer diff --git a/orator/src/orator/models/tokenizers/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/tokenizers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c275725d8ec6d80c237b139479031d3ede3646b9 Binary files /dev/null and b/orator/src/orator/models/tokenizers/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc b/orator/src/orator/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..835bf5816abdb64bc97371e9f597a618a6d337ad Binary files /dev/null and b/orator/src/orator/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc differ diff --git a/orator/src/orator/models/tokenizers/tokenizer.py b/orator/src/orator/models/tokenizers/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3536bc24db7d37cca9faff11c064c2c5d7c1c64 --- /dev/null +++ b/orator/src/orator/models/tokenizers/tokenizer.py @@ -0,0 +1,50 @@ +import logging + +import torch +from tokenizers import Tokenizer + + +# Special tokens +SOT = "[START]" +EOT = "[STOP]" +UNK = "[UNK]" +SPACE = "[SPACE]" +SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"] + +logger = logging.getLogger(__name__) + +class EnTokenizer: + def __init__(self, vocab_file_path): + self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) + self.check_vocabset_sot_eot() + + def check_vocabset_sot_eot(self): + voc = self.tokenizer.get_vocab() + assert SOT in voc + assert EOT in voc + + def text_to_tokens(self, text: str): + text_tokens = self.encode(text) + text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) + return text_tokens + + def encode( self, txt: str, verbose=False): + """ + clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer + """ + txt = txt.replace(' ', SPACE) + code = self.tokenizer.encode(txt) + ids = code.ids + return ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + + txt: str = self.tokenizer.decode(seq, + skip_special_tokens=False) + txt = txt.replace(' ', '') + txt = txt.replace(SPACE, ' ') + txt = txt.replace(EOT, '') + txt = txt.replace(UNK, '') + return txt diff --git a/orator/src/orator/models/voice_encoder/__init__.py b/orator/src/orator/models/voice_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..529e1e63e89f179ec06829bfc5f1afc80912433f --- /dev/null +++ b/orator/src/orator/models/voice_encoder/__init__.py @@ -0,0 +1 @@ +from .voice_encoder import VoiceEncoder, VoiceEncConfig diff --git a/orator/src/orator/models/voice_encoder/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/voice_encoder/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..734cd8a35d12eb71f0f1c29d55bbb39347e42839 Binary files /dev/null and b/orator/src/orator/models/voice_encoder/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc b/orator/src/orator/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c93d77f6ac7f26d75b260410a5cc2b90b54b5fa5 Binary files /dev/null and b/orator/src/orator/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc differ diff --git a/orator/src/orator/models/voice_encoder/voice_encoder.py b/orator/src/orator/models/voice_encoder/voice_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..68b398d113eedd4f77a91ebd143389aa2be69b15 --- /dev/null +++ b/orator/src/orator/models/voice_encoder/voice_encoder.py @@ -0,0 +1,256 @@ +# Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning +# MIT License + +from typing import List, Union, Optional + +import numpy as np +from numpy.lib.stride_tricks import as_strided +import librosa +from librosa import resample +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from ....orator.transforms.spectrogram import melspectrogram +from ....orator.transforms.syn_transforms import pack + + +class VoiceEncConfig: + num_mels = 40 + sample_rate = 16000 + speaker_embed_size = 256 + ve_hidden_size = 256 + flatten_lstm_params = False + n_fft = 400 + hop_size = 160 + win_size = 400 + fmax = 8000 + fmin = 0 + preemphasis = 0. + mel_power = 2.0 + mel_type = "amp" + normalized_mels = False + ve_partial_frames = 160 + ve_final_relu = True + + +def get_num_wins( + n_frames: int, + step: int, + min_coverage: float, + hp: VoiceEncConfig, +): + assert n_frames > 0 + win_size = hp.ve_partial_frames + n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step) + if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage: + n_wins += 1 + target_n = win_size + step * (n_wins - 1) + return n_wins, target_n + + +def get_frame_step( + overlap: float, + rate: float, + hp: VoiceEncConfig, +): + # Compute how many frames separate two partial utterances + assert 0 <= overlap < 1 + if rate is None: + frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap))) + else: + frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames)) + assert 0 < frame_step <= hp.ve_partial_frames + return frame_step + + +def stride_as_partials( + mel: np.ndarray, + hp: VoiceEncConfig, + overlap=0.5, + rate: float=None, + min_coverage=0.8, +): + """ + Takes unscaled mels in (T, M) format + TODO: doc + """ + assert 0 < min_coverage <= 1 + frame_step = get_frame_step(overlap, rate, hp) + + # Compute how many partials can fit in the mel + n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp) + + # Trim or pad the mel spectrogram to match the number of partials + if target_len > len(mel): + mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0))) + elif target_len < len(mel): + mel = mel[:target_len] + + # Ensure the numpy array data is float32 and contiguous in memory + mel = mel.astype(np.float32, order="C") + + # Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother, + # where N is the number of partials, P is the number of frames of each partial and M the + # number of channels of the mel spectrograms. + shape = (n_partials, hp.ve_partial_frames, hp.num_mels) + strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1]) + partials = as_strided(mel, shape, strides) + return partials + + +class VoiceEncoder(nn.Module): + def __init__(self, hp=VoiceEncConfig()): + super().__init__() + + self.hp = hp + + # Network definition + self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True) + if hp.flatten_lstm_params: + self.lstm.flatten_parameters() + self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size) + + # Cosine similarity scaling (fixed initial parameter values) + self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True) + self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, mels: torch.FloatTensor): + """ + Computes the embeddings of a batch of partial utterances. + + :param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor + of shape (B, T, M) where T is hp.ve_partial_frames + :return: the embeddings as a float32 tensor of shape (B, E) where E is + hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1]. + """ + if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1): + raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}") + + # Pass the input through the LSTM layers + _, (hidden, _) = self.lstm(mels) + + # Project the final hidden state + raw_embeds = self.proj(hidden[-1]) + if self.hp.ve_final_relu: + raw_embeds = F.relu(raw_embeds) + + # L2 normalize the embeddings. + return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True) + + def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None): + """ + Computes the embeddings of a batch of full utterances with gradients. + + :param mels: (B, T, M) unscaled mels + :return: (B, E) embeddings on CPU + """ + mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens + + # Compute where to split the utterances into partials + frame_step = get_frame_step(overlap, rate, self.hp) + n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens)) + + # Possibly pad the mels to reach the target lengths + len_diff = max(target_lens) - mels.size(1) + if len_diff > 0: + pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32) + mels = torch.cat((mels, pad.to(mels.device)), dim=1) + + # Group all partials together so that we can batch them easily + partials = [ + mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames] + for mel, n_partial in zip(mels, n_partials) for i in range(n_partial) + ] + assert all(partials[0].shape == partial.shape for partial in partials) + partials = torch.stack(partials) + + # Forward the partials + n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials)))) + partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu() + + # Reduce the partial embeds into full embeds and L2-normalize them + slices = np.concatenate(([0], np.cumsum(n_partials))) + raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])] + raw_embeds = torch.stack(raw_embeds) + embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True) + + return embeds + + @staticmethod + def utt_to_spk_embed(utt_embeds: np.ndarray): + """ + Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a + speaker embedding. + """ + assert utt_embeds.ndim == 2 + utt_embeds = np.mean(utt_embeds, axis=0) + return utt_embeds / np.linalg.norm(utt_embeds, 2) + + @staticmethod + def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray): + """ + Cosine similarity for L2-normalized utterance embeddings or speaker embeddings + """ + embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x) + embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y) + return embeds_x @ embeds_y + + def embeds_from_mels( + self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs + ): + """ + Convenience function for deriving utterance or speaker embeddings from mel spectrograms. + + :param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays. + :param mel_lens: if passing mels as a tensor, individual mel lengths + :param as_spk: whether to return utterance embeddings or a single speaker embedding + :param kwargs: args for inference() + + :returns: embeds as a (B, E) float32 numpy array if is False, else as a (E,) array + """ + # Load mels in memory and pack them + if isinstance(mels, List): + mels = [np.asarray(mel) for mel in mels] + assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format" + mel_lens = [mel.shape[0] for mel in mels] + mels = pack(mels) + + # Embed them + with torch.inference_mode(): + utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy() + + return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds + + def embeds_from_wavs( + self, + wavs: List[np.ndarray], + sample_rate, + as_spk=False, + batch_size=32, + trim_top_db: Optional[float]=20, + **kwargs + ): + """ + Wrapper around embeds_from_mels + + :param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation + """ + if sample_rate != self.hp.sample_rate: + wavs = [ + resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast") + for wav in wavs + ] + + if trim_top_db: + wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs] + + if "rate" not in kwargs: + kwargs["rate"] = 1.3 # Resemble's default value. + + mels = [melspectrogram(w, self.hp).T for w in wavs] + return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs) diff --git a/orator/src/orator/transforms/__pycache__/spectrogram.cpython-311.pyc b/orator/src/orator/transforms/__pycache__/spectrogram.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3a682cb6495b552f14d322b779476cf250fc1d0 Binary files /dev/null and b/orator/src/orator/transforms/__pycache__/spectrogram.cpython-311.pyc differ diff --git a/orator/src/orator/transforms/__pycache__/syn_transforms.cpython-311.pyc b/orator/src/orator/transforms/__pycache__/syn_transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..037b45e08c6f8d5db21241c609a0e8212fe10d2f Binary files /dev/null and b/orator/src/orator/transforms/__pycache__/syn_transforms.cpython-311.pyc differ diff --git a/orator/src/orator/transforms/spectrogram.py b/orator/src/orator/transforms/spectrogram.py new file mode 100644 index 0000000000000000000000000000000000000000..69147fc8c591c9364ff829a157af0ea3fcbd5770 --- /dev/null +++ b/orator/src/orator/transforms/spectrogram.py @@ -0,0 +1,78 @@ +from functools import lru_cache + +from scipy import signal +import numpy as np +import librosa + + +@lru_cache() +def mel_basis(hp): + assert hp.fmax <= hp.sample_rate // 2 + return librosa.filters.mel( + sr=hp.sample_rate, + n_fft=hp.n_fft, + n_mels=hp.num_mels, + fmin=hp.fmin, + fmax=hp.fmax) # -> (nmel, nfreq) + + +def preemphasis(wav, hp): + assert hp.preemphasis != 0 + wav = signal.lfilter([1, -hp.preemphasis], [1], wav) + wav = np.clip(wav, -1, 1) + return wav + + +def melspectrogram(wav, hp, pad=True): + # Run through pre-emphasis + if hp.preemphasis > 0: + wav = preemphasis(wav, hp) + assert np.abs(wav).max() - 1 < 1e-07 + + # Do the stft + spec_complex = _stft(wav, hp, pad=pad) + + # Get the magnitudes + spec_magnitudes = np.abs(spec_complex) + + if hp.mel_power != 1.0: + spec_magnitudes **= hp.mel_power + + # Get the mel and convert magnitudes->db + mel = np.dot(mel_basis(hp), spec_magnitudes) + if hp.mel_type == "db": + mel = _amp_to_db(mel, hp) + + # Normalise the mel from db to 0,1 + if hp.normalized_mels: + mel = _normalize(mel, hp).astype(np.float32) + + assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check + return mel # (M, T) + + +def _stft(y, hp, pad=True): + # NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for + # historical consistency and streaming-version consistency + return librosa.stft( + y, + n_fft=hp.n_fft, + hop_length=hp.hop_size, + win_length=hp.win_size, + center=pad, + pad_mode="reflect", + ) + + +def _amp_to_db(x, hp): + return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x)) + + +def _db_to_amp(x): + return np.power(10.0, x * 0.05) + + +def _normalize(s, hp, headroom_db=15): + min_level_db = 20 * np.log10(hp.stft_magnitude_min) + s = (s - min_level_db) / (-min_level_db + headroom_db) + return s diff --git a/orator/src/orator/transforms/syn_transforms.py b/orator/src/orator/transforms/syn_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..13ce597ae05503ef618b2de9b6c7b833f94409cb --- /dev/null +++ b/orator/src/orator/transforms/syn_transforms.py @@ -0,0 +1,46 @@ +# Common transformations used by synthesizers +import logging + +import numpy as np +import torch + + +logger = logging.getLogger(__name__) + + +def pack(arrays, seq_len: int=None, pad_value=0): + """ + Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of + shape (B, T, ...) by padding each individual array on the right. + + :param arrays: a list of array-like objects of matching shapes except for the first axis. + :param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at + minimum. Will default to that value if None. + :param pad_value: the value to pad the arrays with. + :return: a (B, T, ...) tensor + """ + if seq_len is None: + seq_len = max(len(array) for array in arrays) + else: + assert seq_len >= max(len(array) for array in arrays) + + # Convert lists to np.array + if isinstance(arrays[0], list): + arrays = [np.array(array) for array in arrays] + + # Convert to tensor and handle device + device = None + if isinstance(arrays[0], torch.Tensor): + tensors = arrays + device = tensors[0].device + else: + tensors = [torch.as_tensor(array) for array in arrays] + + # Fill the packed tensor with the array data + packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:]) + packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device) + + for i, tensor in enumerate(tensors): + packed_tensor[i, :tensor.size(0)] = tensor + + return packed_tensor diff --git a/orator/src/orator/transforms/webrtc.py b/orator/src/orator/transforms/webrtc.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d3abf97f27ac7f5c51ea7228920a0c522ed934 --- /dev/null +++ b/orator/src/orator/transforms/webrtc.py @@ -0,0 +1,181 @@ +from itertools import groupby + +import numpy as np +import webrtcvad as _webrtcvad + +from transforms.vad.vad_stream import VADStream +from transforms.wav_encoding import encode_pcm16 + +# The sample rate the algo can operate at +_WEBRTC_SAMPLE_RATES = np.array([8000, 16000, 32000, 48000]) +# The algo operates with window sizes of 10, 20 and 30ms +_WEBRTC_WINDOW_SIZES_MS = (10, 20, 30) +# Greatest common divisor and lowest common multiple of the above +_WEBRTC_WINDOW_SIZES_MS_GCD = 10 +_WEBRTC_WINDOW_SIZES_MS_LCM = 60 + + +class WebRTCVADStream(VADStream): + def __init__(self, sample_rate: int, aggressiveness=2, dilation_ms=40, min_voiced_region_ms=125): + """ + :param sample_rate: sample rate of the wavs that will be passed + :param aggressiveness: parameter for controlling the aggressiveness of the VAD algo. Possible values are 1, + 2 and 3. Higher means less regions will be detected as voiced. + :param dilation_ms: pass a value greater than 0 to include regions directly preceding or succeeding voiced + regions. Voiced regions will be expanded left and right by this value, in milliseconds. + N.B.: this is a best effort parameter. When the output is requested as fast as the input is produced, + it's impossible to foresee an upcoming voiced region. In that case, the dilation on the left of that region + may not appear. + :param min_voiced_region_ms: to exclude regions detected as speech that are considered too short, pass a value + greater than 0. Voiced regions shorter than this value (prior to dilation) will be set as unvoiced. + N.B.: this is also a best effort parameter. A region may be too short, but because VAD has not finished + being computed at the end of that region, it won't be removed as it could potentially be large enough. + """ + webrtc_sr = int(_WEBRTC_SAMPLE_RATES[np.argmin(np.abs(_WEBRTC_SAMPLE_RATES - sample_rate))]) + lcm_win_size = (_WEBRTC_WINDOW_SIZES_MS_LCM * webrtc_sr) // 1000 + self._gcd_win_size = (_WEBRTC_WINDOW_SIZES_MS_GCD * webrtc_sr) // 1000 + + # webrtcvad.Vad is stateful, predictions will be impacted if a new instance is created halfway through an + # audio. This is why we create them now. + self._detectors = {win_size: _webrtcvad.Vad(mode=aggressiveness) for win_size in _WEBRTC_WINDOW_SIZES_MS} + + super().__init__(sample_rate, webrtc_sr, lcm_win_size, dilation_ms, min_voiced_region_ms) + + def _wav_vad(self, wav: np.ndarray) -> np.ndarray: + pcm = encode_pcm16(wav) + + # Perform the VAD by ensembling the different window sizes + win_vad = np.zeros(len(wav) // self._gcd_win_size, dtype=np.int32) + for sub_win_size_ms in _WEBRTC_WINDOW_SIZES_MS: + detector = self._detectors[sub_win_size_ms] + sub_win_size_pcm = (2 * sub_win_size_ms * self.vad_sr) // 1000 + factor = sub_win_size_ms // _WEBRTC_WINDOW_SIZES_MS_GCD + + for i, win_start in enumerate(range(0, len(pcm), sub_win_size_pcm)): + win_i_vad = detector.is_speech(pcm[win_start:win_start + sub_win_size_pcm], self.vad_sr) + win_vad[i * factor:(i + 1) * factor] += win_i_vad + win_vad = win_vad > (len(_WEBRTC_WINDOW_SIZES_MS) // 2) + + # Convert the output to regions + regions = np.diff(win_vad, prepend=0, append=0).nonzero()[0].reshape(-1, 2) + regions = regions * (len(wav) // len(win_vad)) + + return regions + + +def webrtc_vad(wav: np.ndarray, source_sr: int, aggressiveness=2, dilation_ms=40, min_voiced_region_ms=125): + """ + Peforms Voice Activation Detection on a single audio. See WebrtcVADStream for more details. + + :return vad: a boolean numpy array of length equal to + """ + vad_stream = WebRTCVADStream(source_sr, aggressiveness, dilation_ms, min_voiced_region_ms) + vad_stream.feed(wav, close_input=True) + if vad_stream.can_step(): + return vad_stream.step(len(wav)) + else: + return np.zeros_like(wav, dtype=bool) + + +def split_on_silence( + wav, sr, vad, thresholds_ms=[500, 300, 200, 100, 50], min_dur_s=1.5, max_split_dur_s=20, max_dur_s=30, +): + """ + Split a wav into chunks, splitting on silence when the length of the silence exceeds a threshold. + Args: + wav: 1d-array + sr: sample rate + thresholds_ms: min length of silence to split on, clips are recursively split using values from this list until + the resulting chunks are all within the min / max duration bounds + min_dur_s: minimum duration of a chunk in seconds + max_split_dur_s: segments above this length are continue to be split down with smaller thesholds + max_dur_s: maximum duration of a chunk in seconds + """ + assert isinstance(wav, np.ndarray) and wav.ndim == 1 + + # unpack silence length thresholds + thresh_ms, next_thresh_ms = (thresholds_ms + [0, 0])[:2] + if thresh_ms <= 0: + return [wav] + + # convert thresholds to samples + max_split_dur_s = min(max_split_dur_s, max_dur_s) + thresh = int(thresh_ms * sr / 1000) + min_len = int(min_dur_s * sr) + max_split_len = int(max_split_dur_s * sr) + max_len = int(max_dur_s * sr) + wav_len = len(wav) + + # detect regions of silence using groupby + sil_regions = [] + for is_voiced, idxs in groupby(range(wav_len), key=vad.__getitem__): + idxs = list(idxs) + i = idxs[0] + j = idxs[-1] + j += 1 + n = j - i + mid = (i + j) // 2 + + # record split point if this is a long silence region + if (not is_voiced) and n > thresh: + sil_regions += [( + min(mid, i + (0 if i == 0 else thresh // 2)), + max(mid, j - (0 if j == wav_len else thresh // 2)), + )] + + # invert silence regions to get voiced regions + ptr = 0 + voiced_regions = [] + for i, j in sil_regions: + if i > 0: + voiced_regions += [(ptr, i)] + ptr = j + if ptr < wav_len: + voiced_regions += [(ptr, wav_len)] + + # split the waveform into chunks using the detected content bounds and silence split points + chunks = [] + for i, j in voiced_regions: + chunk = wav[i:j] + chunklen = len(chunk) + + # chunk is within bounds + if chunklen < max_split_len: + chunks += [chunk] + + # chunk is too long, attempt to split it recursively using threshold list + elif next_thresh_ms > 0: + chunks += split_on_silence( + chunk, sr, vad[i:j], thresholds_ms=thresholds_ms[1:], + min_dur_s=min_dur_s, max_dur_s=max_dur_s, + ) + + # NOTE: keeping chunks longer than `max_len` here, filtering is done below + else: + chunks += [chunk] + + # merge short chunks + merged_chunks = [] + for chunk in chunks: + chunklen = len(chunk) + + # chunk is too short, add it to the previous chunk if possible + if chunklen == 0: + continue + + elif chunklen < min_len: + # NOTE: ignore the edge case where this would make the previous chunk too long, by just dropping this chunk + if len(merged_chunks) > 0 and len(merged_chunks[-1]) + chunklen < max_len: + merged_chunks[-1] = np.concatenate([merged_chunks[-1], chunk]) + + elif chunklen < max_len: + merged_chunks += [chunk] + + else: + # TODO: keep long chunks as well? one benefit is to keep the adjascent ordering of chunks, for + # building paragraph-level datasets. However, this should rarely drop any clips, so it's probably okay. + # merged_chunks += [chunk] + pass + chunks = merged_chunks + + return chunks diff --git a/orator/src/orator/tts.py b/orator/src/orator/tts.py new file mode 100644 index 0000000000000000000000000000000000000000..5081c6075aa6c58003fd3116e85e35304dbbce40 --- /dev/null +++ b/orator/src/orator/tts.py @@ -0,0 +1,205 @@ +from dataclasses import dataclass +from pathlib import Path + +import librosa +import torch +import torch.nn.functional as F +from huggingface_hub import hf_hub_download + +from .models.t3 import T3 +from .models.s3tokenizer import S3_SR, drop_invalid_tokens +from .models.s3gen import S3GEN_SR, S3Gen +from .models.tokenizers import EnTokenizer +from .models.voice_encoder import VoiceEncoder +from .models.t3.modules.cond_enc import T3Cond + + +REPO_ID = "ResembleAI/Orator" + + +def change_pace(speech_tokens: torch.Tensor, pace: float): + """ + :param speech_tokens: Tensor of shape (L,) + :param pace: float, pace (default: 1) + """ + L = len(speech_tokens) + speech_tokens = F.interpolate(speech_tokens.view(1, 1, -1).float(), size=int(L / pace), mode="nearest") + speech_tokens = speech_tokens.view(-1).long() + return speech_tokens + + +@dataclass +class Conditionals: + """ + Conditionals for T3 and S3Gen + - T3 conditionals: + - speaker_emb + - clap_emb + - cond_prompt_speech_tokens + - cond_prompt_speech_emb + - emotion_adv + - S3Gen conditionals: + - prompt_token + - prompt_token_len + - prompt_feat + - prompt_feat_len + - embedding + """ + t3: T3Cond + gen: dict + + def to(self, device): + self.t3 = self.t3.to(device=device) + for k, v in self.gen.items(): + if torch.is_tensor(v): + self.gen[k] = v.to(device=device) + return self + + def save(self, fpath: Path): + arg_dict = dict( + t3=self.t3.__dict__, + gen=self.gen + ) + torch.save(arg_dict, fpath) + + @classmethod + def load(cls, fpath, map_location="cpu"): + kwargs = torch.load(fpath, map_location=map_location, weights_only=True) + return cls(T3Cond(**kwargs['t3']), kwargs['gen']) + + +class OratorTTS: + ENC_COND_LEN = 6 * S3_SR + DEC_COND_LEN = 10 * S3GEN_SR + + def __init__( + self, + t3: T3, + s3gen: S3Gen, + ve: VoiceEncoder, + tokenizer: EnTokenizer, + device: str, + conds: Conditionals = None, + ): + self.sr = S3GEN_SR # sample rate of synthesized audio + self.t3 = t3 + self.s3gen = s3gen + self.ve = ve + self.tokenizer = tokenizer + self.device = device + self.conds = conds + + @classmethod + def from_local(cls, ckpt_dir, device) -> 'OratorTTS': + ckpt_dir = Path(ckpt_dir) + + ve = VoiceEncoder() + ve.load_state_dict( + torch.load(ckpt_dir / "ve.pt") + ) + ve.to(device).eval() + + t3 = T3() + t3.load_state_dict( + torch.load(ckpt_dir / "t3.pt") + ) + t3.to(device).eval() + + s3gen = S3Gen() + s3gen.load_state_dict( + torch.load(ckpt_dir / "s3gen.pt") + ) + s3gen.to(device).eval() + + tokenizer = EnTokenizer( + str(ckpt_dir / "tokenizer.json") + ) + + conds = None + if (builtin_voice := ckpt_dir / "conds.pt").exists(): + conds = Conditionals.load(builtin_voice).to(device) + + return cls(t3, s3gen, ve, tokenizer, device, conds=conds) + + @classmethod + def from_pretrained(cls, device) -> 'OratorTTS': + for fpath in ["ve.pt", "t3.pt", "s3gen.pt", "tokenizer.json", "conds.pt"]: + local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) + + return cls.from_local(Path(local_path).parent, device) + + def prepare_conditionals(self, wav_fpath, exaggeration=0.5): + ## Load reference wav + s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) + + s3_ref_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR) + + s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] + s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device) + + # Speech cond prompt tokens + if plen := self.t3.hp.speech_cond_prompt_len: + s3_tokzr = self.s3gen.tokenizer + t3_cond_prompt_tokens, _ = s3_tokzr.forward([s3_ref_wav[:self.ENC_COND_LEN]], max_len=plen) + t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device) + + # # Voice-encoder speaker embedding + ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([s3_ref_wav], sample_rate=S3_SR)) + ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device) + + t3_cond = T3Cond( + speaker_emb=ve_embed, + cond_prompt_speech_tokens=t3_cond_prompt_tokens, + emotion_adv=exaggeration * torch.ones(1, 1, 1), + ).to(device=self.device) + self.conds = Conditionals(t3_cond, s3gen_ref_dict) + + def generate( + self, + text, + audio_prompt_path=None, + exaggeration=0.5, + pace=1, + temperature=0.8, + ): + if audio_prompt_path: + self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration) + else: + assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`" + + # Update exaggeration if needed + if exaggeration != self.conds.t3.emotion_adv[0, 0, 0]: + _cond: T3Cond = self.conds.t3 + self.conds.t3 = T3Cond( + speaker_emb=_cond.speaker_emb, + cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens, + emotion_adv=exaggeration * torch.ones(1, 1, 1), + ).to(device=self.device) + + text_tokens = self.tokenizer.text_to_tokens(text).to(self.device) + + sot = self.t3.hp.start_text_token + eot = self.t3.hp.stop_text_token + text_tokens = F.pad(text_tokens, (1, 0), value=sot) + text_tokens = F.pad(text_tokens, (0, 1), value=eot) + + with torch.inference_mode(): + speech_tokens = self.t3.inference( + t3_cond=self.conds.t3, + text_tokens=text_tokens, + max_new_tokens=1000, # TODO: use the value in config + temperature=temperature, + ) + + # TODO: output becomes 1D + speech_tokens = drop_invalid_tokens(speech_tokens) + speech_tokens = speech_tokens.to(self.device) + + speech_tokens = change_pace(speech_tokens, pace=pace) + + wav, _ = self.s3gen.inference( + speech_tokens=speech_tokens, + ref_dict=self.conds.gen, + ) + + return wav.detach().cpu() diff --git a/orator/src/orator/vc.py b/orator/src/orator/vc.py new file mode 100644 index 0000000000000000000000000000000000000000..df140b3bcd5b383b0d9a4417c764f9a31ea379d8 --- /dev/null +++ b/orator/src/orator/vc.py @@ -0,0 +1,84 @@ +from pathlib import Path + +import librosa +import torch +from huggingface_hub import hf_hub_download + +from .models.s3tokenizer import S3_SR +from .models.s3gen import S3GEN_SR, S3Gen + + +REPO_ID = "ResembleAI/Orator" + + +class OratorVC: + ENC_COND_LEN = 6 * S3_SR + DEC_COND_LEN = 10 * S3GEN_SR + + def __init__( + self, + s3gen: S3Gen, + device: str, + ref_dict: dict=None, + ): + self.sr = S3GEN_SR + self.s3gen = s3gen + self.device = device + if ref_dict is None: + self.ref_dict = None + else: + self.ref_dict = { + k: v.to(device) if torch.is_tensor(v) else v + for k, v in ref_dict.items() + } + + @classmethod + def from_local(cls, ckpt_dir, device) -> 'OratorVC': + ckpt_dir = Path(ckpt_dir) + ref_dict = None + if (builtin_voice := ckpt_dir / "conds.pt").exists(): + states = torch.load(builtin_voice) + ref_dict = states['gen'] + + s3gen = S3Gen() + s3gen.load_state_dict( + torch.load(ckpt_dir / "s3gen.pt") + ) + s3gen.to(device).eval() + + return cls(s3gen, device, ref_dict=ref_dict) + + @classmethod + def from_pretrained(cls, device) -> 'OratorVC': + for fpath in ["s3gen.pt", "conds.pt"]: + local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) + + return cls.from_local(Path(local_path).parent, device) + + def set_target_voice(self, wav_fpath): + ## Load reference wav + s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) + + s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] + self.ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device) + + def generate( + self, + audio, + target_voice_path=None, + ): + if target_voice_path: + self.set_target_voice(target_voice_path) + else: + assert self.ref_dict is not None, "Please `prepare_conditionals` first or specify `target_voice_path`" + + with torch.inference_mode(): + audio_16, _ = librosa.load(audio, sr=S3_SR) + audio_16 = torch.from_numpy(audio_16).float().to(self.device)[None, ] + + s3_tokens, _ = self.s3gen.tokenizer(audio_16) + wav, _ = self.s3gen.inference( + speech_tokens=s3_tokens, + ref_dict=self.ref_dict, + ) + return wav.detach().cpu() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..863f120ec6c9c7402c2cadbc14133c2537ae41bd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +gradio +numpy==1.26.0 +resampy==0.4.3 +librosa==0.10.0 +s3tokenizer +torch==2.6.0 +torchaudio==2.6.0 +transformers==4.46.3 +diffusers==0.29.0 +omegaconf==2.3.0 +conformer==0.3.2