Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| import os | |
| import random | |
| from typing import Union | |
| import soundfile as sf | |
| import torch | |
| import yaml | |
| import json | |
| import argparse | |
| import numpy as np | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from pprint import pprint | |
| from scipy.io import wavfile | |
| import warnings | |
| import torchaudio | |
| warnings.filterwarnings("ignore") | |
| import look2hear.models | |
| import look2hear.datas | |
| from look2hear.metrics import MetricsTracker | |
| from look2hear.utils import tensors_to_device, RichProgressBarTheme, MyMetricsTextColumn, BatchesProcessedColumn | |
| from rich.progress import ( | |
| BarColumn, | |
| Progress, | |
| TextColumn, | |
| TimeRemainingColumn, | |
| TransferSpeedColumn, | |
| ) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--conf_dir", | |
| default="local/mixit_conf.yml", | |
| help="Full path to save best validation model") | |
| compute_metrics = ["si_sdr", "sdr"] | |
| os.environ['CUDA_VISIBLE_DEVICES'] = "8" | |
| def main(config): | |
| metricscolumn = MyMetricsTextColumn(style=RichProgressBarTheme.metrics) | |
| progress = Progress( | |
| TextColumn("[bold blue]Testing", justify="right"), | |
| BarColumn(bar_width=None), | |
| "•", | |
| BatchesProcessedColumn(style=RichProgressBarTheme.batch_progress), | |
| "•", | |
| TransferSpeedColumn(), | |
| "•", | |
| TimeRemainingColumn(), | |
| "•", | |
| metricscolumn | |
| ) | |
| # import pdb; pdb.set_trace() | |
| config["train_conf"]["main_args"]["exp_dir"] = os.path.join( | |
| os.getcwd(), "Experiments", "checkpoint", config["train_conf"]["exp"]["exp_name"] | |
| ) | |
| model_path = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "best_model.pth") | |
| # import pdb; pdb.set_trace() | |
| # conf["train_conf"]["masknet"].update({"n_src": 2}) | |
| model = getattr(look2hear.models, config["train_conf"]["audionet"]["audionet_name"]).from_pretrain( | |
| model_path, | |
| sample_rate=config["train_conf"]["datamodule"]["data_config"]["sample_rate"], | |
| **config["train_conf"]["audionet"]["audionet_config"], | |
| ) | |
| if config["train_conf"]["training"]["gpus"]: | |
| device = "cuda" | |
| model.to(device) | |
| model_device = next(model.parameters()).device | |
| datamodule: object = getattr(look2hear.datas, config["train_conf"]["datamodule"]["data_name"])( | |
| **config["train_conf"]["datamodule"]["data_config"] | |
| ) | |
| datamodule.setup() | |
| _, _ , test_set = datamodule.make_sets | |
| # Randomly choose the indexes of sentences to save. | |
| ex_save_dir = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "results/") | |
| os.makedirs(ex_save_dir, exist_ok=True) | |
| metrics = MetricsTracker( | |
| save_file=os.path.join(ex_save_dir, "metrics.csv")) | |
| torch.no_grad().__enter__() | |
| with progress: | |
| for idx in progress.track(range(len(test_set))): | |
| if idx == 825: | |
| # Forward the network on the mixture. | |
| mix, sources, key = tensors_to_device(test_set[idx], | |
| device=model_device) | |
| est_sources = model(mix[None]) | |
| mix_np = mix | |
| sources_np = sources | |
| est_sources_np = est_sources.squeeze(0) | |
| # metrics(mix=mix_np, | |
| # clean=sources_np, | |
| # estimate=est_sources_np, | |
| # key=key) | |
| save_dir = os.path.join("./result/TIGER", "idx{}".format(idx)) | |
| # est_sources_np = normalize_tensor_wav(est_sources_np) | |
| for i in range(est_sources_np.shape[0]): | |
| os.makedirs(os.path.join(save_dir, "s{}/".format(i + 1)), exist_ok=True) | |
| # torchaudio.save(os.path.join(save_dir, "s{}/".format(i + 1)) + key, est_sources_np[i].unsqueeze(0).cpu(), 16000) | |
| torchaudio.save(os.path.join(save_dir, "s{}/".format(i + 1)) + key.split("/")[-1], est_sources_np[i].unsqueeze(0).cpu(), 16000) | |
| # if idx % 50 == 0: | |
| # metricscolumn.update(metrics.update()) | |
| metrics.final() | |
| if __name__ == "__main__": | |
| args = parser.parse_args() | |
| arg_dic = dict(vars(args)) | |
| # Load training config | |
| with open(args.conf_dir, "rb") as f: | |
| train_conf = yaml.safe_load(f) | |
| arg_dic["train_conf"] = train_conf | |
| # print(arg_dic) | |
| main(arg_dic) | |