Spaces:
Running
Running
File size: 4,180 Bytes
2216a22 ce904ba 2216a22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import torch
import torch.nn.functional as F
from ssl_ecapa_model import SSL_ECAPA_TDNN
from huggingface_hub import hf_hub_download
def load_model(ckpt_path):
model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large')
load_parameters(model, ckpt_path)
return model
def load_parameters(model, ckpt_path):
model_state = model.state_dict()
if not os.path.isfile(ckpt_path):
print("Downloading model from Hugging Face Hub...")
new_ckpt_path = hf_hub_download(repo_id="junseok520/voxsim-models", filename=ckpt_path, local_dir="./")
ckpt_path = new_ckpt_path
loaded_state = torch.load(ckpt_path, map_location='cpu', weights_only=True)
for name, param in loaded_state.items():
if name.startswith('__S__.'):
if name[6:] in model_state:
model_state[name[6:]].copy_(param)
else:
print("{} is not in the model.".format(name[6:]))
else:
if name in model_state:
model_state[name].copy_(param)
else:
print("{} is not in the model.".format(name))
class Score:
"""Predicting score for each audio clip."""
def __init__(
self,
ckpt_path: str = "wavlm_ecapa.pt",
device: str = "gpu"):
"""
Args:
ckpt_path: path to pretrained checkpoint of voxsim evaluator.
input_sample_rate: sampling rate of input audio tensor. The input audio tensor
is automatically downsampled to 16kHz.
"""
print(f"Using device: {device}")
self.device = device
self.model = load_model(ckpt_path).to(self.device)
self.model.eval()
def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor:
"""
Args:
wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
the model processes the input as a single audio clip. The model
performs batch processing when len(wavs) == 3.
"""
# if len(wavs.shape) == 1:
# out_wavs = wavs.unsqueeze(0).unsqueeze(0)
# elif len(wavs.shape) == 2:
# out_wavs = wavs.unsqueeze(0)
# elif len(wavs.shape) == 3:
# out_wavs = wavs
# else:
# raise ValueError('Dimension of input tensor needs to be <= 3.')
if len(inp_wavs.shape) == 2:
bs = 1
elif len(inp_wavs.shape) == 3:
bs = inp_wavs.shape[0]
else:
raise ValueError('Dimension of input tensor needs to be <= 3.')
inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device)
inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device)
ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device)
ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device)
# assert inp_wavs.shape[1] == 10
# assert ref_wavs.shape[1] == 10
# assert inp_wav.shape[1] == 1
# assert ref_wav.shape[1] == 1
# import pdb; pdb.set_trace()
with torch.no_grad():
input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach()
input_emb_2 = F.normalize(self.model.forward(inp_wav), p=2, dim=1).detach()
ref_emb_1 = F.normalize(self.model.forward(ref_wavs), p=2, dim=1).detach()
ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach()
emb_size = input_emb_1.shape[-1]
input_emb_1 = input_emb_1.reshape(bs, -1, emb_size)
input_emb_2 = input_emb_2.reshape(bs, -1, emb_size)
ref_emb_1 = ref_emb_1.reshape(bs, -1, emb_size)
ref_emb_2 = ref_emb_2.reshape(bs, -1, emb_size)
score_1 = torch.mean(torch.bmm(input_emb_1, ref_emb_1.transpose(1,2)), dim=(1,2))
score_2 = torch.mean(torch.bmm(input_emb_2, ref_emb_2.transpose(1,2)), dim=(1,2))
score = (score_1 + score_2) / 2
score = score.detach().cpu().numpy()
return score
|