Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torchaudio | |
from transformers import AutoFeatureExtractor, AutoModelForAudioXVector | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def similarity_fn(speaker1, speaker2): | |
if not (speaker1 and speaker2): | |
return gr.Textbox(value='<b style="color:red">ERROR: Please record audio for *both* speakers!</b>') | |
wav1, _ = torchaudio.load(speaker1) | |
wav2, _ = torchaudio.load(speaker2) | |
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv") | |
model = AutoModelForAudioXVector.from_pretrained("microsoft/wavlm-base-plus-sv").to(device) | |
input1 = feature_extractor(wav1.squeeze(0), return_tensors="pt", sampling_rate=16000).input_values.to(device) | |
input2 = feature_extractor(wav2.squeeze(0), return_tensors="pt", sampling_rate=16000).input_values.to(device) | |
with torch.no_grad(): | |
emb1 = model(input1).embeddings | |
emb2 = model(input2).embeddings | |
emb1 = torch.nn.functional.normalize(emb1, dim=-1).cpu() | |
emb2 = torch.nn.functional.normalize(emb2, dim=-1).cpu() | |
similarity = torch.nn.CosineSimilarity(dim=-1)(emb1, emb2).numpy()[0] | |
if similarity >= 0.8: | |
label = "The speakers are similar" | |
color = "green" | |
else: | |
label = "The speakers are different" | |
color = "red" | |
return gr.Textbox(value=f"<span style='color:{color}'>{label}</span>") | |
demo = gr.Interface( | |
speaker1=gr.Audio(source="microphone", type="filepath"), | |
speaker2=gr.Audio(source="microphone", type="filepath"), | |
output=gr.Textbox(), | |
fn=similarity_fn | |
) | |
demo.launch() | |