|
|
|
|
|
|
|
__all__ = ['load_datasets', 'rand', 'Tunables', 'Encoder', 'Decoder', 'TSARTransformer', 'make_model'] |
|
|
|
|
|
import dataclasses |
|
import random |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.profiler import record_function |
|
|
|
from huggingface_hub import hf_hub_download |
|
from fastcore.basics import store_attr |
|
from fastprogress import progress_bar |
|
|
|
import webdataset as wds |
|
|
|
|
|
from pathlib import Path |
|
import pylab as plt |
|
import pandas as pd |
|
import numpy as np |
|
|
|
|
|
import whisper |
|
from whisperspeech.train import * |
|
from whisperspeech.modules import * |
|
from whisperspeech import vq_stoks |
|
|
|
|
|
import re |
|
|
|
class CharTokenizer: |
|
"""Trivial tokenizer – just use UTF-8 bytes""" |
|
eot = 0 |
|
|
|
def encode(self, txt): |
|
return list(bytes(txt.strip(), 'utf-8')) |
|
|
|
def decode(self, tokens): |
|
return bytes(tokens).decode('utf-8') |
|
|
|
def tokenizer(ikey, okey, length): |
|
"""Tokenizes a transcript""" |
|
tok = CharTokenizer() |
|
def _tokenizer(samples): |
|
for s in samples: |
|
toks = torch.tensor(tok.encode(s[ikey])) |
|
s[okey] = F.pad(toks, (0, length - toks.shape[-1]), value=tok.eot) |
|
yield s |
|
return _tokenizer |
|
|
|
def ar_padder(ikey, okey, length, pad_token): |
|
"""Pads the tokens for autoregresive training""" |
|
def _ar_padder(samples): |
|
for s in samples: |
|
toks = s[ikey] |
|
if isinstance(toks, (list, np.ndarray)): toks = torch.tensor(toks) |
|
toks = toks.to(torch.long) |
|
s['in_' +okey] = F.pad(toks, (1, length - toks.shape[-1] - 1), value=pad_token) |
|
s['out_'+okey] = F.pad(toks, (0, length - toks.shape[-1]), value=pad_token) |
|
yield s |
|
return _ar_padder |
|
|
|
def char_per_seconder(txt_key, stoks_key, cps_key, stoks_per_second=25): |
|
"""Adds the characters per second metric to the input data""" |
|
def _char_per_seconder(samples): |
|
for s in samples: |
|
secs = s[stoks_key].shape[-1] / stoks_per_second |
|
s[cps_key] = len(s[txt_key]) / secs |
|
yield s |
|
return _char_per_seconder |
|
|
|
|
|
def build_speaker_map(shards): |
|
speakers = set() |
|
for shard in shards: |
|
with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines())) |
|
return {id:i for i,id in enumerate(speakers)} |
|
|
|
def speaker_id_extractor(speaker_map): |
|
def _extractor(samples): |
|
for s in samples: |
|
s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]]) |
|
yield s |
|
return _extractor |
|
|
|
|
|
def load_datasets( |
|
input:str, |
|
samples:int, |
|
subsample:float=1, |
|
val_samples:int=512, |
|
vq_codes:int=4096, |
|
): |
|
if isinstance(input, (Path, str)): |
|
path = Path(input) |
|
if path.is_dir(): |
|
glob = '*-t2s-*.tar.gz' |
|
else: |
|
glob = path.name |
|
path = path.parent |
|
input = Path(path).glob(glob) |
|
elif isinstance(input, list): |
|
pass |
|
else: |
|
raise ArgumentError("input should be either a list of a path with an optional glob specifier") |
|
shards = [str(x) for x in input] |
|
|
|
speaker_map = build_speaker_map(shards) |
|
|
|
def ds(shards, length): |
|
ds = wds.WebDataset(wds.ResampledShards(shards)).compose( |
|
wds.decode(), |
|
speaker_id_extractor(speaker_map), |
|
wds.select(lambda s: s['stoks.npy'].shape[-1] > 12), |
|
tokenizer('txt', 'ttoks', length=550), |
|
ar_padder('stoks.npy', 'stoks', length=750, pad_token=vq_codes-1), |
|
char_per_seconder('txt', 'stoks.npy', 'cps', stoks_per_second=25), |
|
wds.to_tuple('ttoks', 'speaker', 'cps', 'in_stoks', 'out_stoks'), |
|
wds.batched(64) |
|
) |
|
ds.speakers = speaker_map |
|
ds.total_samples = length |
|
ds.stoks_len = 750 |
|
ds.stoks_codes = vq_codes |
|
ds.ttoks_len = 550 |
|
return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64) |
|
|
|
return ( |
|
ds(shards[1:], samples), |
|
ds(shards[:1], val_samples), |
|
) |
|
|
|
|
|
def rand(start, end): |
|
return random.random() * (end - start) + start |
|
|
|
@dataclasses.dataclass |
|
class Tunables: |
|
init_std :float = 1 |
|
embeddings_std :float = .01 |
|
embeddings_lr_scale: float = 5 |
|
embedding_projector_lr_scale: float = 2.5 |
|
output_mult :float = .35 |
|
query_mult :float = 1 |
|
encoder_depth_ratio :float = 0.25 |
|
eot_dropout_p :float = .5 |
|
cps_input: bool = True |
|
cps_bins: int = 32 |
|
|
|
lr0 :float = 1.5e-3 |
|
clip_gradient_norm :float = .2 |
|
weight_decay :float = 1e-1 |
|
warmup_steps :float = 4000 |
|
|
|
random :bool = False |
|
|
|
def __post_init__(self): |
|
|
|
if self.random: |
|
self.init_std = 10**rand(-1,1) |
|
self.embeddings_std = 10**rand(-3,-.7) |
|
self.embeddings_lr_scale = rand(2,6) |
|
self.output_mult = rand(0.25,0.65) |
|
self.query_mult = 2**rand(-2,3) |
|
self.encoder_depth_ratio = 0.25 |
|
|
|
self.lr0 = rand(1,5)*1e-3 |
|
self.clip_gradient_norm = 10**rand(-3,0) |
|
self.warmup_steps = 100*(10**rand(1,1.85)) |
|
|
|
|
|
class EmbeddingProjector(nn.Linear): |
|
pass |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, emb_width=384, ffn_mult=4, pos_embs=None, tunables=Tunables()): |
|
super().__init__() |
|
self.emb_width = emb_width |
|
|
|
self.emb_factor = width != emb_width |
|
|
|
self.embedding = nn.Embedding(codes, emb_width) |
|
if self.emb_factor: |
|
self.emb_to_hidden = EmbeddingProjector(emb_width, width) |
|
|
|
if pos_embs is None: pos_embs = sinusoids(length, width) |
|
self.register_buffer("positional_embedding", pos_embs) |
|
|
|
self.layers = nn.Sequential(*[ |
|
ResidualAttentionBlock(width, n_head, |
|
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth) |
|
]) |
|
|
|
self.ln_post = LayerNorm(width) |
|
|
|
def forward(self, Stoks): |
|
xin = self.embedding(Stoks) |
|
if self.emb_factor: |
|
xin = self.emb_to_hidden(xin) |
|
|
|
assert xin.shape[1:] == self.positional_embedding.shape, "incorrect semantic token shape" |
|
xin = (xin + self.positional_embedding).to(xin.dtype) |
|
|
|
return self.ln_post(self.layers(xin)) |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, depth=6, stoks_width=384, width=384, n_head=6, length=1500, codes=1024, ffn_mult=4, pos_embs=None, tunables=Tunables()): |
|
super().__init__() |
|
self.length = length |
|
self.codes = codes |
|
self.width = width |
|
self.stoks_width = stoks_width |
|
|
|
self.emb_factor = width != stoks_width |
|
|
|
|
|
self.embedding = nn.Embedding(codes, stoks_width) |
|
if self.emb_factor: |
|
self.emb_to_hidden = EmbeddingProjector(stoks_width, width) |
|
self.hidden_to_emb = EmbeddingProjector(width, stoks_width) |
|
|
|
if pos_embs is None: pos_embs = sinusoids(length, width) |
|
self.register_buffer("positional_embedding", pos_embs) |
|
|
|
self.layers = nn.ModuleList([ |
|
ResidualAttentionBlock(width, n_head, cross_attention=True, |
|
qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth) |
|
]) |
|
self.ln_post = LayerNorm(width) |
|
|
|
def forward(self, Stoks, xenc, cps=None): |
|
Sembs = self.embedding(Stoks) |
|
|
|
if self.emb_factor: |
|
Sembs = self.emb_to_hidden(Sembs) |
|
|
|
xin = (Sembs + self.positional_embedding[:Sembs.shape[1]]).to(xenc.dtype) |
|
if cps is not None: xin = xin + cps |
|
|
|
x = xin |
|
for l in self.layers: x = l(x, xenc, causal=True) |
|
|
|
x = self.ln_post(x) |
|
|
|
if self.emb_factor: |
|
x = self.hidden_to_emb(x) |
|
|
|
logits = (x @ self.embedding.weight.to(x.dtype).T).float() |
|
return logits |
|
|
|
|
|
class TSARTransformer(nn.Module): |
|
def __init__(self, depth=6, n_head=6, head_width=64, ffn_mult=4, language='en', |
|
ttoks_len=200, ttoks_codes=50364, ttoks_width=None, |
|
stoks_len=1500, stoks_codes=1024, stoks_width=None, |
|
tunables=Tunables()): |
|
assert language == 'en', "only english is supported right now" |
|
super().__init__() |
|
store_attr("depth,n_head,head_width,ffn_mult,stoks_width,ttoks_width,ttoks_len,stoks_len,ttoks_codes,stoks_codes,language") |
|
|
|
width = n_head * head_width |
|
self.width = width |
|
self.base_width = 3 * head_width |
|
self.tunables = tunables |
|
if self.stoks_width is None: self.stoks_width = self.width |
|
if self.ttoks_width is None: self.ttoks_width = self.width |
|
|
|
if tunables.cps_input: |
|
self.cps_embeddings = nn.Embedding(tunables.cps_bins, self.width) |
|
else: |
|
self.cps_embeddings = None |
|
|
|
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio) |
|
decoder_depth = depth * 2 - encoder_depth |
|
tformer_args = dict(width=width, n_head=n_head, ffn_mult=ffn_mult, tunables=tunables) |
|
self.encoder = Encoder(length=ttoks_len, codes=ttoks_codes, emb_width=self.ttoks_width, depth=encoder_depth, **tformer_args) |
|
self.decoder = Decoder(length=stoks_len, codes=stoks_codes, stoks_width=self.stoks_width, depth=decoder_depth, **tformer_args) |
|
|
|
self.tokenizer = None |
|
|
|
self.apply(self.init_transformer) |
|
|
|
def load_frozen_semantic_embeddings(self, vqmodel): |
|
with torch.no_grad(): |
|
self.decoder.embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0] |
|
self.decoder.embedding.lr_scale = 0 |
|
|
|
def setup(self, device): |
|
pass |
|
|
|
def init_transformer(self, m): |
|
if isinstance(m, LinearHead): |
|
m.no_weight_decay = True |
|
torch.nn.init.constant_(m.weight, 0) |
|
elif isinstance(m, QueryHead): |
|
m.lr_scale = 1/(m.weight.shape[1] / self.base_width) |
|
torch.nn.init.constant_(m.weight, 0) |
|
elif isinstance(m, nn.Embedding): |
|
m.no_weight_decay = True |
|
m.lr_scale = self.tunables.embeddings_lr_scale |
|
std = self.tunables.embeddings_std |
|
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) |
|
elif isinstance(m, EmbeddingProjector): |
|
m.lr_scale = self.tunables.embedding_projector_lr_scale |
|
std = self.tunables.init_std |
|
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) |
|
elif isinstance(m, nn.Linear): |
|
m.lr_scale = 1/(m.weight.shape[1] / self.base_width) |
|
std = self.tunables.init_std / m.weight.shape[1] |
|
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) |
|
if m.bias is not None: |
|
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std) |
|
elif isinstance(m, nn.LayerNorm): |
|
m.no_weight_decay = True |
|
torch.nn.init.constant_(m.bias, 0) |
|
torch.nn.init.constant_(m.weight, 1) |
|
|
|
def forward(self, Ttoks, speakers, cpss, in_stoks, out_stoks=None, loss=True): |
|
with record_function("encoder"): |
|
xenc = self.encoder(Ttoks.to(torch.long)) |
|
with record_function("decoder"): |
|
if self.cps_embeddings: |
|
cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long) |
|
cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1 |
|
cps_embs = self.cps_embeddings(cps_bin).unsqueeze(1) |
|
else: |
|
cps_embs = None |
|
logits = self.decoder(in_stoks, xenc, cps=cps_embs) * self.tunables.output_mult / (self.width / self.base_width) |
|
if loss is not None: |
|
with record_function("loss"): |
|
loss = F.cross_entropy(logits.transpose(-1,-2), out_stoks) |
|
return logits, loss |
|
|
|
|
|
|
|
|
|
@classmethod |
|
def load_model(cls, repo_id="collabora/whisperspeech", filename="t2s_up_wds.model", local_filename=None): |
|
if not local_filename: |
|
local_filename = hf_hub_download(repo_id=repo_id, filename=filename) |
|
spec = torch.load(local_filename) |
|
model = cls(**spec['config'], tunables=Tunables(**spec['tunables'])) |
|
model.load_state_dict(spec['state_dict']) |
|
model.eval() |
|
return model |
|
|
|
def load_checkpoint(self, local_filename): |
|
spec = torch.load(local_filename, map_location='cpu') |
|
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint' |
|
state_dict = {k.replace('model.', ''):v |
|
for k,v in spec['state_dict'].items()} |
|
self.load_state_dict(state_dict) |
|
return self |
|
|
|
def save_model(self, fname): |
|
torch.save(dict(config = self.__stored_args__, |
|
tunables = dataclasses.asdict(self.tunables), |
|
state_dict = self.state_dict()), fname) |
|
|
|
def ensure_tokenizer(self): |
|
assert not self.training |
|
if self.tokenizer is None: self.tokenizer = CharTokenizer() |
|
|
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
@torch.no_grad() |
|
def generate(self, txt, cps=15, N=None, T=0.7, top_k=None, show_progress_bar=True): |
|
self.ensure_tokenizer() |
|
N = N or self.stoks_len |
|
dev = self.device |
|
ttoks = torch.tensor(self.tokenizer.encode(txt), device=dev) |
|
ttoks = F.pad(ttoks, (0, self.ttoks_len - len(ttoks)), value=self.tokenizer.eot).unsqueeze(0) |
|
cpss = torch.tensor([cps], device=dev) |
|
toks = torch.zeros((1,N), dtype=torch.long, device=dev) |
|
toks[0,0] = self.stoks_codes-1 |
|
it = range(1,N) |
|
if show_progress_bar: it = progress_bar(it) |
|
for i in it: |
|
p, _ = self(ttoks, None, cpss, toks[:,:i], loss=None) |
|
last_p = p[0,-1] |
|
if top_k: |
|
last_p[last_p < torch.topk(last_p, top_k).values[-1,None]] = -torch.inf |
|
tok = torch.multinomial((last_p / float(T)).softmax(-1), 1) |
|
toks[0,i] = tok |
|
if toks[0,i] == self.stoks_codes-1: return toks[0,1:i] |
|
return toks[0,1:] |
|
|
|
@torch.no_grad() |
|
def generate_batch(self, txts, N=None, T=1.1, top_k=7, show_progress_bar=True): |
|
self.ensure_tokenizer() |
|
N = self.stoks_len |
|
dev = self.device |
|
ttoks = [] |
|
for txt in txts: |
|
ttoks_ = torch.tensor(self.tokenizer.encode(txt), device=dev) |
|
ttoks_ = F.pad(ttoks_, (0, self.ttoks_len - len(ttoks_)), value=self.tokenizer.eot).unsqueeze(0) |
|
ttoks.append(ttoks_) |
|
ttoks = torch.cat(ttoks, dim=0) |
|
toks = torch.zeros((len(ttoks),N), dtype=torch.long, device=dev) |
|
it = range(N) |
|
if show_progress_bar: it = progress_bar(it) |
|
for i in it: |
|
p, _ = self(ttoks, toks[:,:i], loss=None) |
|
last_p = p[:,-1] |
|
if top_k: |
|
last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf |
|
tok = torch.multinomial((last_p / float(T)).softmax(-1), 1) |
|
toks[:,i] = tok[:,0] |
|
if (toks[:,i] == self.stoks_codes-1).all(): return toks[:,:i] |
|
return toks |
|
|
|
|
|
def _make_model(size:str, tunables:Tunables=Tunables(), dataset=None, **kwargs): |
|
kwargs = dict(stoks_len = dataset.stoks_len, ttoks_len = dataset.ttoks_len, tunables=tunables, **kwargs) |
|
if 'stoks_codes' not in kwargs: kwargs['stoks_codes'] = dataset.stoks_codes |
|
if size == 'micro': |
|
return TSARTransformer(depth=2, n_head=3, ffn_mult=1, **kwargs) |
|
if size == 'tiny': |
|
return TSARTransformer(depth=4, n_head=6, **kwargs) |
|
if size == 'base': |
|
return TSARTransformer(depth=6, n_head=8, **kwargs) |
|
if size == 'small': |
|
return TSARTransformer(depth=12, n_head=16, **kwargs) |
|
|
|
def make_model(size:str, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None): |
|
if frozen_embeddings_model: |
|
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model) |
|
model = _make_model(size, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1]) |
|
model.load_frozen_semantic_embeddings(vqmodel) |
|
else: |
|
model = _make_model(size, quantizers, tunables, dataset) |
|
return model |
|
|