File size: 13,367 Bytes
ea174b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 |
# -*- coding: utf-8 -*-
import yaml
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from .nn.feature_extractor import MelFeatureExtractor
from .nn.modules import OmniAudioEncoder, OmniAudioDecoder, ResidualDownConv, UpConv, Transformer, Vocos
from .nn.quantizer import ResidualVQ
class XY_Tokenizer(nn.Module):
def __init__(self, generator_params):
super().__init__()
# Basic parameters
self.input_sample_rate = generator_params['input_sample_rate']
self.output_sample_rate = generator_params['output_sample_rate']
self.encoder_downsample_rate = 1280
self.decoder_upsample_rate = 1920
self.code_dim = generator_params['quantizer_kwargs']['input_dim']
## Codec part
## Semantic channel
self.semantic_encoder = OmniAudioEncoder(**generator_params['semantic_encoder_kwargs'])
self.semantic_encoder_adapter = Transformer(**generator_params['semantic_encoder_adapter_kwargs'])
## Acoustic channel
self.acoustic_encoder = OmniAudioEncoder(**generator_params['acoustic_encoder_kwargs'])
## Semantic & acoustic shared parameters
self.pre_rvq_adapter = Transformer(**generator_params['pre_rvq_adapter_kwargs'])
self.downsample = ResidualDownConv(**generator_params['downsample_kwargs'])
self.quantizer = ResidualVQ(**generator_params['quantizer_kwargs'])
self.nq = generator_params['quantizer_kwargs']['num_quantizers']
self.post_rvq_adapter = Transformer(**generator_params['post_rvq_adapter_kwargs'])
## Acoustic channel
self.upsample = UpConv(**generator_params['upsample_kwargs'])
self.acoustic_decoder = OmniAudioDecoder(**generator_params['acoustic_decoder_kwargs'])
self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs'])
## Feature extractor
self.feature_extractor = MelFeatureExtractor(**generator_params['feature_extractor_kwargs'])
@torch.inference_mode()
def inference_tokenize(self, x, input_lengths):
"""
Input:
x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate
input_lengths: Valid length for each sample # (B,)
Output:
dict: Contains the following key-value pairs
"zq": Quantized embeddings # (B, D, T)
"codes": Quantization codes # (nq, B, T)
"codes_lengths": Quantization code lengths # (B,)
"""
list_x = [xi[:, :x_len].reshape(-1).cpu().numpy() for xi, x_len in zip(x, input_lengths)]
features = self.feature_extractor(
list_x,
sampling_rate=self.input_sample_rate,
return_tensors="pt",
return_attention_mask=True
)
input_mel = features['input_features'].to(x.device).to(x.dtype) # (B, D, 3000)
audio_attention_mask = features['attention_mask'].to(x.device) # (B, 3000)
# Get batch size and sequence length of the input
mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,)
# Semantic channel
semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) # (B, D, T), 50hz
# Acoustic channel
acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
# Semantic & acoustic mixing
concated_semantic_acoustic_channel = torch.concat([semantic_encoder_adapter_output, acoustic_encoder_output], dim=1) # (B, D, T)
concated_semantic_acoustic_channel_length = acoustic_encoder_output_length
pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) # (B, D, T), 50hz
downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) # (B, D, T), 50hz -> 12.5hz
zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,)
return {
"zq": zq, # (B, D, T)
"codes": codes, # (nq, B, T)
"codes_lengths": quantizer_output_length # (B,)
}
@torch.inference_mode()
def inference_detokenize(self, codes, codes_lengths):
"""
Input:
codes: Quantization codes # (nq, B, T)
codes_lengths: Quantization code lengths for each sample # (B,)
Output:
dict: Contains the following key-value pairs
"y": Synthesized audio waveform # (B, 1, T)
"output_length": Output lengths # (B,)
"""
zq = self.quantizer.decode_codes(codes) # (B, D, T)
post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) # (B, D, T), 12.5hz
# Acoustic channel
upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) # (B, D, T), 12.5hz -> 50hz
acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) # (B, D, T), 50hz -> 100hz
y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) # (B, 1, T), 100hz -> 16khz
return {
"y": y, # (B, 1, T)
"output_length": vocos_output_length, # (B,)
}
@torch.inference_mode()
def encode(self, wav_list, overlap_seconds=10, device=torch.device("cuda")):
"""
Input:
wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,)
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
Output:
dict: Contains the following key-value pairs
"codes_list": List of quantization codes # B * (nq, T)
"""
duration_seconds = 30 - overlap_seconds
chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk
duration_size = int(duration_seconds * self.input_sample_rate) # Valid output samples per chunk
code_duration_length = duration_size // self.encoder_downsample_rate # Valid code length per chunk
# Get maximum waveform length
max_length = max(len(wav) for wav in wav_list)
batch_size = len(wav_list)
wav_tensor = torch.zeros(batch_size, 1, max_length, device=device)
input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
for i, wav in enumerate(wav_list):
wav_tensor[i, 0, :len(wav)] = wav
input_lengths[i] = len(wav) # (B,)
# Calculate number of chunks needed
max_chunks = (max_length + duration_size - 1) // duration_size
codes_list = []
# Process the entire batch in chunks
for chunk_idx in range(max_chunks):
start = chunk_idx * duration_size
end = min(start + chunk_size, max_length)
chunk = wav_tensor[:, :, start:end] # (B, 1, T')
chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,)
# Skip empty chunks
if chunk_lengths.max() == 0:
continue
# Encode
result = self.inference_tokenize(chunk, chunk_lengths) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)}
chunk_codes = result["codes"] # (nq, B, T')
chunk_code_lengths = result["codes_lengths"] # (B,)
# Extract valid portion
valid_code_lengths = torch.clamp(chunk_code_lengths, 0, code_duration_length) # (B,)
valid_chunk_codes = torch.zeros(self.nq, batch_size, code_duration_length, device=device, dtype=chunk_codes.dtype)
for b in range(batch_size):
if valid_code_lengths[b] > 0:
valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[:, b, :valid_code_lengths[b]] # (nq, B, valid_code_length)
codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length)
# Concatenate all chunks
if codes_list:
codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total)
codes_list = [codes_tensor[:, i, :input_lengths[i] // self.encoder_downsample_rate] for i in range(batch_size)] # B * (nq, T)
else:
codes_list = [torch.zeros(self.nq, 0, device=device, dtype=torch.long) for _ in range(batch_size)] # B * (nq, 0)
return {
"codes_list": codes_list # B * (nq, T)
}
@torch.inference_mode()
def decode(self, codes_list, overlap_seconds=10, device=torch.device("cuda")):
"""
Input:
codes_list: List of quantization codes # B * (nq, T)
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
Output:
dict: Contains the following key-value pairs
"syn_wav_list": List of synthesized audio waveforms # B * (T,)
"""
duration_seconds = 30 - overlap_seconds
chunk_code_length = int(30 * self.input_sample_rate // self.encoder_downsample_rate) # Maximum code length per chunk
duration_code_length = int(duration_seconds * self.input_sample_rate // self.encoder_downsample_rate) # Valid code length per chunk
duration_wav_length = duration_code_length * self.decoder_upsample_rate # Valid waveform length per chunk
# Get maximum code length
max_code_length = max(codes.shape[-1] for codes in codes_list)
batch_size = len(codes_list)
codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=device, dtype=torch.long)
code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
for i, codes in enumerate(codes_list):
codes_tensor[:, i, :codes.shape[-1]] = codes.to(device)
code_lengths[i] = codes.shape[-1] # (B,)
# Calculate number of chunks needed
max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length
wav_list = []
# Process the entire batch in chunks
for chunk_idx in range(max_chunks):
start = chunk_idx * duration_code_length
end = min(start + chunk_code_length, max_code_length)
chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,)
# Skip empty chunks
if chunk_code_lengths.max() == 0:
continue
# Decode
result = self.inference_detokenize(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)}
chunk_wav = result["y"] # (B, 1, T')
chunk_wav_lengths = result["output_length"] # (B,)
# Extract valid portion
valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,)
valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=device)
for b in range(batch_size):
if valid_wav_lengths[b] > 0:
valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length)
wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
# Concatenate all chunks
if wav_list:
wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
syn_wav_list = [wav_tensor[i, 0, :code_lengths[i] * self.decoder_upsample_rate] for i in range(batch_size)] # B * (T,)
else:
syn_wav_list = [torch.zeros(0, device=device) for _ in range(batch_size)] # B * (0,)
return {
"syn_wav_list": syn_wav_list # B * (T,)
}
@classmethod
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
# Load model from configuration file and checkpoint
logging.info(f"Loading model from {config_path} and {ckpt_path}")
# Load configuration
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# Create model instance
model = cls(config['generator_params'])
# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location='cpu')
# Check if checkpoint contains 'generator' key
if 'generator' in checkpoint:
model.load_state_dict(checkpoint['generator'])
else:
model.load_state_dict(checkpoint)
return model |