File size: 3,053 Bytes
027a68c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse

import numpy as np
import torch
import torch.nn as nn
import librosa
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model, Wav2Vec2PreTrainedModel

from project_settings import project_path


class ModelHead(nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.hidden_size, num_labels)

    def forward(self, features, **kwargs):
        x = features
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class AgeGenderModel(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.wav2vec2 = Wav2Vec2Model(config)
        self.age = ModelHead(config, 1)
        self.gender = ModelHead(config, 3)
        self.init_weights()

    def forward(self,
                input_values,
                ):
        outputs = self.wav2vec2(input_values)
        hidden_states = outputs[0]
        hidden_states = torch.mean(hidden_states, dim=1)

        logits_age = self.age.forward(hidden_states)
        logits_gender = torch.softmax(self.gender.forward(hidden_states), dim=1)

        return hidden_states, logits_age, logits_gender


class AudeeringModel(object):
    """
    https://arxiv.org/abs/2306.16962

    https://github.com/audeering/w2v2-age-gender-how-to

    https://huggingface.co/audeering/wav2vec2-large-robust-6-ft-age-gender
    https://huggingface.co/audeering/wav2vec2-large-robust-24-ft-age-gender
    """
    def __init__(self, model_path: str):
        self.model_path = model_path

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        processor: Wav2Vec2Processor = Wav2Vec2Processor.from_pretrained(self.model_path)
        model = AgeGenderModel.from_pretrained(self.model_path).to(device)
        model.eval()

        self.device = device
        self.processor = processor
        self.model = model

    def predict(self, signal: np.ndarray, sample_rate: int) -> dict:
        y = self.processor.__call__(signal, sampling_rate=sample_rate)
        y = y["input_values"][0]
        y = y.reshape(1, -1)
        y = torch.from_numpy(y).to(self.device)

        _, age, gender = self.model.forward(y)

        age = age.detach().cpu().numpy().tolist()
        age = age[0][0]

        gender = gender.detach().cpu().numpy().tolist()
        gender = gender[0]

        result = {
            "age": round(age, 4),
            "female": round(gender[0], 4),
            "male": round(gender[1], 4),
            "child": round(gender[2], 4),
        }
        return result

    def __call__(self, *args, **kwargs):
        return self.predict(*args, **kwargs)


if __name__ == "__main__":
    pass