|
import os |
|
import random |
|
from typing import Union |
|
import soundfile as sf |
|
import torch |
|
import yaml |
|
import json |
|
import argparse |
|
import numpy as np |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from pprint import pprint |
|
from scipy.io import wavfile |
|
import warnings |
|
import torchaudio |
|
warnings.filterwarnings("ignore") |
|
import look2hear.models |
|
import look2hear.datas |
|
from look2hear.metrics import MetricsTracker |
|
from look2hear.utils import tensors_to_device, RichProgressBarTheme, MyMetricsTextColumn, BatchesProcessedColumn |
|
|
|
from rich.progress import ( |
|
BarColumn, |
|
Progress, |
|
TextColumn, |
|
TimeRemainingColumn, |
|
TransferSpeedColumn, |
|
) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--conf_dir", |
|
default="local/mixit_conf.yml", |
|
help="Full path to save best validation model") |
|
|
|
|
|
compute_metrics = ["si_sdr", "sdr"] |
|
os.environ['CUDA_VISIBLE_DEVICES'] = "8" |
|
|
|
def main(config): |
|
metricscolumn = MyMetricsTextColumn(style=RichProgressBarTheme.metrics) |
|
progress = Progress( |
|
TextColumn("[bold blue]Testing", justify="right"), |
|
BarColumn(bar_width=None), |
|
"•", |
|
BatchesProcessedColumn(style=RichProgressBarTheme.batch_progress), |
|
"•", |
|
TransferSpeedColumn(), |
|
"•", |
|
TimeRemainingColumn(), |
|
"•", |
|
metricscolumn |
|
) |
|
|
|
config["train_conf"]["main_args"]["exp_dir"] = os.path.join( |
|
os.getcwd(), "Experiments", "checkpoint", config["train_conf"]["exp"]["exp_name"] |
|
) |
|
model_path = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "best_model.pth") |
|
|
|
|
|
model = getattr(look2hear.models, config["train_conf"]["audionet"]["audionet_name"]).from_pretrain( |
|
model_path, |
|
sample_rate=config["train_conf"]["datamodule"]["data_config"]["sample_rate"], |
|
**config["train_conf"]["audionet"]["audionet_config"], |
|
) |
|
if config["train_conf"]["training"]["gpus"]: |
|
device = "cuda" |
|
model.to(device) |
|
model_device = next(model.parameters()).device |
|
datamodule: object = getattr(look2hear.datas, config["train_conf"]["datamodule"]["data_name"])( |
|
**config["train_conf"]["datamodule"]["data_config"] |
|
) |
|
datamodule.setup() |
|
_, _ , test_set = datamodule.make_sets |
|
|
|
|
|
ex_save_dir = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "results/") |
|
os.makedirs(ex_save_dir, exist_ok=True) |
|
metrics = MetricsTracker( |
|
save_file=os.path.join(ex_save_dir, "metrics.csv")) |
|
torch.no_grad().__enter__() |
|
with progress: |
|
for idx in progress.track(range(len(test_set))): |
|
if idx == 825: |
|
|
|
mix, sources, key = tensors_to_device(test_set[idx], |
|
device=model_device) |
|
est_sources = model(mix[None]) |
|
mix_np = mix |
|
sources_np = sources |
|
est_sources_np = est_sources.squeeze(0) |
|
|
|
|
|
|
|
|
|
save_dir = os.path.join("./result/TIGER", "idx{}".format(idx)) |
|
|
|
for i in range(est_sources_np.shape[0]): |
|
os.makedirs(os.path.join(save_dir, "s{}/".format(i + 1)), exist_ok=True) |
|
|
|
torchaudio.save(os.path.join(save_dir, "s{}/".format(i + 1)) + key.split("/")[-1], est_sources_np[i].unsqueeze(0).cpu(), 16000) |
|
|
|
|
|
metrics.final() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args() |
|
arg_dic = dict(vars(args)) |
|
|
|
|
|
with open(args.conf_dir, "rb") as f: |
|
train_conf = yaml.safe_load(f) |
|
arg_dic["train_conf"] = train_conf |
|
|
|
main(arg_dic) |
|
|