Spaces:
Running
Running
| import time | |
| import argparse | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import yaml | |
| # from tqdm import tqdm | |
| import copy | |
| from pathlib import Path | |
| import diffusers | |
| import datasets | |
| import numpy as np | |
| import pandas as pd | |
| import wandb | |
| import transformers | |
| import torch | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import set_seed | |
| from datasets import load_dataset | |
| from torch.utils.data import Dataset, DataLoader | |
| from tqdm.auto import tqdm | |
| from transformers import SchedulerType, get_scheduler | |
| from tangoflux.model import TangoFlux | |
| from datasets import load_dataset, Audio | |
| from tangoflux.utils import Text2AudioDataset, read_wav_file, DPOText2AudioDataset | |
| from diffusers import AutoencoderOobleck | |
| import torchaudio | |
| logger = get_logger(__name__) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Rectified flow for text to audio generation task." | |
| ) | |
| parser.add_argument( | |
| "--num_examples", | |
| type=int, | |
| default=-1, | |
| help="How many examples to use for training and validation.", | |
| ) | |
| parser.add_argument( | |
| "--text_column", | |
| type=str, | |
| default="captions", | |
| help="The name of the column in the datasets containing the input texts.", | |
| ) | |
| parser.add_argument( | |
| "--audio_column", | |
| type=str, | |
| default="location", | |
| help="The name of the column in the datasets containing the audio paths.", | |
| ) | |
| parser.add_argument( | |
| "--adam_beta1", | |
| type=float, | |
| default=0.9, | |
| help="The beta1 parameter for the Adam optimizer.", | |
| ) | |
| parser.add_argument( | |
| "--adam_beta2", | |
| type=float, | |
| default=0.95, | |
| help="The beta2 parameter for the Adam optimizer.", | |
| ) | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="tangoflux_config.yaml", | |
| help="Config file defining the model size.", | |
| ) | |
| parser.add_argument( | |
| "--weight_decay", type=float, default=1e-8, help="Weight decay to use." | |
| ) | |
| parser.add_argument( | |
| "--max_train_steps", | |
| type=int, | |
| default=None, | |
| help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | |
| ) | |
| parser.add_argument( | |
| "--lr_scheduler_type", | |
| type=SchedulerType, | |
| default="linear", | |
| help="The scheduler type to use.", | |
| choices=[ | |
| "linear", | |
| "cosine", | |
| "cosine_with_restarts", | |
| "polynomial", | |
| "constant", | |
| "constant_with_warmup", | |
| ], | |
| ) | |
| parser.add_argument( | |
| "--adam_epsilon", | |
| type=float, | |
| default=1e-08, | |
| help="Epsilon value for the Adam optimizer", | |
| ) | |
| parser.add_argument( | |
| "--adam_weight_decay", | |
| type=float, | |
| default=1e-2, | |
| help="Epsilon value for the Adam optimizer", | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=None, help="A seed for reproducible training." | |
| ) | |
| parser.add_argument( | |
| "--checkpointing_steps", | |
| type=str, | |
| default="best", | |
| help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.", | |
| ) | |
| parser.add_argument( | |
| "--save_every", | |
| type=int, | |
| default=5, | |
| help="Save model after every how many epochs when checkpointing_steps is set to best.", | |
| ) | |
| parser.add_argument( | |
| "--load_from_checkpoint", | |
| type=str, | |
| default=None, | |
| help="Whether to continue training from a model weight", | |
| ) | |
| args = parser.parse_args() | |
| # Sanity checks | |
| # if args.train_file is None and args.validation_file is None: | |
| # raise ValueError("Need a training/validation file.") | |
| # else: | |
| # if args.train_file is not None: | |
| # extension = args.train_file.split(".")[-1] | |
| # assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." | |
| # if args.validation_file is not None: | |
| # extension = args.validation_file.split(".")[-1] | |
| # assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." | |
| return args | |
| def main(): | |
| args = parse_args() | |
| accelerator_log_kwargs = {} | |
| def load_config(config_path): | |
| with open(config_path, "r") as file: | |
| return yaml.safe_load(file) | |
| config = load_config(args.config) | |
| learning_rate = float(config["training"]["learning_rate"]) | |
| num_train_epochs = int(config["training"]["num_train_epochs"]) | |
| num_warmup_steps = int(config["training"]["num_warmup_steps"]) | |
| per_device_batch_size = int(config["training"]["per_device_batch_size"]) | |
| gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"]) | |
| output_dir = config["paths"]["output_dir"] | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| **accelerator_log_kwargs, | |
| ) | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state, main_process_only=False) | |
| datasets.utils.logging.set_verbosity_error() | |
| diffusers.utils.logging.set_verbosity_error() | |
| transformers.utils.logging.set_verbosity_error() | |
| # If passed along, set the training seed now. | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| # Handle output directory creation and wandb tracking | |
| if accelerator.is_main_process: | |
| if output_dir is None or output_dir == "": | |
| output_dir = "saved/" + str(int(time.time())) | |
| if not os.path.exists("saved"): | |
| os.makedirs("saved") | |
| os.makedirs(output_dir, exist_ok=True) | |
| elif output_dir is not None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True) | |
| with open("{}/summary.jsonl".format(output_dir), "a") as f: | |
| f.write(json.dumps(dict(vars(args))) + "\n\n") | |
| accelerator.project_configuration.automatic_checkpoint_naming = False | |
| wandb.init( | |
| project="Text to Audio Flow matching DPO", | |
| settings=wandb.Settings(_disable_stats=True), | |
| ) | |
| accelerator.wait_for_everyone() | |
| # Get the datasets | |
| data_files = {} | |
| # if args.train_file is not None: | |
| if config["paths"]["train_file"] != "": | |
| data_files["train"] = config["paths"]["train_file"] | |
| # if args.validation_file is not None: | |
| if config["paths"]["val_file"] != "": | |
| data_files["validation"] = config["paths"]["val_file"] | |
| if config["paths"]["test_file"] != "": | |
| data_files["test"] = config["paths"]["test_file"] | |
| else: | |
| data_files["test"] = config["paths"]["val_file"] | |
| extension = "json" | |
| train_dataset = load_dataset(extension, data_files=data_files["train"]) | |
| data_files.pop("train") | |
| raw_datasets = load_dataset(extension, data_files=data_files) | |
| text_column, audio_column = args.text_column, args.audio_column | |
| model = TangoFlux(config=config["model"], initialize_reference_model=True) | |
| vae = AutoencoderOobleck.from_pretrained( | |
| "stabilityai/stable-audio-open-1.0", subfolder="vae" | |
| ) | |
| ## Freeze vae | |
| for param in vae.parameters(): | |
| vae.requires_grad = False | |
| vae.eval() | |
| ## Freeze text encoder param | |
| for param in model.text_encoder.parameters(): | |
| param.requires_grad = False | |
| model.text_encoder.eval() | |
| prefix = "" | |
| with accelerator.main_process_first(): | |
| train_dataset = DPOText2AudioDataset( | |
| train_dataset["train"], | |
| prefix, | |
| text_column, | |
| "chosen", | |
| "reject", | |
| "duration", | |
| args.num_examples, | |
| ) | |
| eval_dataset = Text2AudioDataset( | |
| raw_datasets["validation"], | |
| prefix, | |
| text_column, | |
| audio_column, | |
| "duration", | |
| args.num_examples, | |
| ) | |
| test_dataset = Text2AudioDataset( | |
| raw_datasets["test"], | |
| prefix, | |
| text_column, | |
| audio_column, | |
| "duration", | |
| args.num_examples, | |
| ) | |
| accelerator.print( | |
| "Num instances in train: {}, validation: {}, test: {}".format( | |
| train_dataset.get_num_instances(), | |
| eval_dataset.get_num_instances(), | |
| test_dataset.get_num_instances(), | |
| ) | |
| ) | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| shuffle=True, | |
| batch_size=config["training"]["per_device_batch_size"], | |
| collate_fn=train_dataset.collate_fn, | |
| ) | |
| eval_dataloader = DataLoader( | |
| eval_dataset, | |
| shuffle=True, | |
| batch_size=config["training"]["per_device_batch_size"], | |
| collate_fn=eval_dataset.collate_fn, | |
| ) | |
| test_dataloader = DataLoader( | |
| test_dataset, | |
| shuffle=False, | |
| batch_size=config["training"]["per_device_batch_size"], | |
| collate_fn=test_dataset.collate_fn, | |
| ) | |
| # Optimizer | |
| optimizer_parameters = list(model.transformer.parameters()) + list( | |
| model.fc.parameters() | |
| ) | |
| num_trainable_parameters = sum( | |
| p.numel() for p in model.parameters() if p.requires_grad | |
| ) | |
| accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters)) | |
| if args.load_from_checkpoint: | |
| from safetensors.torch import load_file | |
| w1 = load_file(args.load_from_checkpoint) | |
| model.load_state_dict(w1, strict=False) | |
| logger.info("Weights loaded from{}".format(args.load_from_checkpoint)) | |
| import copy | |
| model.ref_transformer = copy.deepcopy(model.transformer) | |
| model.ref_transformer.requires_grad_ = False | |
| model.ref_transformer.eval() | |
| for param in model.ref_transformer.parameters(): | |
| param.requires_grad = False | |
| optimizer = torch.optim.AdamW( | |
| optimizer_parameters, | |
| lr=learning_rate, | |
| betas=(args.adam_beta1, args.adam_beta2), | |
| weight_decay=args.adam_weight_decay, | |
| eps=args.adam_epsilon, | |
| ) | |
| # Scheduler and math around the number of training steps. | |
| overrode_max_train_steps = False | |
| num_update_steps_per_epoch = math.ceil( | |
| len(train_dataloader) / gradient_accumulation_steps | |
| ) | |
| if args.max_train_steps is None: | |
| args.max_train_steps = num_train_epochs * num_update_steps_per_epoch | |
| overrode_max_train_steps = True | |
| lr_scheduler = get_scheduler( | |
| name=args.lr_scheduler_type, | |
| optimizer=optimizer, | |
| num_warmup_steps=num_warmup_steps | |
| * gradient_accumulation_steps | |
| * accelerator.num_processes, | |
| num_training_steps=args.max_train_steps * gradient_accumulation_steps, | |
| ) | |
| # Prepare everything with our `accelerator`. | |
| vae, model, optimizer, lr_scheduler = accelerator.prepare( | |
| vae, model, optimizer, lr_scheduler | |
| ) | |
| train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare( | |
| train_dataloader, eval_dataloader, test_dataloader | |
| ) | |
| # We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
| num_update_steps_per_epoch = math.ceil( | |
| len(train_dataloader) / gradient_accumulation_steps | |
| ) | |
| if overrode_max_train_steps: | |
| args.max_train_steps = num_train_epochs * num_update_steps_per_epoch | |
| # Afterwards we recalculate our number of training epochs | |
| num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
| # Figure out how many steps we should save the Accelerator states | |
| checkpointing_steps = args.checkpointing_steps | |
| if checkpointing_steps is not None and checkpointing_steps.isdigit(): | |
| checkpointing_steps = int(checkpointing_steps) | |
| # Train! | |
| total_batch_size = ( | |
| per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps | |
| ) | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(train_dataset)}") | |
| logger.info(f" Num Epochs = {num_train_epochs}") | |
| logger.info(f" Instantaneous batch size per device = {per_device_batch_size}") | |
| logger.info( | |
| f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" | |
| ) | |
| logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
| logger.info(f" Total optimization steps = {args.max_train_steps}") | |
| # Only show the progress bar once on each machine. | |
| progress_bar = tqdm( | |
| range(args.max_train_steps), disable=not accelerator.is_local_main_process | |
| ) | |
| completed_steps = 0 | |
| starting_epoch = 0 | |
| # Potentially load in the weights and states from a previous save | |
| resume_from_checkpoint = config["paths"]["resume_from_checkpoint"] | |
| if resume_from_checkpoint != "": | |
| accelerator.load_state(resume_from_checkpoint) | |
| accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}") | |
| # Duration of the audio clips in seconds | |
| best_loss = np.inf | |
| length = config["training"]["max_audio_duration"] | |
| for epoch in range(starting_epoch, num_train_epochs): | |
| model.train() | |
| total_loss, total_val_loss = 0, 0 | |
| for step, batch in enumerate(train_dataloader): | |
| optimizer.zero_grad() | |
| with accelerator.accumulate(model): | |
| optimizer.zero_grad() | |
| device = accelerator.device | |
| text, audio_w, audio_l, duration, _ = batch | |
| with torch.no_grad(): | |
| audio_list_w = [] | |
| audio_list_l = [] | |
| for audio_path in audio_w: | |
| wav = read_wav_file( | |
| audio_path, length | |
| ) ## Only read the first 30 seconds of audio | |
| if ( | |
| wav.shape[0] == 1 | |
| ): ## If this audio is mono, we repeat the channel so it become "fake stereo" | |
| wav = wav.repeat(2, 1) | |
| audio_list_w.append(wav) | |
| for audio_path in audio_l: | |
| wav = read_wav_file( | |
| audio_path, length | |
| ) ## Only read the first 30 seconds of audio | |
| if ( | |
| wav.shape[0] == 1 | |
| ): ## If this audio is mono, we repeat the channel so it become "fake stereo" | |
| wav = wav.repeat(2, 1) | |
| audio_list_l.append(wav) | |
| audio_input_w = torch.stack(audio_list_w, dim=0).to(device) | |
| audio_input_l = torch.stack(audio_list_l, dim=0).to(device) | |
| # audio_input_ = audio_input.to(device) | |
| unwrapped_vae = accelerator.unwrap_model(vae) | |
| duration = torch.tensor(duration, device=device) | |
| duration = torch.clamp( | |
| duration, max=length | |
| ) ## max duration is 30 sec | |
| audio_latent_w = unwrapped_vae.encode( | |
| audio_input_w | |
| ).latent_dist.sample() | |
| audio_latent_l = unwrapped_vae.encode( | |
| audio_input_l | |
| ).latent_dist.sample() | |
| audio_latent = torch.cat((audio_latent_w, audio_latent_l), dim=0) | |
| audio_latent = audio_latent.transpose( | |
| 1, 2 | |
| ) ## Tranpose to (bsz, seq_len, channel) | |
| loss, raw_model_loss, raw_ref_loss, implicit_acc = model( | |
| audio_latent, text, duration=duration, sft=False | |
| ) | |
| total_loss += loss.detach().float() | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| # if accelerator.sync_gradients: | |
| if accelerator.sync_gradients: | |
| # accelerator.clip_grad_value_(model.parameters(),1.0) | |
| progress_bar.update(1) | |
| completed_steps += 1 | |
| if completed_steps % 10 == 0 and accelerator.is_main_process: | |
| total_norm = 0.0 | |
| for p in model.parameters(): | |
| if p.grad is not None: | |
| param_norm = p.grad.data.norm(2) | |
| total_norm += param_norm.item() ** 2 | |
| total_norm = total_norm**0.5 | |
| logger.info( | |
| f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}" | |
| ) | |
| lr = lr_scheduler.get_last_lr()[0] | |
| result = { | |
| "train_loss": loss.item(), | |
| "grad_norm": total_norm, | |
| "learning_rate": lr, | |
| "raw_model_loss": raw_model_loss, | |
| "raw_ref_loss": raw_ref_loss, | |
| "implicit_acc": implicit_acc, | |
| } | |
| # result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4) | |
| wandb.log(result, step=completed_steps) | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if isinstance(checkpointing_steps, int): | |
| if completed_steps % checkpointing_steps == 0: | |
| output_dir = f"step_{completed_steps }" | |
| if output_dir is not None: | |
| output_dir = os.path.join(output_dir, output_dir) | |
| accelerator.save_state(output_dir) | |
| if completed_steps >= args.max_train_steps: | |
| break | |
| model.eval() | |
| eval_progress_bar = tqdm( | |
| range(len(eval_dataloader)), disable=not accelerator.is_local_main_process | |
| ) | |
| for step, batch in enumerate(eval_dataloader): | |
| with accelerator.accumulate(model) and torch.no_grad(): | |
| device = model.device | |
| text, audios, duration, _ = batch | |
| audio_list = [] | |
| for audio_path in audios: | |
| wav = read_wav_file( | |
| audio_path, length | |
| ) ## Only read the first 30 seconds of audio | |
| if ( | |
| wav.shape[0] == 1 | |
| ): ## If this audio is mono, we repeat the channel so it become "fake stereo" | |
| wav = wav.repeat(2, 1) | |
| audio_list.append(wav) | |
| audio_input = torch.stack(audio_list, dim=0) | |
| audio_input = audio_input.to(device) | |
| duration = torch.tensor(duration, device=device) | |
| unwrapped_vae = accelerator.unwrap_model(vae) | |
| audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample() | |
| audio_latent = audio_latent.transpose( | |
| 1, 2 | |
| ) ## Tranpose to (bsz, seq_len, channel) | |
| val_loss, _, _, _ = model( | |
| audio_latent, text, duration=duration, sft=True | |
| ) | |
| total_val_loss += val_loss.detach().float() | |
| eval_progress_bar.update(1) | |
| if accelerator.is_main_process: | |
| result = {} | |
| result["epoch"] = float(epoch + 1) | |
| result["epoch/train_loss"] = round( | |
| total_loss.item() / len(train_dataloader), 4 | |
| ) | |
| result["epoch/val_loss"] = round( | |
| total_val_loss.item() / len(eval_dataloader), 4 | |
| ) | |
| wandb.log(result, step=completed_steps) | |
| with open("{}/summary.jsonl".format(output_dir), "a") as f: | |
| f.write(json.dumps(result) + "\n\n") | |
| logger.info(result) | |
| save_checkpoint = True | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process and args.checkpointing_steps == "best": | |
| if save_checkpoint: | |
| accelerator.save_state("{}/{}".format(output_dir, "best")) | |
| if (epoch + 1) % args.save_every == 0: | |
| accelerator.save_state( | |
| "{}/{}".format(output_dir, "epoch_" + str(epoch + 1)) | |
| ) | |
| if accelerator.is_main_process and args.checkpointing_steps == "epoch": | |
| accelerator.save_state( | |
| "{}/{}".format(output_dir, "epoch_" + str(epoch + 1)) | |
| ) | |
| if __name__ == "__main__": | |
| main() | |