Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| import yaml | |
| import logging | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .nn.feature_extractor import MelFeatureExtractor | |
| from .nn.modules import OmniAudioEncoder, OmniAudioDecoder, ResidualDownConv, UpConv, Transformer, Vocos | |
| from .nn.quantizer import ResidualVQ | |
| class XY_Tokenizer(nn.Module): | |
| def __init__(self, generator_params): | |
| super().__init__() | |
| # Basic parameters | |
| self.input_sample_rate = generator_params['input_sample_rate'] | |
| self.output_sample_rate = generator_params['output_sample_rate'] | |
| self.encoder_downsample_rate = 1280 | |
| self.decoder_upsample_rate = 1920 | |
| self.code_dim = generator_params['quantizer_kwargs']['input_dim'] | |
| ## Codec part | |
| ## Semantic channel | |
| self.semantic_encoder = OmniAudioEncoder(**generator_params['semantic_encoder_kwargs']) | |
| self.semantic_encoder_adapter = Transformer(**generator_params['semantic_encoder_adapter_kwargs']) | |
| ## Acoustic channel | |
| self.acoustic_encoder = OmniAudioEncoder(**generator_params['acoustic_encoder_kwargs']) | |
| ## Semantic & acoustic shared parameters | |
| self.pre_rvq_adapter = Transformer(**generator_params['pre_rvq_adapter_kwargs']) | |
| self.downsample = ResidualDownConv(**generator_params['downsample_kwargs']) | |
| self.quantizer = ResidualVQ(**generator_params['quantizer_kwargs']) | |
| self.nq = generator_params['quantizer_kwargs']['num_quantizers'] | |
| self.post_rvq_adapter = Transformer(**generator_params['post_rvq_adapter_kwargs']) | |
| ## Acoustic channel | |
| self.upsample = UpConv(**generator_params['upsample_kwargs']) | |
| self.acoustic_decoder = OmniAudioDecoder(**generator_params['acoustic_decoder_kwargs']) | |
| self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs']) | |
| ## Feature extractor | |
| self.feature_extractor = MelFeatureExtractor(**generator_params['feature_extractor_kwargs']) | |
| def inference_tokenize(self, x, input_lengths): | |
| """ | |
| Input: | |
| x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate | |
| input_lengths: Valid length for each sample # (B,) | |
| Output: | |
| dict: Contains the following key-value pairs | |
| "zq": Quantized embeddings # (B, D, T) | |
| "codes": Quantization codes # (nq, B, T) | |
| "codes_lengths": Quantization code lengths # (B,) | |
| """ | |
| list_x = [xi[:, :x_len].reshape(-1).cpu().numpy() for xi, x_len in zip(x, input_lengths)] | |
| features = self.feature_extractor( | |
| list_x, | |
| sampling_rate=self.input_sample_rate, | |
| return_tensors="pt", | |
| return_attention_mask=True | |
| ) | |
| input_mel = features['input_features'].to(x.device).to(x.dtype) # (B, D, 3000) | |
| audio_attention_mask = features['attention_mask'].to(x.device) # (B, 3000) | |
| # Get batch size and sequence length of the input | |
| mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,) | |
| # Semantic channel | |
| semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz | |
| semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) # (B, D, T), 50hz | |
| # Acoustic channel | |
| acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz | |
| # Semantic & acoustic mixing | |
| concated_semantic_acoustic_channel = torch.concat([semantic_encoder_adapter_output, acoustic_encoder_output], dim=1) # (B, D, T) | |
| concated_semantic_acoustic_channel_length = acoustic_encoder_output_length | |
| pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) # (B, D, T), 50hz | |
| downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) # (B, D, T), 50hz -> 12.5hz | |
| zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,) | |
| return { | |
| "zq": zq, # (B, D, T) | |
| "codes": codes, # (nq, B, T) | |
| "codes_lengths": quantizer_output_length # (B,) | |
| } | |
| def inference_detokenize(self, codes, codes_lengths): | |
| """ | |
| Input: | |
| codes: Quantization codes # (nq, B, T) | |
| codes_lengths: Quantization code lengths for each sample # (B,) | |
| Output: | |
| dict: Contains the following key-value pairs | |
| "y": Synthesized audio waveform # (B, 1, T) | |
| "output_length": Output lengths # (B,) | |
| """ | |
| zq = self.quantizer.decode_codes(codes) # (B, D, T) | |
| post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) # (B, D, T), 12.5hz | |
| # Acoustic channel | |
| upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) # (B, D, T), 12.5hz -> 50hz | |
| acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) # (B, D, T), 50hz -> 100hz | |
| y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) # (B, 1, T), 100hz -> 16khz | |
| return { | |
| "y": y, # (B, 1, T) | |
| "output_length": vocos_output_length, # (B,) | |
| } | |
| def encode(self, wav_list, overlap_seconds=10, device=torch.device("cuda")): | |
| """ | |
| Input: | |
| wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,) | |
| overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output | |
| Output: | |
| dict: Contains the following key-value pairs | |
| "codes_list": List of quantization codes # B * (nq, T) | |
| """ | |
| duration_seconds = 30 - overlap_seconds | |
| chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk | |
| duration_size = int(duration_seconds * self.input_sample_rate) # Valid output samples per chunk | |
| code_duration_length = duration_size // self.encoder_downsample_rate # Valid code length per chunk | |
| # Get maximum waveform length | |
| max_length = max(len(wav) for wav in wav_list) | |
| batch_size = len(wav_list) | |
| wav_tensor = torch.zeros(batch_size, 1, max_length, device=device) | |
| input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) | |
| for i, wav in enumerate(wav_list): | |
| wav_tensor[i, 0, :len(wav)] = wav | |
| input_lengths[i] = len(wav) # (B,) | |
| # Calculate number of chunks needed | |
| max_chunks = (max_length + duration_size - 1) // duration_size | |
| codes_list = [] | |
| # Process the entire batch in chunks | |
| for chunk_idx in range(max_chunks): | |
| start = chunk_idx * duration_size | |
| end = min(start + chunk_size, max_length) | |
| chunk = wav_tensor[:, :, start:end] # (B, 1, T') | |
| chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,) | |
| # Skip empty chunks | |
| if chunk_lengths.max() == 0: | |
| continue | |
| # Encode | |
| result = self.inference_tokenize(chunk, chunk_lengths) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)} | |
| chunk_codes = result["codes"] # (nq, B, T') | |
| chunk_code_lengths = result["codes_lengths"] # (B,) | |
| # Extract valid portion | |
| valid_code_lengths = torch.clamp(chunk_code_lengths, 0, code_duration_length) # (B,) | |
| valid_chunk_codes = torch.zeros(self.nq, batch_size, code_duration_length, device=device, dtype=chunk_codes.dtype) | |
| for b in range(batch_size): | |
| if valid_code_lengths[b] > 0: | |
| valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[:, b, :valid_code_lengths[b]] # (nq, B, valid_code_length) | |
| codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length) | |
| # Concatenate all chunks | |
| if codes_list: | |
| codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total) | |
| codes_list = [codes_tensor[:, i, :input_lengths[i] // self.encoder_downsample_rate] for i in range(batch_size)] # B * (nq, T) | |
| else: | |
| codes_list = [torch.zeros(self.nq, 0, device=device, dtype=torch.long) for _ in range(batch_size)] # B * (nq, 0) | |
| return { | |
| "codes_list": codes_list # B * (nq, T) | |
| } | |
| def decode(self, codes_list, overlap_seconds=10, device=torch.device("cuda")): | |
| """ | |
| Input: | |
| codes_list: List of quantization codes # B * (nq, T) | |
| overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output | |
| Output: | |
| dict: Contains the following key-value pairs | |
| "syn_wav_list": List of synthesized audio waveforms # B * (T,) | |
| """ | |
| duration_seconds = 30 - overlap_seconds | |
| chunk_code_length = int(30 * self.input_sample_rate // self.encoder_downsample_rate) # Maximum code length per chunk | |
| duration_code_length = int(duration_seconds * self.input_sample_rate // self.encoder_downsample_rate) # Valid code length per chunk | |
| duration_wav_length = duration_code_length * self.decoder_upsample_rate # Valid waveform length per chunk | |
| # Get maximum code length | |
| max_code_length = max(codes.shape[-1] for codes in codes_list) | |
| batch_size = len(codes_list) | |
| codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=device, dtype=torch.long) | |
| code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) | |
| for i, codes in enumerate(codes_list): | |
| codes_tensor[:, i, :codes.shape[-1]] = codes.to(device) | |
| code_lengths[i] = codes.shape[-1] # (B,) | |
| # Calculate number of chunks needed | |
| max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length | |
| wav_list = [] | |
| # Process the entire batch in chunks | |
| for chunk_idx in range(max_chunks): | |
| start = chunk_idx * duration_code_length | |
| end = min(start + chunk_code_length, max_code_length) | |
| chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T') | |
| chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,) | |
| # Skip empty chunks | |
| if chunk_code_lengths.max() == 0: | |
| continue | |
| # Decode | |
| result = self.inference_detokenize(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)} | |
| chunk_wav = result["y"] # (B, 1, T') | |
| chunk_wav_lengths = result["output_length"] # (B,) | |
| # Extract valid portion | |
| valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,) | |
| valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=device) | |
| for b in range(batch_size): | |
| if valid_wav_lengths[b] > 0: | |
| valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length) | |
| wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length) | |
| # Concatenate all chunks | |
| if wav_list: | |
| wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total) | |
| syn_wav_list = [wav_tensor[i, 0, :code_lengths[i] * self.decoder_upsample_rate] for i in range(batch_size)] # B * (T,) | |
| else: | |
| syn_wav_list = [torch.zeros(0, device=device) for _ in range(batch_size)] # B * (0,) | |
| return { | |
| "syn_wav_list": syn_wav_list # B * (T,) | |
| } | |
| def load_from_checkpoint(cls, config_path: str, ckpt_path: str): | |
| # Load model from configuration file and checkpoint | |
| logging.info(f"Loading model from {config_path} and {ckpt_path}") | |
| # Load configuration | |
| with open(config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| # Create model instance | |
| model = cls(config['generator_params']) | |
| # Load checkpoint | |
| checkpoint = torch.load(ckpt_path, map_location='cpu') | |
| # Check if checkpoint contains 'generator' key | |
| if 'generator' in checkpoint: | |
| model.load_state_dict(checkpoint['generator']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| return model |