Spaces:
Sleeping
Sleeping
File size: 4,715 Bytes
2216a22 f96e2ca 2216a22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import argparse
import pathlib
import tqdm
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy
from score import Score
import torch
import warnings
warnings.filterwarnings("ignore")
def get_arg():
parser = argparse.ArgumentParser()
parser.add_argument("--bs", required=False, default=None, type=int)
parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str)
parser.add_argument("--ckpt_path", required=False, default="wavlm_ecapa.model", type=pathlib.Path)
parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path)
parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path)
parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path)
parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path)
parser.add_argument("--out_path", required=True, type=pathlib.Path)
parser.add_argument("--num_workers", required=False, default=0, type=int)
return parser.parse_args()
def loadWav(filename, max_frames: int = 400):
# Maximum audio length
max_audio = max_frames * 160 + 240
# Read wav file and convert to torch tensor
if type(filename) == tuple:
sr, audio = filename
audio = librosa.util.normalize(audio)
else:
audio, sr = librosa.load(filename, sr=16000)
audio_org = audio.copy()
audiosize = audio.shape[0]
if audiosize <= max_audio:
shortage = max_audio - audiosize + 1
audio = numpy.pad(audio, (0, shortage), 'wrap')
audiosize = audio.shape[0]
startframe = numpy.linspace(0,audiosize-max_audio,num=10)
feats = []
for asf in startframe:
feats.append(audio[int(asf):int(asf)+max_audio])
feat = numpy.stack(feats,axis=0).astype(numpy.float32)
return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32))
class AudioDataset(Dataset):
def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
self.inp_wavlist = list(inp_dir_path.glob("*.wav"))
self.ref_wavlist = list(ref_dir_path.glob("*.wav"))
assert len(self.inp_wavlist) == len(self.ref_wavlist)
self.inp_wavlist.sort()
self.ref_wavlist.sort()
_, self.sr = librosa.load(self.inp_wavlist[0], sr=None)
self.max_audio = max_frames * 160 + 240
def __len__(self):
return len(self.inp_wavlist)
def __getitem__(self, idx):
inp_wavs, inp_wav = loadWav(self.inp_wavlist[idx])
ref_wavs, ref_wav = loadWav(self.ref_wavlist[idx])
return inp_wavs, inp_wav, ref_wavs, ref_wav
def main():
args = get_arg()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.mode == "predict_file":
assert args.inp_path is not None
assert args.ref_path is not None
assert args.inp_dir is None
assert args.ref_dir is None
assert args.inp_path.exists()
assert args.inp_path.is_file()
assert args.ref_path.exists()
assert args.ref_path.is_file()
inp_wavs, inp_wav = loadWav(args.inp_path)
ref_wavs, ref_wav = loadWav(args.ref_path)
scorer = Score(ckpt_path=args.ckpt_path, device=device)
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
print("Voxsim score: ", score[0])
with open(args.out_path, "w") as fw:
fw.write(str(score[0]))
else:
assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
assert args.bs is not None, "bs is required when mode is predict_dir."
assert args.inp_path is None, "inp_path should be None"
assert args.ref_path is None, "ref_path should be None"
assert args.inp_dir.exists()
assert args.ref_dir.exists()
assert args.inp_dir.is_dir()
assert args.ref_dir.is_dir()
dataset = AudioDataset(args.inp_dir, args.ref_dir)
loader = DataLoader(
dataset,
batch_size=args.bs,
shuffle=False,
num_workers=args.num_workers)
scorer = Score(ckpt_path=args.ckpt_path, device=device)
with open(args.out_path, 'w'):
pass
for batch in tqdm.tqdm(loader):
scores = score.score(batch.to(device))
with open(args.out_path, 'a') as fw:
for s in scores:
fw.write(str(s) + "\n")
print("save to ", args.out_path)
if __name__ == "__main__":
main() |