Spaces:
Running
Running
import json | |
import logging | |
import os | |
from collections import OrderedDict | |
from decimal import Decimal | |
from pathlib import Path | |
from typing import Callable, Union | |
from typing import Tuple, Optional, List, Dict | |
import meeteval | |
import numpy as np | |
import pandas as pd | |
from meeteval.io.seglst import SegLstSegment | |
from meeteval.wer.wer.orc import OrcErrorRate | |
# this must be called before any other loggers are instantiated to take effect | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] [%(name)s] %(message)s') | |
def get_logger(name: str): | |
""" | |
All modules should use this function to get a logger. | |
This way, we ensure all loggers are instantiated after basicConfig() call and inherit the same config. | |
""" | |
return logging.getLogger(name) | |
_LOG = get_logger('wer') | |
def create_dummy_seg_list(session_id): | |
return meeteval.io.SegLST( | |
[{'session_id': session_id, 'start_time': Decimal(0), 'end_time': Decimal(0), 'speaker': '', 'words': ''}]) | |
def calc_session_tcp_wer(ref, hyp, collar): | |
res = meeteval.wer.tcpwer(reference=ref, hypothesis=hyp, collar=collar) | |
res_df = pd.DataFrame.from_dict(res, orient='index').reset_index(names='session_id') | |
keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions', | |
'missed_speaker', 'falarm_speaker', 'scored_speaker', 'assignment'] | |
return (res_df[['session_id'] + keys] | |
.rename(columns={k: 'tcp_' + k for k in keys}) | |
.rename(columns={'tcp_error_rate': 'tcp_wer'})) | |
def calc_wer( | |
ref_seglst: SegLstSegment, | |
tcp_hyp_seglst: SegLstSegment, | |
collar: int = 5, | |
metrics_list: List[str] = None) -> pd.DataFrame: | |
""" | |
Calculates tcpWER and tcorcWER for each session in hypothesis files using meeteval, and saves the error | |
information to .json. | |
Text normalization is applied to both hypothesis and reference. | |
Args: | |
out_dir: the directory to save the ref.json reference transcript to (extracted from gt_utt_df). | |
tcp_wer_hyp_json: path to hypothesis .json file for tcpWER, or json structure. | |
tcorc_wer_hyp_json: path to hypothesis .json file for tcorcWER, or json structure. | |
gt_utt_df: dataframe of ground truth utterances. must include the sessions in the hypothesis files. | |
see load_data() function. | |
tn: text normalizer | |
collar: tolerance of tcpWER to temporal misalignment between hypothesis and reference. | |
save_visualizations: if True, save html visualizations of alignment between hyp and ref. | |
meeting_id_is_session_id: if True, the session_id in the hypothesis/ref files is the same as the meeting_id. | |
Returns: | |
wer_df: pd.DataFrame with columns - | |
'session_id' - same as in hypothesis files | |
'tcp_wer': tcpWER | |
'tcorc_wer': tcorcWER | |
... intermediate tcpWER/tcorcWER fields such as insertions/deletions. see in code. | |
""" | |
# json to SegLST structure (Segment-wise Long-form Speech Transcription annotation) | |
if len(tcp_hyp_seglst) == 0: | |
tcp_hyp_seglst = create_dummy_seg_list(ref_seglst.segments[0]['session_id']) | |
_LOG.warning(f"Empty tcp_wer_hyp_json, using dummy segment: {tcp_hyp_seglst.segments[0]}") | |
wers_to_concat = [] | |
if "tcp_wer" in metrics_list: | |
tcp_wer_res = calc_session_tcp_wer(ref_seglst, tcp_hyp_seglst, collar) | |
wers_to_concat.append(tcp_wer_res.drop(columns='session_id')) | |
wer_df = pd.concat(wers_to_concat, axis=1) | |
wer_df['session_id'] = ref_seglst.segments[0]['session_id'] | |
_LOG.debug('Done calculating WER') | |
_LOG.debug(f"\n{wer_df[['session_id', *metrics_list]]}") | |
return wer_df | |
def aggregate_wer_metrics(wer_df: pd.DataFrame, metrics_list: List[str]) -> Dict: | |
num_wer_df = wer_df._get_numeric_data() | |
metrics = num_wer_df.sum().to_dict(into=OrderedDict) | |
for metric in metrics_list: | |
mprefix, _ = metric.split("_", maxsplit=1) | |
metrics[mprefix + "_wer"] = metrics[mprefix + "_errors"] / metrics[mprefix + "_length"] | |
for k in ['missed_speaker', 'falarm_speaker', 'scored_speaker']: | |
# compute mean for this keys | |
key = f"{mprefix}_{k}" | |
new_key = f"{mprefix}_mean_{k}" | |
if key not in metrics: | |
continue | |
metrics[new_key] = metrics[key] / len(num_wer_df) | |
del metrics[key] | |
return metrics | |
def normalize_segment(segment: SegLstSegment, tn): | |
words = segment["words"] | |
words = tn(words) | |
segment["words"] = words | |
return segment | |
def assign_streams(tcorc_hyp_seglst): | |
tcorc_hyp_seglst = tcorc_hyp_seglst.groupby(key='speaker') | |
per_stream_list = [[] for _ in range(len(tcorc_hyp_seglst))] | |
for speaker_id, speaker_seglst in tcorc_hyp_seglst.items(): | |
speaker_seglst = speaker_seglst.sorted(key='start_time') | |
for seg in speaker_seglst: | |
# check if current segment does not overlap with any of the segments in per_stream_list | |
for i in range(len(per_stream_list)): | |
if not any(seg['start_time'] < s['end_time'] and seg['end_time'] > s['start_time'] for s in | |
per_stream_list[i]): | |
seg['speaker'] = i | |
per_stream_list[i].append(seg) | |
break | |
else: | |
raise ValueError('No stream found for segment') | |
tcorc_hyp_seglst = meeteval.io.SegLST([seg for stream in per_stream_list for seg in stream]).sorted( | |
key='start_time') | |
return tcorc_hyp_seglst | |
def filter_empty_segments(seg_lst): | |
return seg_lst.filter(lambda seg: seg['words'] != '') | |
def find_first_non_overlapping_segment_streams(per_speaker_groups, per_speaker_vad_masks): | |
for speaker_id, speaker_seglst in per_speaker_groups.items(): | |
for other_speaker_id, other_speaker_seglst in per_speaker_groups.items(): | |
if speaker_id != other_speaker_id: | |
vad_mask_merged = per_speaker_vad_masks[speaker_id] & per_speaker_vad_masks[other_speaker_id] | |
if not vad_mask_merged.any(): | |
return (speaker_id, other_speaker_id) | |
def change_speaker_id(segment, speaker_id): | |
segment['speaker'] = speaker_id | |
return segment | |
def merge_streams(tcorc_hyp_seglst): | |
per_speaker_groups = tcorc_hyp_seglst.groupby(key='speaker') | |
# create per speaker vad masks | |
per_speaker_vad_masks = {} | |
for speaker_id, speaker_seglst in per_speaker_groups.items(): | |
per_speaker_vad_masks[speaker_id] = create_vad_mask(speaker_seglst, time_step=0.01) | |
longest_mask = max(len(mask) for mask in per_speaker_vad_masks.values()) | |
# pad all masks to the same length | |
for speaker_id, mask in per_speaker_vad_masks.items(): | |
per_speaker_vad_masks[speaker_id] = np.pad(mask, (0, longest_mask - len(mask))) | |
# recursively merge all pairs of speakers that have no overlapping vad masks | |
while True: | |
res = find_first_non_overlapping_segment_streams(per_speaker_groups, per_speaker_vad_masks) | |
if res is None: | |
break | |
speaker_id, other_speaker_id = res | |
per_speaker_groups[speaker_id] = per_speaker_groups[speaker_id] + per_speaker_groups[other_speaker_id].map( | |
lambda seg: change_speaker_id(seg, speaker_id)) | |
per_speaker_vad_masks[speaker_id] = per_speaker_vad_masks[speaker_id] | per_speaker_vad_masks[other_speaker_id] | |
del per_speaker_groups[other_speaker_id] | |
del per_speaker_vad_masks[other_speaker_id] | |
tcorc_hyp_seglst = meeteval.io.SegLST( | |
[seg for speaker_seglst in per_speaker_groups.values() for seg in speaker_seglst]).sorted(key='start_time') | |
return tcorc_hyp_seglst | |
def normalize_segment(segment: SegLstSegment, tn): | |
words = segment["words"] | |
words = tn(words) | |
segment["words"] = words | |
return segment | |
def create_vad_mask(segments, time_step=0.1, total_duration=None): | |
""" | |
Create a VAD mask for the given segments. | |
:param segments: List of segments, each containing 'start_time' and 'end_time' | |
:param time_step: The resolution of the VAD mask in seconds (default: 100ms) | |
:param total_duration: Optionally specify the total duration to create the mask. | |
If not provided, the mask will be generated based on the maximum end time of the segments. | |
:return: VAD mask as a numpy array, where 1 represents voice activity and 0 represents silence. | |
""" | |
# Find the total duration if not provided | |
if total_duration is None: | |
total_duration = max(seg['end_time'] for seg in segments) | |
# Initialize VAD mask as zeros (silence) | |
mask_length = int(float(total_duration) / time_step) + 1 | |
vad_mask = np.zeros(mask_length, dtype=bool) | |
# Iterate over segments and mark the corresponding times as active (1) | |
for seg in segments: | |
start_idx = int(float(seg['start_time']) / time_step) | |
end_idx = int(float(seg['end_time']) / time_step) | |
vad_mask[start_idx:end_idx] = 1 | |
return vad_mask | |
def find_group_splits(vad, group_duration=30, time_step=0.1): | |
non_active_indices = np.argwhere(~vad).squeeze(axis=-1) | |
splits = [] | |
group_shift = group_duration / time_step | |
next_offset = group_shift | |
for i in non_active_indices: | |
if i >= next_offset: | |
splits.append(i) | |
next_offset = i + group_shift | |
return splits | |
def map_utterance_to_split(utterance_start_time, splits): | |
for i, split in enumerate(splits): | |
if utterance_start_time < split: | |
return i | |
return len(splits) | |
def agregate_errors_across_groups(res, session_id): | |
overall_error_number = sum([group.errors for group in res.values()]) | |
overall_length = sum([group.length for group in res.values()]) | |
overall_errors = { | |
'error_rate': overall_error_number / overall_length, | |
'errors': overall_error_number, | |
'length': overall_length, | |
'insertions': sum([group.insertions for group in res.values()]), | |
'deletions': sum([group.deletions for group in res.values()]), | |
'substitutions': sum([group.substitutions for group in res.values()]), | |
'assignment': [] | |
} | |
for group in res.values(): | |
overall_errors['assignment'].extend(list(group.assignment)) | |
overall_errors['assignment'] = tuple(overall_errors['assignment']) | |
res = {session_id: OrcErrorRate(errors=overall_errors["errors"], | |
length=overall_errors["length"], | |
insertions=overall_errors["insertions"], | |
deletions=overall_errors["deletions"], | |
substitutions=overall_errors["substitutions"], | |
hypothesis_self_overlap=None, | |
reference_self_overlap=None, | |
assignment=overall_errors["assignment"])} | |
return res | |
def aggregate_wer_metrics(wer_df: pd.DataFrame, metrics_list: List[str]) -> Dict: | |
num_wer_df = wer_df._get_numeric_data() | |
metrics = num_wer_df.sum().to_dict(into=OrderedDict) | |
for metric in metrics_list: | |
mprefix, _ = metric.split("_", maxsplit=1) | |
metrics[mprefix + "_wer"] = metrics[mprefix + "_errors"] / metrics[mprefix + "_length"] | |
for k in ['missed_speaker', 'falarm_speaker', 'scored_speaker']: | |
# compute mean for this keys | |
key = f"{mprefix}_{k}" | |
new_key = f"{mprefix}_mean_{k}" | |
if key not in metrics: | |
continue | |
metrics[new_key] = metrics[key] / len(num_wer_df) | |
del metrics[key] | |
return metrics | |