Sulai2005's picture
Initial commit
506a2b4
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn, Tensor
from .perceiver import Perceiver
from .t3_config import T3Config
@dataclass
class T3Cond:
"""
Dataclass container for most / all conditioning info.
TODO: serialization methods aren't used, keeping them around for convenience
"""
speaker_emb: Tensor
clap_emb: Optional[Tensor] = None
cond_prompt_speech_tokens: Optional[Tensor] = None
cond_prompt_speech_emb: Optional[Tensor] = None
emotion_adv: Optional[Tensor] = 0.5
def to(self, *, device=None, dtype=None):
"Cast to a device and dtype. Dtype casting is ignored for long/int tensors."
for k, v in self.__dict__.items():
if torch.is_tensor(v):
is_fp = type(v.view(-1)[0].item()) is not int
setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None))
return self
def save(self, fpath):
torch.save(self.__dict__, fpath)
@staticmethod
def load(fpath, map_location="cpu"):
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
return T3Cond(**kwargs)
class T3CondEnc(nn.Module):
"""
Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc.
"""
def __init__(self, hp: T3Config):
super().__init__()
self.hp = hp
if hp.encoder_type == "voice_encoder":
self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels)
else:
raise NotImplementedError(str(hp.encoder_type))
# emotion adv
self.emotion_adv_fc = None
if hp.emotion_adv:
self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False)
# perceiver resampler
self.perceiver = None
if hp.use_perceiver_resampler:
self.perceiver = Perceiver()
def forward(self, cond: T3Cond):
# Validate
assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \
"no embeddings for cond_prompt_speech_tokens"
# Speaker embedding projection
cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim)
empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim)
# TODO CLAP
assert cond.clap_emb is None, "clap_embed not implemented"
cond_clap = empty # (B, 0, dim)
# Cond prompt
cond_prompt_speech_emb = cond.cond_prompt_speech_emb
if cond_prompt_speech_emb is None:
cond_prompt_speech_emb = empty # (B, 0, dim)
elif self.hp.use_perceiver_resampler:
cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb)
# Emotion Adv: must provide a value if this model uses emotion conditioning
cond_emotion_adv = empty # (B, 0, dim)
if self.hp.emotion_adv:
assert cond.emotion_adv is not None
cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1))
# Concat and return
cond_embeds = torch.cat((
cond_spkr,
cond_clap,
cond_prompt_speech_emb,
cond_emotion_adv,
), dim=1)
return cond_embeds