from dataclasses import dataclass from typing import Optional import torch from torch import nn, Tensor from .perceiver import Perceiver from .t3_config import T3Config @dataclass class T3Cond: """ Dataclass container for most / all conditioning info. TODO: serialization methods aren't used, keeping them around for convenience """ speaker_emb: Tensor clap_emb: Optional[Tensor] = None cond_prompt_speech_tokens: Optional[Tensor] = None cond_prompt_speech_emb: Optional[Tensor] = None emotion_adv: Optional[Tensor] = 0.5 def to(self, *, device=None, dtype=None): "Cast to a device and dtype. Dtype casting is ignored for long/int tensors." for k, v in self.__dict__.items(): if torch.is_tensor(v): is_fp = type(v.view(-1)[0].item()) is not int setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None)) return self def save(self, fpath): torch.save(self.__dict__, fpath) @staticmethod def load(fpath, map_location="cpu"): kwargs = torch.load(fpath, map_location=map_location, weights_only=True) return T3Cond(**kwargs) class T3CondEnc(nn.Module): """ Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc. """ def __init__(self, hp: T3Config): super().__init__() self.hp = hp if hp.encoder_type == "voice_encoder": self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels) else: raise NotImplementedError(str(hp.encoder_type)) # emotion adv self.emotion_adv_fc = None if hp.emotion_adv: self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False) # perceiver resampler self.perceiver = None if hp.use_perceiver_resampler: self.perceiver = Perceiver() def forward(self, cond: T3Cond): # Validate assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \ "no embeddings for cond_prompt_speech_tokens" # Speaker embedding projection cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim) empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim) # TODO CLAP assert cond.clap_emb is None, "clap_embed not implemented" cond_clap = empty # (B, 0, dim) # Cond prompt cond_prompt_speech_emb = cond.cond_prompt_speech_emb if cond_prompt_speech_emb is None: cond_prompt_speech_emb = empty # (B, 0, dim) elif self.hp.use_perceiver_resampler: cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb) # Emotion Adv: must provide a value if this model uses emotion conditioning cond_emotion_adv = empty # (B, 0, dim) if self.hp.emotion_adv: assert cond.emotion_adv is not None cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1)) # Concat and return cond_embeds = torch.cat(( cond_spkr, cond_clap, cond_prompt_speech_emb, cond_emotion_adv, ), dim=1) return cond_embeds