|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from collections import defaultdict |
|
from functools import lru_cache |
|
from pathlib import Path |
|
from subprocess import CalledProcessError, run |
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union |
|
|
|
import kaldialign |
|
import numpy as np |
|
import soundfile |
|
import av |
|
import wave |
|
import torch |
|
import torch.nn.functional as F |
|
from whisper_live.utils import resample |
|
|
|
|
|
Pathlike = Union[str, Path] |
|
|
|
SAMPLE_RATE = 16000 |
|
N_FFT = 400 |
|
HOP_LENGTH = 160 |
|
CHUNK_LENGTH = 30 |
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE |
|
|
|
|
|
def load_audio(file: str, sr: int = 16000): |
|
""" |
|
Open an audio file, resample it, and read as a mono waveform. |
|
|
|
Parameters |
|
---------- |
|
file: str |
|
The audio file to open. |
|
|
|
sr: int |
|
The sample rate to resample the audio if necessary. |
|
|
|
Returns |
|
------- |
|
A NumPy array containing the audio waveform, in float32 dtype. |
|
""" |
|
resampled_file = resample(file, sr) |
|
|
|
with wave.open(resampled_file, "rb") as wav_file: |
|
num_frames = wav_file.getnframes() |
|
raw_data = wav_file.readframes(num_frames) |
|
|
|
audio_data = np.frombuffer(raw_data, dtype=np.int16) |
|
|
|
audio_data = audio_data.astype(np.float32) / 32768.0 |
|
|
|
return audio_data |
|
|
|
|
|
def load_audio_wav_format(wav_path): |
|
|
|
assert wav_path.endswith( |
|
'.wav'), f"Only support .wav format, but got {wav_path}" |
|
waveform, sample_rate = soundfile.read(wav_path) |
|
assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}" |
|
return waveform, sample_rate |
|
|
|
|
|
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): |
|
""" |
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder. |
|
""" |
|
if torch.is_tensor(array): |
|
if array.shape[axis] > length: |
|
array = array.index_select(dim=axis, |
|
index=torch.arange(length, |
|
device=array.device)) |
|
|
|
if array.shape[axis] < length: |
|
pad_widths = [(0, 0)] * array.ndim |
|
pad_widths[axis] = (0, length - array.shape[axis]) |
|
array = F.pad(array, |
|
[pad for sizes in pad_widths[::-1] for pad in sizes]) |
|
else: |
|
if array.shape[axis] > length: |
|
array = array.take(indices=range(length), axis=axis) |
|
|
|
if array.shape[axis] < length: |
|
pad_widths = [(0, 0)] * array.ndim |
|
pad_widths[axis] = (0, length - array.shape[axis]) |
|
array = np.pad(array, pad_widths) |
|
|
|
return array |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def mel_filters(device, |
|
n_mels: int, |
|
mel_filters_dir: str = None) -> torch.Tensor: |
|
""" |
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram. |
|
Allows decoupling librosa dependency; saved using: |
|
|
|
np.savez_compressed( |
|
"mel_filters.npz", |
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), |
|
) |
|
""" |
|
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" |
|
if mel_filters_dir is None: |
|
mel_filters_path = os.path.join(os.path.dirname(__file__), "assets", |
|
"mel_filters.npz") |
|
else: |
|
mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz") |
|
with np.load(mel_filters_path) as f: |
|
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) |
|
|
|
|
|
def log_mel_spectrogram( |
|
audio: Union[str, np.ndarray, torch.Tensor], |
|
n_mels: int, |
|
padding: int = 0, |
|
device: Optional[Union[str, torch.device]] = None, |
|
return_duration: bool = False, |
|
mel_filters_dir: str = None, |
|
): |
|
""" |
|
Compute the log-Mel spectrogram of |
|
|
|
Parameters |
|
---------- |
|
audio: Union[str, np.ndarray, torch.Tensor], shape = (*) |
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz |
|
|
|
n_mels: int |
|
The number of Mel-frequency filters, only 80 and 128 are supported |
|
|
|
padding: int |
|
Number of zero samples to pad to the right |
|
|
|
device: Optional[Union[str, torch.device]] |
|
If given, the audio tensor is moved to this device before STFT |
|
|
|
Returns |
|
------- |
|
torch.Tensor, shape = (80 or 128, n_frames) |
|
A Tensor that contains the Mel spectrogram |
|
""" |
|
if not torch.is_tensor(audio): |
|
if isinstance(audio, str): |
|
if audio.endswith('.wav'): |
|
audio, _ = load_audio_wav_format(audio) |
|
else: |
|
audio = load_audio(audio) |
|
assert isinstance(audio, |
|
np.ndarray), f"Unsupported audio type: {type(audio)}" |
|
duration = audio.shape[-1] / SAMPLE_RATE |
|
audio = pad_or_trim(audio, N_SAMPLES) |
|
audio = audio.astype(np.float32) |
|
audio = torch.from_numpy(audio) |
|
|
|
if device is not None: |
|
audio = audio.to(device) |
|
if padding > 0: |
|
audio = F.pad(audio, (0, padding)) |
|
window = torch.hann_window(N_FFT).to(audio.device) |
|
stft = torch.stft(audio, |
|
N_FFT, |
|
HOP_LENGTH, |
|
window=window, |
|
return_complex=True) |
|
magnitudes = stft[..., :-1].abs()**2 |
|
|
|
filters = mel_filters(audio.device, n_mels, mel_filters_dir) |
|
mel_spec = filters @ magnitudes |
|
|
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
|
log_spec = (log_spec + 4.0) / 4.0 |
|
if return_duration: |
|
return log_spec, duration |
|
else: |
|
return log_spec |
|
|
|
|
|
def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str, |
|
str]]) -> None: |
|
"""Save predicted results and reference transcripts to a file. |
|
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py |
|
Args: |
|
filename: |
|
File to save the results to. |
|
texts: |
|
An iterable of tuples. The first element is the cur_id, the second is |
|
the reference transcript and the third element is the predicted result. |
|
Returns: |
|
Return None. |
|
""" |
|
with open(filename, "w") as f: |
|
for cut_id, ref, hyp in texts: |
|
print(f"{cut_id}:\tref={ref}", file=f) |
|
print(f"{cut_id}:\thyp={hyp}", file=f) |
|
|
|
|
|
def write_error_stats( |
|
f: TextIO, |
|
test_set_name: str, |
|
results: List[Tuple[str, str]], |
|
enable_log: bool = True, |
|
) -> float: |
|
"""Write statistics based on predicted results and reference transcripts. |
|
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py |
|
It will write the following to the given file: |
|
|
|
- WER |
|
- number of insertions, deletions, substitutions, corrects and total |
|
reference words. For example:: |
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 |
|
reference words (2337 correct) |
|
|
|
- The difference between the reference transcript and predicted result. |
|
An instance is given below:: |
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES |
|
|
|
The above example shows that the reference word is `EDISON`, |
|
but it is predicted to `ADDISON` (a substitution error). |
|
|
|
Another example is:: |
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK |
|
|
|
The reference word `SIR` is missing in the predicted |
|
results (a deletion error). |
|
results: |
|
An iterable of tuples. The first element is the cur_id, the second is |
|
the reference transcript and the third element is the predicted result. |
|
enable_log: |
|
If True, also print detailed WER to the console. |
|
Otherwise, it is written only to the given file. |
|
Returns: |
|
Return None. |
|
""" |
|
subs: Dict[Tuple[str, str], int] = defaultdict(int) |
|
ins: Dict[str, int] = defaultdict(int) |
|
dels: Dict[str, int] = defaultdict(int) |
|
|
|
|
|
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) |
|
num_corr = 0 |
|
ERR = "*" |
|
for cut_id, ref, hyp in results: |
|
ali = kaldialign.align(ref, hyp, ERR) |
|
for ref_word, hyp_word in ali: |
|
if ref_word == ERR: |
|
ins[hyp_word] += 1 |
|
words[hyp_word][3] += 1 |
|
elif hyp_word == ERR: |
|
dels[ref_word] += 1 |
|
words[ref_word][4] += 1 |
|
elif hyp_word != ref_word: |
|
subs[(ref_word, hyp_word)] += 1 |
|
words[ref_word][1] += 1 |
|
words[hyp_word][2] += 1 |
|
else: |
|
words[ref_word][0] += 1 |
|
num_corr += 1 |
|
ref_len = sum([len(r) for _, r, _ in results]) |
|
sub_errs = sum(subs.values()) |
|
ins_errs = sum(ins.values()) |
|
del_errs = sum(dels.values()) |
|
tot_errs = sub_errs + ins_errs + del_errs |
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) |
|
|
|
if enable_log: |
|
logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " |
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, " |
|
f"{del_errs} del, {sub_errs} sub ]") |
|
|
|
print(f"%WER = {tot_err_rate}", file=f) |
|
print( |
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, " |
|
f"{sub_errs} substitutions, over {ref_len} reference " |
|
f"words ({num_corr} correct)", |
|
file=f, |
|
) |
|
print( |
|
"Search below for sections starting with PER-UTT DETAILS:, " |
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", |
|
file=f, |
|
) |
|
|
|
print("", file=f) |
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) |
|
for cut_id, ref, hyp in results: |
|
ali = kaldialign.align(ref, hyp, ERR) |
|
combine_successive_errors = True |
|
if combine_successive_errors: |
|
ali = [[[x], [y]] for x, y in ali] |
|
for i in range(len(ali) - 1): |
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: |
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0] |
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1] |
|
ali[i] = [[], []] |
|
ali = [[ |
|
list(filter(lambda a: a != ERR, x)), |
|
list(filter(lambda a: a != ERR, y)), |
|
] for x, y in ali] |
|
ali = list(filter(lambda x: x != [[], []], ali)) |
|
ali = [[ |
|
ERR if x == [] else " ".join(x), |
|
ERR if y == [] else " ".join(y), |
|
] for x, y in ali] |
|
|
|
print( |
|
f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else |
|
f"({ref_word}->{hyp_word})" |
|
for ref_word, hyp_word in ali)), |
|
file=f, |
|
) |
|
|
|
print("", file=f) |
|
print("SUBSTITUTIONS: count ref -> hyp", file=f) |
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], |
|
reverse=True): |
|
print(f"{count} {ref} -> {hyp}", file=f) |
|
|
|
print("", file=f) |
|
print("DELETIONS: count ref", file=f) |
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): |
|
print(f"{count} {ref}", file=f) |
|
|
|
print("", file=f) |
|
print("INSERTIONS: count hyp", file=f) |
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): |
|
print(f"{count} {hyp}", file=f) |
|
|
|
print("", file=f) |
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", |
|
file=f) |
|
for _, word, counts in sorted([(sum(v[1:]), k, v) |
|
for k, v in words.items()], |
|
reverse=True): |
|
(corr, ref_sub, hyp_sub, ins, dels) = counts |
|
tot_errs = ref_sub + hyp_sub + ins + dels |
|
ref_count = corr + ref_sub + dels |
|
hyp_count = corr + hyp_sub + ins |
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) |
|
return float(tot_err_rate) |
|
|