Spaces:
Running
Running
import torch | |
from transformers import WhisperConfig | |
import librosa | |
import numpy as np | |
import pathlib | |
from torch.nn import functional as F | |
from ..model import WhiStress | |
PATH_TO_WEIGHTS = pathlib.Path(__file__).parent.parent / "weights" | |
def get_loaded_model(device="cuda"): | |
whisper_model_name = f"openai/whisper-small.en" | |
whisper_config = WhisperConfig() | |
whistress_model = WhiStress( | |
whisper_config, layer_for_head=9, whisper_backbone_name=whisper_model_name | |
).to(device) | |
whistress_model.processor.tokenizer.model_input_names = [ | |
"input_ids", | |
"attention_mask", | |
"labels_head", | |
] | |
whistress_model.load_model(PATH_TO_WEIGHTS) | |
whistress_model.to(device) | |
whistress_model.eval() | |
return whistress_model | |
def get_word_emphasis_pairs( | |
transcription_preds, emphasis_preds, processor, filter_special_tokens=True | |
): | |
emphasis_preds_list = emphasis_preds.tolist() | |
transcription_preds_words = [ | |
processor.tokenizer.decode([i], skip_special_tokens=False) | |
for i in transcription_preds | |
] | |
if filter_special_tokens: | |
special_tokens_indices = [ | |
i | |
for i, x in enumerate(transcription_preds) | |
if x in processor.tokenizer.all_special_ids | |
] | |
emphasis_preds_list = [ | |
x | |
for i, x in enumerate(emphasis_preds_list) | |
if i not in special_tokens_indices | |
] | |
transcription_preds_words = [ | |
x | |
for i, x in enumerate(transcription_preds_words) | |
if i not in special_tokens_indices | |
] | |
return list(zip(transcription_preds_words, emphasis_preds_list)) | |
def inference_from_audio(audio: np.ndarray, model: WhiStress, device: str): | |
input_features = model.processor.feature_extractor( | |
audio, sampling_rate=16000, return_tensors="pt" | |
)["input_features"] | |
out_model = model.generate_dual(input_features=input_features.to(device)) | |
emphasis_probs = F.softmax(out_model.logits, dim=-1) | |
emphasis_preds = torch.argmax(emphasis_probs, dim=-1) | |
emphasis_preds_right_shifted = torch.cat((emphasis_preds[:, -1:], emphasis_preds[:, :-1]), dim=1) | |
word_emphasis_pairs = get_word_emphasis_pairs( | |
out_model.preds[0], | |
emphasis_preds_right_shifted[0], | |
model.processor, | |
filter_special_tokens=True, | |
) | |
return word_emphasis_pairs | |
def prepare_audio(audio, target_sr=16000): | |
# resample to 16kHz | |
sr = audio["sampling_rate"] | |
y = audio["array"] | |
y = np.array(y, dtype=float) | |
y_resampled = librosa.resample(y, orig_sr=sr, target_sr=target_sr) | |
# Normalize the audio (scale to [-1, 1]) | |
y_resampled /= max(abs(y_resampled)) | |
return y_resampled | |
def merge_stressed_tokens(tokens_with_stress): | |
""" | |
tokens_with_stress is a list of tuples: (token_string, stress_value) | |
e.g.: | |
[(" I", 0), (" didn", 1), ("'t", 0), (" say", 0), (" he", 0), (" stole", 0), | |
(" the", 0), (" money", 0), (".", 0)] | |
Returns a list of merged tuples, combining subwords into full words. | |
""" | |
merged = [] | |
current_word = "" | |
current_stress = 0 # 0 means not stressed, 1 means stressed | |
for token, stress in tokens_with_stress: | |
# If token starts with a space (or is the very first), we treat it as a new word | |
# or if current_word is empty (first iteration). | |
if token.startswith(" ") or current_word == "": | |
# If we already have something in current_word, push it into merged | |
# before starting a new one | |
if current_word: | |
merged.append((current_word, current_stress)) | |
# Start a new word | |
current_word = token | |
current_stress = stress | |
else: | |
# Otherwise, it's a subword that should be appended to the previous word | |
current_word += token | |
# If any sub-token is stressed, the whole merged word is stressed | |
current_stress = max(current_stress, stress) | |
# Don't forget to append the final word | |
if current_word: | |
merged.append((current_word, current_stress)) | |
return merged | |
def inference_from_audio_and_transcription( | |
audio: np.ndarray, transcription, model: WhiStress, device: str | |
): | |
input_features = model.processor.feature_extractor( | |
audio, sampling_rate=16000, return_tensors="pt" | |
)["input_features"] | |
# convert transcription to input_ids | |
input_ids = model.processor.tokenizer( | |
transcription, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=30, | |
)["input_ids"] | |
out_model = model( | |
input_features=input_features.to(device), | |
decoder_input_ids=input_ids.to(device), | |
) | |
emphasis_probs = F.softmax(out_model.logits, dim=-1) | |
emphasis_preds = torch.argmax(emphasis_probs, dim=-1) | |
emphasis_preds_right_shifted = torch.cat((emphasis_preds[:, -1:], emphasis_preds[:, :-1]), dim=1) | |
word_emphasis_pairs = get_word_emphasis_pairs( | |
input_ids[0], | |
emphasis_preds_right_shifted[0], | |
model.processor, | |
filter_special_tokens=True, | |
) | |
return word_emphasis_pairs | |
def scored_transcription(audio, model, strip_words=True, transcription: str = None, device="cuda"): | |
audio_arr = prepare_audio(audio) | |
token_stress_pairs = None | |
if transcription: # if we want to use the ground truth transcription | |
token_stress_pairs = inference_from_audio_and_transcription(audio_arr, transcription, model, device) | |
else: | |
token_stress_pairs = inference_from_audio(audio_arr, model, device) | |
# token_stress_pairs = inference_from_audio(audio_arr, model) | |
word_level_stress = merge_stressed_tokens(token_stress_pairs) | |
if strip_words: | |
word_level_stress = [(word.strip(), stress) for word, stress in word_level_stress] | |
return word_level_stress | |