Spaces:
Paused
Paused
from .istftnet import Decoder | |
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder | |
from dataclasses import dataclass | |
from huggingface_hub import hf_hub_download | |
from loguru import logger | |
from transformers import AlbertConfig | |
from typing import Dict, Optional, Union | |
import json | |
import torch | |
import os | |
class KModel(torch.nn.Module): | |
''' | |
KModel is a torch.nn.Module with 2 main responsibilities: | |
1. Init weights, downloading config.json + model.pth from HF if needed | |
2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor) | |
You likely only need one KModel instance, and it can be reused across | |
multiple KPipelines to avoid redundant memory allocation. | |
Unlike KPipeline, KModel is language-blind. | |
KModel stores self.vocab and thus knows how to map phonemes -> input_ids, | |
so there is no need to repeatedly download config.json outside of KModel. | |
''' | |
MODEL_NAMES = { | |
'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth', | |
'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth', | |
} | |
def __init__( | |
self, | |
repo_id: Optional[str] = None, | |
config: Union[Dict, str, None] = None, | |
model: Optional[str] = None, | |
disable_complex: bool = False | |
): | |
super().__init__() | |
if repo_id is None: | |
repo_id = 'hexgrad/Kokoro-82M' | |
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.") | |
self.repo_id = repo_id | |
if not isinstance(config, dict): | |
if not config: | |
logger.debug("No config provided, downloading from HF") | |
config = hf_hub_download(repo_id=repo_id, filename='config.json') | |
with open(config, 'r', encoding='utf-8') as r: | |
config = json.load(r) | |
logger.debug(f"Loaded config: {config}") | |
self.vocab = config['vocab'] | |
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert'])) | |
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim']) | |
self.context_length = self.bert.config.max_position_embeddings | |
self.predictor = ProsodyPredictor( | |
style_dim=config['style_dim'], d_hid=config['hidden_dim'], | |
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout'] | |
) | |
self.text_encoder = TextEncoder( | |
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'], | |
depth=config['n_layer'], n_symbols=config['n_token'] | |
) | |
self.decoder = Decoder( | |
dim_in=config['hidden_dim'], style_dim=config['style_dim'], | |
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet'] | |
) | |
if not model: | |
try: | |
model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id]) | |
except: | |
model = os.path.join(repo_id, 'kokoro-v1_0.pth') | |
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items(): | |
assert hasattr(self, key), key | |
try: | |
getattr(self, key).load_state_dict(state_dict) | |
except: | |
logger.debug(f"Did not load {key} from state_dict") | |
state_dict = {k[7:]: v for k, v in state_dict.items()} | |
getattr(self, key).load_state_dict(state_dict, strict=False) | |
def device(self): | |
return self.bert.device | |
class Output: | |
audio: torch.FloatTensor | |
pred_dur: Optional[torch.LongTensor] = None | |
def forward_with_tokens( | |
self, | |
input_ids: torch.LongTensor, | |
ref_s: torch.FloatTensor, | |
speed: float = 1 | |
) -> tuple[torch.FloatTensor, torch.LongTensor]: | |
input_lengths = torch.full( | |
(input_ids.shape[0],), | |
input_ids.shape[-1], | |
device=input_ids.device, | |
dtype=torch.long | |
) | |
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths) | |
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device) | |
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int()) | |
d_en = self.bert_encoder(bert_dur).transpose(-1, -2) | |
s = ref_s[:, 128:] | |
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask) | |
x, _ = self.predictor.lstm(d) | |
duration = self.predictor.duration_proj(x) | |
duration = torch.sigmoid(duration).sum(axis=-1) / speed | |
pred_dur = torch.round(duration).clamp(min=1).long().squeeze() | |
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur) | |
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device) | |
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1 | |
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device) | |
en = d.transpose(-1, -2) @ pred_aln_trg | |
F0_pred, N_pred = self.predictor.F0Ntrain(en, s) | |
t_en = self.text_encoder(input_ids, input_lengths, text_mask) | |
asr = t_en @ pred_aln_trg | |
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze() | |
return audio, pred_dur | |
def forward( | |
self, | |
phonemes: str, | |
ref_s: torch.FloatTensor, | |
speed: float = 1, | |
return_output: bool = False | |
) -> Union['KModel.Output', torch.FloatTensor]: | |
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) | |
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}") | |
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) | |
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device) | |
ref_s = ref_s.to(self.device) | |
audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed) | |
audio = audio.squeeze().cpu() | |
pred_dur = pred_dur.cpu() if pred_dur is not None else None | |
logger.debug(f"pred_dur: {pred_dur}") | |
return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio | |
class KModelForONNX(torch.nn.Module): | |
def __init__(self, kmodel: KModel): | |
super().__init__() | |
self.kmodel = kmodel | |
def forward( | |
self, | |
input_ids: torch.LongTensor, | |
ref_s: torch.FloatTensor, | |
speed: float = 1 | |
) -> tuple[torch.FloatTensor, torch.LongTensor]: | |
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed) | |
return waveform, duration | |