Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| import os | |
| import sys | |
| import torch | |
| from torch import Tensor | |
| import argparse | |
| import json | |
| import look2hear.datas | |
| import look2hear.models | |
| import look2hear.system | |
| import look2hear.losses | |
| import look2hear.metrics | |
| import look2hear.utils | |
| from look2hear.system import make_optimizer | |
| from dataclasses import dataclass | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from torch.utils.data import DataLoader | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar | |
| from pytorch_lightning.callbacks.progress.rich_progress import * | |
| from rich.console import Console | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| from pytorch_lightning.loggers.wandb import WandbLogger | |
| from pytorch_lightning.strategies.ddp import DDPStrategy | |
| from rich import print, reconfigure | |
| from collections.abc import MutableMapping | |
| from look2hear.utils import print_only, MyRichProgressBar, RichProgressBarTheme | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import wandb | |
| wandb.login() | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--conf_dir", | |
| default="local/conf.yml", | |
| help="Full path to save best validation model", | |
| ) | |
| def main(config): | |
| print_only( | |
| "Instantiating datamodule <{}>".format(config["datamodule"]["data_name"]) | |
| ) | |
| datamodule: object = getattr(look2hear.datas, config["datamodule"]["data_name"])( | |
| **config["datamodule"]["data_config"] | |
| ) | |
| datamodule.setup() | |
| train_loader, val_loader, test_loader = datamodule.make_loader | |
| # Define model and optimizer | |
| print_only( | |
| "Instantiating AudioNet <{}>".format(config["audionet"]["audionet_name"]) | |
| ) | |
| model = getattr(look2hear.models, config["audionet"]["audionet_name"])( | |
| sample_rate=config["datamodule"]["data_config"]["sample_rate"], | |
| **config["audionet"]["audionet_config"], | |
| ) | |
| # import pdb; pdb.set_trace() | |
| print_only("Instantiating Optimizer <{}>".format(config["optimizer"]["optim_name"])) | |
| optimizer = make_optimizer(model.parameters(), **config["optimizer"]) | |
| # Define scheduler | |
| scheduler = None | |
| if config["scheduler"]["sche_name"]: | |
| print_only( | |
| "Instantiating Scheduler <{}>".format(config["scheduler"]["sche_name"]) | |
| ) | |
| if config["scheduler"]["sche_name"] != "DPTNetScheduler": | |
| scheduler = getattr(torch.optim.lr_scheduler, config["scheduler"]["sche_name"])( | |
| optimizer=optimizer, **config["scheduler"]["sche_config"] | |
| ) | |
| else: | |
| scheduler = { | |
| "scheduler": getattr(look2hear.system.schedulers, config["scheduler"]["sche_name"])( | |
| optimizer, len(train_loader) // config["datamodule"]["data_config"]["batch_size"], 64 | |
| ), | |
| "interval": "step", | |
| } | |
| # Just after instantiating, save the args. Easy loading in the future. | |
| config["main_args"]["exp_dir"] = os.path.join( | |
| os.getcwd(), "Experiments", "checkpoint", config["exp"]["exp_name"] | |
| ) | |
| exp_dir = config["main_args"]["exp_dir"] | |
| os.makedirs(exp_dir, exist_ok=True) | |
| conf_path = os.path.join(exp_dir, "conf.yml") | |
| with open(conf_path, "w") as outfile: | |
| yaml.safe_dump(config, outfile) | |
| # Define Loss function. | |
| print_only( | |
| "Instantiating Loss, Train <{}>, Val <{}>".format( | |
| config["loss"]["train"]["sdr_type"], config["loss"]["val"]["sdr_type"] | |
| ) | |
| ) | |
| loss_func = { | |
| "train": getattr(look2hear.losses, config["loss"]["train"]["loss_func"])( | |
| getattr(look2hear.losses, config["loss"]["train"]["sdr_type"]), | |
| **config["loss"]["train"]["config"], | |
| ), | |
| "val": getattr(look2hear.losses, config["loss"]["val"]["loss_func"])( | |
| getattr(look2hear.losses, config["loss"]["val"]["sdr_type"]), | |
| **config["loss"]["val"]["config"], | |
| ), | |
| } | |
| print_only("Instantiating System <{}>".format(config["training"]["system"])) | |
| system = getattr(look2hear.system, config["training"]["system"])( | |
| audio_model=model, | |
| loss_func=loss_func, | |
| optimizer=optimizer, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| test_loader=test_loader, | |
| scheduler=scheduler, | |
| config=config, | |
| ) | |
| # Define callbacks | |
| print_only("Instantiating ModelCheckpoint") | |
| callbacks = [] | |
| checkpoint_dir = os.path.join(exp_dir) | |
| checkpoint = ModelCheckpoint( | |
| checkpoint_dir, | |
| filename="{epoch}", | |
| monitor="val_loss/dataloader_idx_0", | |
| mode="min", | |
| save_top_k=5, | |
| verbose=True, | |
| save_last=True, | |
| ) | |
| callbacks.append(checkpoint) | |
| if config["training"]["early_stop"]: | |
| print_only("Instantiating EarlyStopping") | |
| callbacks.append(EarlyStopping(**config["training"]["early_stop"])) | |
| callbacks.append(MyRichProgressBar(theme=RichProgressBarTheme())) | |
| # Don't ask GPU if they are not available. | |
| gpus = config["training"]["gpus"] if torch.cuda.is_available() else None | |
| distributed_backend = "cuda" if torch.cuda.is_available() else None | |
| # default logger used by trainer | |
| logger_dir = os.path.join(os.getcwd(), "Experiments", "tensorboard_logs") | |
| os.makedirs(os.path.join(logger_dir, config["exp"]["exp_name"]), exist_ok=True) | |
| # comet_logger = TensorBoardLogger(logger_dir, name=config["exp"]["exp_name"]) | |
| comet_logger = WandbLogger( | |
| name=config["exp"]["exp_name"], | |
| save_dir=os.path.join(logger_dir, config["exp"]["exp_name"]), | |
| project="Real-work-dataset", | |
| # offline=True | |
| ) | |
| trainer = pl.Trainer( | |
| max_epochs=config["training"]["epochs"], | |
| callbacks=callbacks, | |
| default_root_dir=exp_dir, | |
| devices=gpus, | |
| accelerator=distributed_backend, | |
| strategy=DDPStrategy(find_unused_parameters=True), | |
| limit_train_batches=1.0, # Useful for fast experiment | |
| gradient_clip_val=5.0, | |
| logger=comet_logger, | |
| sync_batchnorm=True, | |
| # precision="bf16-mixed", | |
| # num_sanity_val_steps=0, | |
| # sync_batchnorm=True, | |
| # fast_dev_run=True, | |
| ) | |
| trainer.fit(system) | |
| print_only("Finished Training") | |
| best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} | |
| with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: | |
| json.dump(best_k, f, indent=0) | |
| state_dict = torch.load(checkpoint.best_model_path) | |
| system.load_state_dict(state_dict=state_dict["state_dict"]) | |
| system.cpu() | |
| to_save = system.audio_model.serialize() | |
| torch.save(to_save, os.path.join(exp_dir, "best_model.pth")) | |
| if __name__ == "__main__": | |
| import yaml | |
| from pprint import pprint | |
| from look2hear.utils.parser_utils import ( | |
| prepare_parser_from_dict, | |
| parse_args_as_dict, | |
| ) | |
| args = parser.parse_args() | |
| with open(args.conf_dir) as f: | |
| def_conf = yaml.safe_load(f) | |
| parser = prepare_parser_from_dict(def_conf, parser=parser) | |
| arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) | |
| # pprint(arg_dic) | |
| main(arg_dic) | |