|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Optional, Tuple, Union, Literal |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers.modeling_outputs import BaseModelOutput |
|
from transformers import HubertModel, AutoConfig, AutoModel |
|
|
|
|
|
@dataclass |
|
class CustomHubertConfig: |
|
"""Configuration class for CustomHubert model.""" |
|
|
|
|
|
checkpoint_name: str |
|
|
|
feature_layer: int = 11 |
|
|
|
target_sample_rate: int = 16000 |
|
|
|
seq_len_multiple_of: Optional[int] = None |
|
|
|
|
|
@dataclass |
|
class HubertForBarkSemanticConfig: |
|
"""Configuration for HuBERTForBarkSemantic.""" |
|
|
|
|
|
checkpoint_name: Literal["facebook/hubert-base-ls960", "hubert-large-ls960-ft"] |
|
vocab_size: int |
|
|
|
feature_layer: int = 11 |
|
|
|
|
|
max_target_length: int = 2000 |
|
num_decoder_layer: int = 12 |
|
sos_token_id: int = 10000 |
|
eos_token_id: int = 10001 |
|
|
|
|
|
class HubertFeatureExtractor(nn.Module): |
|
""" |
|
A custom HuBERT model that loads a pretrained model from transformers and extracts |
|
features from a specified layer. Processes raw audio waveforms and returns hidden states. |
|
|
|
Args: |
|
config (CustomHubertConfig): Configuration specifying checkpoint, layer, and audio settings. |
|
device (torch.device, optional): Device to run the model on (e.g., "cuda" or "cpu"). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config: CustomHubertConfig, |
|
load_pretrained_weights: bool, |
|
device: Optional[torch.device] = None, |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.target_sample_rate = config.target_sample_rate |
|
|
|
|
|
self.hubert_config = AutoConfig.from_pretrained(config.checkpoint_name) |
|
if load_pretrained_weights: |
|
self.model = HubertModel.from_pretrained(config.checkpoint_name) |
|
else: |
|
|
|
self.model = AutoModel.from_config(self.hubert_config) |
|
|
|
|
|
|
|
num_layers = self.model.config.num_hidden_layers |
|
if not (0 <= config.feature_layer < num_layers): |
|
raise ValueError( |
|
f"feature_layer must be between 0 and {num_layers - 1}, got {config.feature_layer}" |
|
) |
|
self.feature_layer = config.feature_layer |
|
|
|
|
|
if device is not None: |
|
self.to(device) |
|
|
|
@property |
|
def hidden_size(self) -> int: |
|
"""Returns the hidden size of the HuBERT model (e.g., 768 for BASE, 1024 for LARGE).""" |
|
return self.model.config.hidden_size |
|
|
|
def forward( |
|
self, |
|
wav_input: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Processes raw audio waveforms through HuBERT and extracts features from the specified layer. |
|
Input audio sample rate expected 16k |
|
|
|
Args: |
|
wav_input (torch.Tensor): Raw audio waveforms, shape [batch_size, audio_length]. |
|
return_shape (Tuple[int, int], optional): If provided, reshapes output to [batch_size, seq_length, hidden_size]. |
|
|
|
Returns: |
|
torch.Tensor: Features from the specified layer. Shape depends on return_shape: |
|
- If None: [batch_size * seq_length, hidden_size] (flattened). |
|
- If provided: [batch_size, seq_length, hidden_size]. |
|
""" |
|
|
|
|
|
|
|
outputs: BaseModelOutput = self.model( |
|
input_values=wav_input, output_hidden_states=True, return_dict=True |
|
) |
|
|
|
|
|
|
|
features = outputs.hidden_states[self.feature_layer] |
|
features = features.contiguous() |
|
return features |
|
|
|
|
|
class HuBERTForBarkSemantic(nn.Module): |
|
def __init__( |
|
self, |
|
config: HubertForBarkSemanticConfig, |
|
load_hubert_pretrained_weights: bool = True, |
|
device: Optional[torch.device] = None, |
|
): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
hubert_config = CustomHubertConfig( |
|
checkpoint_name=config.checkpoint_name, |
|
feature_layer=config.feature_layer, |
|
) |
|
self.hubert = HubertFeatureExtractor( |
|
config=hubert_config, |
|
load_pretrained_weights=load_hubert_pretrained_weights, |
|
device=device, |
|
) |
|
|
|
|
|
input_size = self.hubert.model.config.hidden_size |
|
|
|
|
|
self.decoder_embedding = nn.Embedding(config.vocab_size, input_size) |
|
self.pos_embedding = nn.Parameter( |
|
torch.zeros(1, config.max_target_length, input_size) |
|
) |
|
self.decoder = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer( |
|
d_model=input_size, |
|
nhead=8, |
|
dim_feedforward=2048, |
|
dropout=0.1, |
|
batch_first=True, |
|
), |
|
num_layers=config.num_decoder_layer, |
|
) |
|
self.fc = nn.Linear(input_size, config.vocab_size) |
|
|
|
if device is not None: |
|
self.to(device) |
|
|
|
def save_state_dict(self, save_path: str): |
|
torch.save(self.state_dict(), save_path) |
|
|
|
def forward(self, wav_input: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass: Extracts HuBERT features and predicts semantic token probabilities. |
|
|
|
Args: |
|
wav_input: [batch_size, audio_length] (e.g., [2, 160000]) |
|
tgt: the target sequence |
|
|
|
Returns: |
|
[batch_size, seq_length, vocab_size + 1] (e.g., [2, 500, VOCAB_SIZE]) |
|
""" |
|
memory: torch.Tensor = self.hubert(wav_input) |
|
B, T_tgt = tgt.shape |
|
tgt_emb = self.decoder_embedding(tgt) + self.pos_embedding[:, :T_tgt, :] |
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask(T_tgt).to(tgt.device) |
|
|
|
output: torch.Tensor = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask) |
|
logits = self.fc(output) |
|
return logits |
|
|
|
@torch.no_grad |
|
def generate( |
|
self, |
|
wav_input: torch.Tensor, |
|
temperature: Optional[float] = 0.8, |
|
eos_p: Optional[float] = 0.5, |
|
max_length: int = 600, |
|
) -> torch.Tensor: |
|
""" |
|
Inference: autoregressive generation. |
|
assuming wav_input audio is at 16000 sample rate""" |
|
self.eval() |
|
memory = self.hubert(wav_input) |
|
B = wav_input.shape[0] |
|
tgt = torch.full( |
|
size=(B, 1), fill_value=self.config.sos_token_id, device=wav_input.device |
|
) |
|
|
|
for _ in range(max_length): |
|
tgt_emb = ( |
|
self.decoder_embedding(tgt) + self.pos_embedding[:, : tgt.shape[1], :] |
|
) |
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.shape[1]).to( |
|
tgt.device |
|
) |
|
|
|
output = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask) |
|
|
|
logits: torch.Tensor = self.fc(output[:, -1, :]) |
|
|
|
if temperature is not None and temperature > 0: |
|
probs = torch.softmax(input=logits / temperature, dim=-1) |
|
next_token = torch.multinomial(input=probs, num_samples=1) |
|
else: |
|
probs = torch.softmax(input=logits, dim=-1) |
|
next_token = logits.argmax(dim=-1, keepdim=True) |
|
|
|
|
|
if eos_p is not None and eos_p > 0: |
|
if torch.all(probs[:, self.config.eos_token_id] > eos_p): |
|
break |
|
|
|
|
|
if torch.all(next_token == self.config.eos_token_id): |
|
break |
|
|
|
tgt = torch.cat([tgt, next_token], dim=1) |
|
if (next_token == self.config.eos_token_id).all(): |
|
break |
|
|
|
|
|
return tgt[:, 1:] |
|
|