File size: 4,180 Bytes
2216a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce904ba
 
 
 
 
2216a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import torch
import torch.nn.functional as F
from ssl_ecapa_model import SSL_ECAPA_TDNN
from huggingface_hub import hf_hub_download


def load_model(ckpt_path):
    model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large')
    load_parameters(model, ckpt_path)
    return model


def load_parameters(model, ckpt_path):
    model_state = model.state_dict()
    if not os.path.isfile(ckpt_path):
        print("Downloading model from Hugging Face Hub...")
        new_ckpt_path = hf_hub_download(repo_id="junseok520/voxsim-models", filename=ckpt_path, local_dir="./")
        ckpt_path = new_ckpt_path
    loaded_state = torch.load(ckpt_path, map_location='cpu', weights_only=True)
    
    for name, param in loaded_state.items():
        if name.startswith('__S__.'):
            if name[6:] in model_state:
                model_state[name[6:]].copy_(param)
            else:
                print("{} is not in the model.".format(name[6:]))
        else:
            if name in model_state:
                model_state[name].copy_(param)
            else:
                print("{} is not in the model.".format(name))


class Score:
    """Predicting score for each audio clip."""

    def __init__(
        self,
        ckpt_path: str = "wavlm_ecapa.pt",
        device: str = "gpu"):
        """
        Args:
            ckpt_path: path to pretrained checkpoint of voxsim evaluator.
            input_sample_rate: sampling rate of input audio tensor. The input audio tensor
                is automatically downsampled to 16kHz.
        """
        print(f"Using device: {device}")
        self.device = device
        self.model = load_model(ckpt_path).to(self.device)
        self.model.eval()
    
    def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor:
        """
        Args:
            wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
                the model processes the input as a single audio clip. The model
                performs batch processing when len(wavs) == 3. 
        """
        # if len(wavs.shape) == 1:
        #     out_wavs = wavs.unsqueeze(0).unsqueeze(0)
        # elif len(wavs.shape) == 2:
        #     out_wavs = wavs.unsqueeze(0)
        # elif len(wavs.shape) == 3:
        #     out_wavs = wavs
        # else:
        #     raise ValueError('Dimension of input tensor needs to be <= 3.')

        if len(inp_wavs.shape) == 2:
            bs = 1
        elif len(inp_wavs.shape) == 3:
            bs = inp_wavs.shape[0]
        else:
            raise ValueError('Dimension of input tensor needs to be <= 3.')

        inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device)
        inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device)
        ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device)
        ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device)
        
        # assert inp_wavs.shape[1] == 10
        # assert ref_wavs.shape[1] == 10
        # assert inp_wav.shape[1] == 1
        # assert ref_wav.shape[1] == 1

        # import pdb; pdb.set_trace()

        with torch.no_grad():
            input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach()
            input_emb_2 = F.normalize(self.model.forward(inp_wav), p=2, dim=1).detach()
            ref_emb_1 = F.normalize(self.model.forward(ref_wavs), p=2, dim=1).detach()
            ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach()

            emb_size = input_emb_1.shape[-1]
            input_emb_1 = input_emb_1.reshape(bs, -1, emb_size)
            input_emb_2 = input_emb_2.reshape(bs, -1, emb_size)
            ref_emb_1 = ref_emb_1.reshape(bs, -1, emb_size)
            ref_emb_2 = ref_emb_2.reshape(bs, -1, emb_size)

            score_1 = torch.mean(torch.bmm(input_emb_1, ref_emb_1.transpose(1,2)), dim=(1,2))
            score_2 = torch.mean(torch.bmm(input_emb_2, ref_emb_2.transpose(1,2)), dim=(1,2))
            score = (score_1 + score_2) / 2
            score = score.detach().cpu().numpy()

            return score