speech-analysis / app.py
hagenw's picture
Debug
d8eeff7
raw
history blame
3.76 kB
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn as nn
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel
import audiofile
class ModelHead(nn.Module):
r"""Classification head."""
def __init__(self, config, num_labels):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class AgeGenderModel(Wav2Vec2PreTrainedModel):
r"""Speech emotion classifier."""
def __init__(self, config):
super().__init__(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.age = ModelHead(config, 1)
self.gender = ModelHead(config, 3)
self.init_weights()
def forward(
self,
input_values,
):
outputs = self.wav2vec2(input_values)
hidden_states = outputs[0]
hidden_states = torch.mean(hidden_states, dim=1)
logits_age = self.age(hidden_states)
logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
return hidden_states, logits_age, logits_gender
# load model from hub
device = 0 if torch.cuda.is_available() else "cpu"
model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = AgeGenderModel.from_pretrained(model_name)
def process_func(x: np.ndarray, sampling_rate: int) -> dict:
r"""Predict age and gender or extract embeddings from raw audio signal."""
# run through processor to normalize signal
# always returns a batch, so we just get the first entry
# then we put it on the device
y = processor(x, sampling_rate=sampling_rate)
y = y['input_values'][0]
y = y.reshape(1, -1)
y = torch.from_numpy(y).to(device)
# run through model
with torch.no_grad():
y = model(y)
y = torch.hstack([y[1], y[2]])
# convert to numpy
y = y.detach().cpu().numpy()
# convert to dict
y = {
"age": 100 * y[0][0],
"female": y[0][1],
"male": y[0][2],
"child": y[0][3],
}
return y
@spaces.GPU
def recognize(file):
if file is None:
raise gr.Error(
"No audio file submitted! "
"Please upload or record an audio file "
"before submitting your request."
)
signal, sampling_rate = audiofile.read(file)
age_gender = process_func(signal, sampling_rate)
return age_gender
outputs = gr.Label()
title = "audEERING age and gender recognition"
description = (
"Recognize age and gender of a microphone recording or audio file. "
"Demo uses the checkpoint [{model_name}](https://huggingface.co/{model_name})."
)
allow_flagging = "never"
microphone = gr.Interface(
fn=recognize,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=outputs,
title=title,
description=description,
allow_flagging=allow_flagging,
)
file = gr.Interface(
fn=recognize,
inputs=gr.Audio(sources="upload", type="filepath", label="Audio file"),
outputs=outputs,
title=title,
description=description,
allow_flagging=allow_flagging,
)
demo = gr.TabbedInterface([microphone, file], ["Microphone", "Audio file"])
# demo.queue().launch()
demo.launch()