|
import argparse |
|
import logging |
|
import os |
|
from dataclasses import asdict |
|
import torch |
|
import torch.nn as nn |
|
from core.trainer import train_hubert_quantizer |
|
from core.model.hubert import ( |
|
HuBERTForBarkSemantic, |
|
HubertForBarkSemanticConfig, |
|
) |
|
from core.utils import download_dataset_from_hf |
|
from core.bark.constants import HUBERT_OUTPUT_VOCAB_SIZE |
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
) |
|
logger = logging.getLogger(__name__) |
|
WORKSPACE = "./" |
|
|
|
|
|
DATASET_REPO_ID = "sleeper371/bark-wave-semantic" |
|
|
|
CHECKPOINT_REPO_ID = "sleeper371/hubert-for-bark-semantic" |
|
|
|
HF_NOISE_FILE_NAME = "environmental_sound.zip" |
|
|
|
|
|
|
|
_LOCAL_NOISE_DATA_PATH = "noise_dataset" |
|
|
|
_LOCAL_TRAINING_DATA_PATH = "wav_semantic_dataset" |
|
|
|
_LOCAL_CHECKPOINTS_PATH = "checkpoints" |
|
|
|
|
|
def prefix_workspace(workspace_path: str, path: str) -> str: |
|
return os.path.join(workspace_path, path) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="HuBERT Training Script") |
|
parser.add_argument( |
|
"--hubert-checkpoint-name", |
|
type=str, |
|
default="facebook/hubert-base-ls960", |
|
help="checkpoint name that will be used as the feature extractor layer for CustomHuBERT", |
|
) |
|
parser.add_argument( |
|
"--feature-layer", |
|
type=int, |
|
default=11, |
|
help="layer at which to use features for the LSTM", |
|
) |
|
|
|
parser.add_argument( |
|
"--mix-precision", |
|
action="store_true", |
|
help="train model with mix precision bfloat16 and gradient scaler", |
|
) |
|
|
|
parser.add_argument( |
|
"--lr", type=float, default=8e-5, help="Learning rate (default: 8e-5)" |
|
) |
|
parser.add_argument( |
|
"--num-epochs", |
|
type=int, |
|
default=3, |
|
help="Number of training epochs (default: 3)", |
|
) |
|
parser.add_argument( |
|
"--train-ratio", |
|
type=float, |
|
default=0.8, |
|
help="Train/validation split ratio (default: 0.8)", |
|
) |
|
parser.add_argument( |
|
"--batch-size", |
|
type=int, |
|
default=2, |
|
help="Batch size for training (default: 16)", |
|
) |
|
parser.add_argument( |
|
"--dataset-file-name", |
|
type=str, |
|
default="short_sentences.zip", |
|
help="name of the dataset file in the HF repo to download", |
|
) |
|
|
|
parser.add_argument( |
|
"--save-checkpoint-every", |
|
type=int, |
|
default=1, |
|
help="Save checkpoint every N epochs (default: 1)", |
|
) |
|
|
|
parser.add_argument( |
|
"--model-bfloat16", |
|
action="store_true", |
|
default=False, |
|
help="set true to convert and train model in bfloat16", |
|
) |
|
|
|
parser.add_argument( |
|
"--augment-data-with-noise", |
|
action="store_true", |
|
default=False, |
|
help="load and add noise randomly to training data as a regularization technique", |
|
) |
|
|
|
parser.add_argument( |
|
"--augment-prob", |
|
type=float, |
|
default=0.5, |
|
help="noise will be added to audio sample with this probability", |
|
) |
|
|
|
parser.add_argument( |
|
"--publish-hf", |
|
action="store_true", |
|
default=False, |
|
help="if set, publish checkpoints to huggingface hub", |
|
) |
|
|
|
parser.add_argument( |
|
"--workspace", |
|
type=str, |
|
default=WORKSPACE, |
|
help="workspace folder to store data", |
|
) |
|
|
|
parser.add_argument( |
|
"--num_samples", |
|
type=int, |
|
default=10000, |
|
help="number of examples to load from the dataset", |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def ensure_directory(path: str): |
|
"""Create directory if it doesn't exist.""" |
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
def calculate_model_memory(model: nn.Module): |
|
""" |
|
Calculate and print the memory usage of a PyTorch model's parameters based on their detected data type. |
|
|
|
Args: |
|
model (nn.Module): The PyTorch model to analyze. |
|
""" |
|
|
|
bytes_per_param_dict = { |
|
torch.float32: 4, |
|
torch.float16: 2, |
|
torch.int8: 1, |
|
torch.int32: 4, |
|
torch.int64: 8, |
|
} |
|
|
|
|
|
param_iter = iter(model.parameters()) |
|
try: |
|
first_param = next(param_iter) |
|
dtype = first_param.dtype |
|
except StopIteration: |
|
print("Model has no parameters!") |
|
return |
|
|
|
|
|
|
|
bytes_per_param = bytes_per_param_dict.get(dtype, 4) |
|
dtype_name = str(dtype).replace("torch.", "") |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
|
|
|
total_memory_bytes = total_params * bytes_per_param |
|
|
|
|
|
total_memory_kb = total_memory_bytes / 1024 |
|
total_memory_mb = total_memory_kb / 1024 |
|
total_memory_gb = total_memory_mb / 1024 |
|
|
|
|
|
logger.info(f"Model Memory Usage (Detected dtype: {dtype_name}):") |
|
logger.info(f"Total Parameters: {total_params:,}") |
|
logger.info(f"Total Memory: {total_memory_gb:,.2f} GB") |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
|
|
LOCAL_NOISE_DATA_PATH = prefix_workspace(args.workspace, _LOCAL_NOISE_DATA_PATH) |
|
|
|
LOCAL_TRAINING_DATA_PATH = prefix_workspace( |
|
args.workspace, _LOCAL_TRAINING_DATA_PATH |
|
) |
|
|
|
LOCAL_CHECKPOINTS_PATH = prefix_workspace(args.workspace, _LOCAL_CHECKPOINTS_PATH) |
|
|
|
|
|
ensure_directory(LOCAL_CHECKPOINTS_PATH) |
|
|
|
logger.info("Starting HuBERT training") |
|
|
|
device = ( |
|
torch.device("cuda") |
|
if torch.cuda.is_available() |
|
else ( |
|
torch.device("mps") |
|
if torch.backends.mps.is_available() |
|
else torch.device("cpu") |
|
) |
|
) |
|
|
|
config = HubertForBarkSemanticConfig( |
|
vocab_size=HUBERT_OUTPUT_VOCAB_SIZE, |
|
checkpoint_name=args.hubert_checkpoint_name, |
|
feature_layer=args.feature_layer, |
|
num_decoder_layer=6, |
|
) |
|
model = HuBERTForBarkSemantic( |
|
config=config, load_hubert_pretrained_weights=True, device=device |
|
) |
|
|
|
if args.model_bfloat16: |
|
model = model.to(torch.bfloat16) |
|
logger.info("Training model in bfloat16 precision") |
|
|
|
calculate_model_memory(model) |
|
|
|
|
|
if not os.path.exists(LOCAL_TRAINING_DATA_PATH): |
|
download_dataset_from_hf( |
|
DATASET_REPO_ID, |
|
args.dataset_file_name, |
|
LOCAL_TRAINING_DATA_PATH, |
|
) |
|
|
|
if args.augment_data_with_noise and not os.path.exists(LOCAL_NOISE_DATA_PATH): |
|
download_dataset_from_hf( |
|
DATASET_REPO_ID, |
|
HF_NOISE_FILE_NAME, |
|
LOCAL_NOISE_DATA_PATH, |
|
) |
|
|
|
|
|
trained_model = train_hubert_quantizer( |
|
model=model, |
|
model_config=asdict(config), |
|
lr=args.lr, |
|
num_epoch=args.num_epochs, |
|
train_ratio=args.train_ratio, |
|
batch_size=args.batch_size, |
|
data_path=LOCAL_TRAINING_DATA_PATH, |
|
checkpoint_path=LOCAL_CHECKPOINTS_PATH, |
|
save_checkpoint_every=args.save_checkpoint_every, |
|
augment_data_with_noise=args.augment_data_with_noise, |
|
augment_prob=args.augment_prob, |
|
noise_data_path=LOCAL_NOISE_DATA_PATH, |
|
publish_hf=args.publish_hf, |
|
publish_to_repo=CHECKPOINT_REPO_ID, |
|
device=device, |
|
num_samples=args.num_samples, |
|
enable_grad_scaler=args.mix_precision, |
|
) |
|
logger.info("Training completed") |
|
|
|
return trained_model |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|